Spaces:
				
			
			
	
			
			
		Running
		
			on 
			
			Zero
	
	
	
			
			
	
	
	
	
		
		
		Running
		
			on 
			
			Zero
	Initial commit
Browse files- .gitattributes +1 -0
- .gitignore +28 -0
- LICENSE +21 -0
- app.py +938 -0
- example.py +67 -0
- examples/cartoon_horse.mp4 +3 -0
- examples/gradio_examples/house.mp4 +3 -0
- examples/gradio_examples/man_walking_long.mp4 +3 -0
- examples/gradio_examples/parkour.mp4 +3 -0
- examples/gradio_examples/valley.mp4 +3 -0
- examples/parkour_long.mp4 +3 -0
- examples/skating.mp4 +3 -0
- examples/skiing.mp4 +3 -0
- pi3/models/dinov2/__init__.py +6 -0
- pi3/models/dinov2/hub/__init__.py +4 -0
- pi3/models/dinov2/hub/backbones.py +156 -0
- pi3/models/dinov2/hub/utils.py +39 -0
- pi3/models/dinov2/layers/__init__.py +11 -0
- pi3/models/dinov2/layers/attention.py +89 -0
- pi3/models/dinov2/layers/block.py +259 -0
- pi3/models/dinov2/layers/dino_head.py +58 -0
- pi3/models/dinov2/layers/drop_path.py +34 -0
- pi3/models/dinov2/layers/layer_scale.py +27 -0
- pi3/models/dinov2/layers/mlp.py +40 -0
- pi3/models/dinov2/layers/patch_embed.py +88 -0
- pi3/models/dinov2/layers/swiglu_ffn.py +72 -0
- pi3/models/dinov2/models/__init__.py +43 -0
- pi3/models/dinov2/models/vision_transformer.py +404 -0
- pi3/models/dinov2/utils/__init__.py +4 -0
- pi3/models/dinov2/utils/cluster.py +95 -0
- pi3/models/dinov2/utils/config.py +72 -0
- pi3/models/dinov2/utils/dtype.py +37 -0
- pi3/models/dinov2/utils/param_groups.py +103 -0
- pi3/models/dinov2/utils/utils.py +95 -0
- pi3/models/layers/attention.py +369 -0
- pi3/models/layers/block.py +406 -0
- pi3/models/layers/camera_head.py +93 -0
- pi3/models/layers/pos_embed.py +174 -0
- pi3/models/layers/transformer_head.py +81 -0
- pi3/models/pi3.py +216 -0
- pi3/utils/basic.py +223 -0
- pi3/utils/debug.py +63 -0
- pi3/utils/geometry.py +375 -0
- requirements.txt +13 -0
    	
        .gitattributes
    CHANGED
    
    | @@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text | |
| 33 | 
             
            *.zip filter=lfs diff=lfs merge=lfs -text
         | 
| 34 | 
             
            *.zst filter=lfs diff=lfs merge=lfs -text
         | 
| 35 | 
             
            *tfevents* filter=lfs diff=lfs merge=lfs -text
         | 
|  | 
|  | |
| 33 | 
             
            *.zip filter=lfs diff=lfs merge=lfs -text
         | 
| 34 | 
             
            *.zst filter=lfs diff=lfs merge=lfs -text
         | 
| 35 | 
             
            *tfevents* filter=lfs diff=lfs merge=lfs -text
         | 
| 36 | 
            +
            *.mp4 filter=lfs diff=lfs merge=lfs -text
         | 
    	
        .gitignore
    ADDED
    
    | @@ -0,0 +1,28 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            *.pyc
         | 
| 2 | 
            +
            *.egg-info
         | 
| 3 | 
            +
            *.pt
         | 
| 4 | 
            +
            *.pth
         | 
| 5 | 
            +
            *.zip
         | 
| 6 | 
            +
            *.tar.gz
         | 
| 7 | 
            +
            *.render*
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            **/__pycache__/
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            .vscode
         | 
| 12 | 
            +
            activate.sh
         | 
| 13 | 
            +
            logs
         | 
| 14 | 
            +
            outputs
         | 
| 15 | 
            +
            data
         | 
| 16 | 
            +
            ckpts
         | 
| 17 | 
            +
            vis
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            *.ply
         | 
| 20 | 
            +
            *.npy
         | 
| 21 | 
            +
            *.png
         | 
| 22 | 
            +
            *.log
         | 
| 23 | 
            +
            *.jpg
         | 
| 24 | 
            +
            *.out
         | 
| 25 | 
            +
            *.pt
         | 
| 26 | 
            +
             | 
| 27 | 
            +
            resize_demo_imgs.py
         | 
| 28 | 
            +
            img_dir_to_video.py
         | 
    	
        LICENSE
    ADDED
    
    | @@ -0,0 +1,21 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            MIT License
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            Copyright (c) 2025 The Pi3 Authors
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            Permission is hereby granted, free of charge, to any person obtaining a copy
         | 
| 6 | 
            +
            of this software and associated documentation files (the "Software"), to deal
         | 
| 7 | 
            +
            in the Software without restriction, including without limitation the rights
         | 
| 8 | 
            +
            to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
         | 
| 9 | 
            +
            copies of the Software, and to permit persons to whom the Software is
         | 
| 10 | 
            +
            furnished to do so, subject to the following conditions:
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            The above copyright notice and this permission notice shall be included in all
         | 
| 13 | 
            +
            copies or substantial portions of the Software.
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
         | 
| 16 | 
            +
            IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
         | 
| 17 | 
            +
            FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
         | 
| 18 | 
            +
            AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
         | 
| 19 | 
            +
            LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
         | 
| 20 | 
            +
            OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
         | 
| 21 | 
            +
            SOFTWARE.
         | 
    	
        app.py
    ADDED
    
    | @@ -0,0 +1,938 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Meta Platforms, Inc. and affiliates.
         | 
| 2 | 
            +
            # All rights reserved.
         | 
| 3 | 
            +
            #
         | 
| 4 | 
            +
            # This source code is licensed under the license found in the
         | 
| 5 | 
            +
            # LICENSE file in the root directory of this source tree.
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            import os
         | 
| 8 | 
            +
            import cv2
         | 
| 9 | 
            +
            import torch
         | 
| 10 | 
            +
            import numpy as np
         | 
| 11 | 
            +
            import gradio as gr
         | 
| 12 | 
            +
            import sys
         | 
| 13 | 
            +
            import shutil
         | 
| 14 | 
            +
            from datetime import datetime
         | 
| 15 | 
            +
            import glob
         | 
| 16 | 
            +
            import gc
         | 
| 17 | 
            +
            import time
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            from pi3.utils.geometry import se3_inverse, homogenize_points, depth_edge
         | 
| 20 | 
            +
            from pi3.models.pi3 import Pi3
         | 
| 21 | 
            +
            from pi3.utils.basic import load_images_as_tensor
         | 
| 22 | 
            +
             | 
| 23 | 
            +
            import trimesh
         | 
| 24 | 
            +
            import matplotlib
         | 
| 25 | 
            +
            from scipy.spatial.transform import Rotation
         | 
| 26 | 
            +
             | 
| 27 | 
            +
             | 
| 28 | 
            +
            """
         | 
| 29 | 
            +
            Gradio utils
         | 
| 30 | 
            +
            """
         | 
| 31 | 
            +
             | 
| 32 | 
            +
            def predictions_to_glb(
         | 
| 33 | 
            +
                predictions,
         | 
| 34 | 
            +
                conf_thres=50.0,
         | 
| 35 | 
            +
                filter_by_frames="all",
         | 
| 36 | 
            +
                show_cam=True,
         | 
| 37 | 
            +
            ) -> trimesh.Scene:
         | 
| 38 | 
            +
                """
         | 
| 39 | 
            +
                Converts VGGT predictions to a 3D scene represented as a GLB file.
         | 
| 40 | 
            +
             | 
| 41 | 
            +
                Args:
         | 
| 42 | 
            +
                    predictions (dict): Dictionary containing model predictions with keys:
         | 
| 43 | 
            +
                        - world_points: 3D point coordinates (S, H, W, 3)
         | 
| 44 | 
            +
                        - world_points_conf: Confidence scores (S, H, W)
         | 
| 45 | 
            +
                        - images: Input images (S, H, W, 3)
         | 
| 46 | 
            +
                        - extrinsic: Camera extrinsic matrices (S, 3, 4)
         | 
| 47 | 
            +
                    conf_thres (float): Percentage of low-confidence points to filter out (default: 50.0)
         | 
| 48 | 
            +
                    filter_by_frames (str): Frame filter specification (default: "all")
         | 
| 49 | 
            +
                    show_cam (bool): Include camera visualization (default: True)
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                Returns:
         | 
| 52 | 
            +
                    trimesh.Scene: Processed 3D scene containing point cloud and cameras
         | 
| 53 | 
            +
             | 
| 54 | 
            +
                Raises:
         | 
| 55 | 
            +
                    ValueError: If input predictions structure is invalid
         | 
| 56 | 
            +
                """
         | 
| 57 | 
            +
                if not isinstance(predictions, dict):
         | 
| 58 | 
            +
                    raise ValueError("predictions must be a dictionary")
         | 
| 59 | 
            +
             | 
| 60 | 
            +
                if conf_thres is None:
         | 
| 61 | 
            +
                    conf_thres = 10
         | 
| 62 | 
            +
             | 
| 63 | 
            +
                print("Building GLB scene")
         | 
| 64 | 
            +
                selected_frame_idx = None
         | 
| 65 | 
            +
                if filter_by_frames != "all" and filter_by_frames != "All":
         | 
| 66 | 
            +
                    try:
         | 
| 67 | 
            +
                        # Extract the index part before the colon
         | 
| 68 | 
            +
                        selected_frame_idx = int(filter_by_frames.split(":")[0])
         | 
| 69 | 
            +
                    except (ValueError, IndexError):
         | 
| 70 | 
            +
                        pass
         | 
| 71 | 
            +
             | 
| 72 | 
            +
                pred_world_points = predictions["points"]
         | 
| 73 | 
            +
                pred_world_points_conf = predictions.get("conf", np.ones_like(pred_world_points[..., 0]))
         | 
| 74 | 
            +
             | 
| 75 | 
            +
                # Get images from predictions
         | 
| 76 | 
            +
                images = predictions["images"]
         | 
| 77 | 
            +
                # Use extrinsic matrices instead of pred_extrinsic_list
         | 
| 78 | 
            +
                camera_poses = predictions["camera_poses"]
         | 
| 79 | 
            +
             | 
| 80 | 
            +
                if selected_frame_idx is not None:
         | 
| 81 | 
            +
                    pred_world_points = pred_world_points[selected_frame_idx][None]
         | 
| 82 | 
            +
                    pred_world_points_conf = pred_world_points_conf[selected_frame_idx][None]
         | 
| 83 | 
            +
                    images = images[selected_frame_idx][None]
         | 
| 84 | 
            +
                    camera_poses = camera_poses[selected_frame_idx][None]
         | 
| 85 | 
            +
             | 
| 86 | 
            +
                vertices_3d = pred_world_points.reshape(-1, 3)
         | 
| 87 | 
            +
                # Handle different image formats - check if images need transposing
         | 
| 88 | 
            +
                if images.ndim == 4 and images.shape[1] == 3:  # NCHW format
         | 
| 89 | 
            +
                    colors_rgb = np.transpose(images, (0, 2, 3, 1))
         | 
| 90 | 
            +
                else:  # Assume already in NHWC format
         | 
| 91 | 
            +
                    colors_rgb = images
         | 
| 92 | 
            +
                colors_rgb = (colors_rgb.reshape(-1, 3) * 255).astype(np.uint8)
         | 
| 93 | 
            +
             | 
| 94 | 
            +
                conf = pred_world_points_conf.reshape(-1)
         | 
| 95 | 
            +
                # Convert percentage threshold to actual confidence value
         | 
| 96 | 
            +
                if conf_thres == 0.0:
         | 
| 97 | 
            +
                    conf_threshold = 0.0
         | 
| 98 | 
            +
                else:
         | 
| 99 | 
            +
                    # conf_threshold = np.percentile(conf, conf_thres)
         | 
| 100 | 
            +
                    conf_threshold = conf_thres / 100
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                conf_mask = (conf >= conf_threshold) & (conf > 1e-5)
         | 
| 103 | 
            +
             | 
| 104 | 
            +
                vertices_3d = vertices_3d[conf_mask]
         | 
| 105 | 
            +
                colors_rgb = colors_rgb[conf_mask]
         | 
| 106 | 
            +
             | 
| 107 | 
            +
                if vertices_3d is None or np.asarray(vertices_3d).size == 0:
         | 
| 108 | 
            +
                    vertices_3d = np.array([[1, 0, 0]])
         | 
| 109 | 
            +
                    colors_rgb = np.array([[255, 255, 255]])
         | 
| 110 | 
            +
                    scene_scale = 1
         | 
| 111 | 
            +
                else:
         | 
| 112 | 
            +
                    # Calculate the 5th and 95th percentiles along each axis
         | 
| 113 | 
            +
                    lower_percentile = np.percentile(vertices_3d, 5, axis=0)
         | 
| 114 | 
            +
                    upper_percentile = np.percentile(vertices_3d, 95, axis=0)
         | 
| 115 | 
            +
             | 
| 116 | 
            +
                    # Calculate the diagonal length of the percentile bounding box
         | 
| 117 | 
            +
                    scene_scale = np.linalg.norm(upper_percentile - lower_percentile)
         | 
| 118 | 
            +
             | 
| 119 | 
            +
                colormap = matplotlib.colormaps.get_cmap("gist_rainbow")
         | 
| 120 | 
            +
             | 
| 121 | 
            +
                # Initialize a 3D scene
         | 
| 122 | 
            +
                scene_3d = trimesh.Scene()
         | 
| 123 | 
            +
             | 
| 124 | 
            +
                # Add point cloud data to the scene
         | 
| 125 | 
            +
                point_cloud_data = trimesh.PointCloud(vertices=vertices_3d, colors=colors_rgb)
         | 
| 126 | 
            +
             | 
| 127 | 
            +
                scene_3d.add_geometry(point_cloud_data)
         | 
| 128 | 
            +
             | 
| 129 | 
            +
                # Prepare 4x4 matrices for camera extrinsics
         | 
| 130 | 
            +
                num_cameras = len(camera_poses)
         | 
| 131 | 
            +
             | 
| 132 | 
            +
                if show_cam:
         | 
| 133 | 
            +
                    # Add camera models to the scene
         | 
| 134 | 
            +
                    for i in range(num_cameras):
         | 
| 135 | 
            +
                        camera_to_world = camera_poses[i]
         | 
| 136 | 
            +
                        rgba_color = colormap(i / num_cameras)
         | 
| 137 | 
            +
                        current_color = tuple(int(255 * x) for x in rgba_color[:3])
         | 
| 138 | 
            +
             | 
| 139 | 
            +
                        # integrate_camera_into_scene(scene_3d, camera_to_world, current_color, scene_scale)
         | 
| 140 | 
            +
                        integrate_camera_into_scene(scene_3d, camera_to_world, current_color, 1.)          # fixed camera size
         | 
| 141 | 
            +
             | 
| 142 | 
            +
                # Rotate scene for better visualize
         | 
| 143 | 
            +
                align_rotation = np.eye(4)
         | 
| 144 | 
            +
                align_rotation[:3, :3] = Rotation.from_euler("y", 100, degrees=True).as_matrix()            # plane rotate
         | 
| 145 | 
            +
                align_rotation[:3, :3] = align_rotation[:3, :3] @ Rotation.from_euler("x", 155, degrees=True).as_matrix()           # roll
         | 
| 146 | 
            +
                scene_3d.apply_transform(align_rotation)
         | 
| 147 | 
            +
             | 
| 148 | 
            +
                print("GLB Scene built")
         | 
| 149 | 
            +
                return scene_3d
         | 
| 150 | 
            +
             | 
| 151 | 
            +
            def integrate_camera_into_scene(scene: trimesh.Scene, transform: np.ndarray, face_colors: tuple, scene_scale: float):
         | 
| 152 | 
            +
                """
         | 
| 153 | 
            +
                Integrates a fake camera mesh into the 3D scene.
         | 
| 154 | 
            +
             | 
| 155 | 
            +
                Args:
         | 
| 156 | 
            +
                    scene (trimesh.Scene): The 3D scene to add the camera model.
         | 
| 157 | 
            +
                    transform (np.ndarray): Transformation matrix for camera positioning.
         | 
| 158 | 
            +
                    face_colors (tuple): Color of the camera face.
         | 
| 159 | 
            +
                    scene_scale (float): Scale of the scene.
         | 
| 160 | 
            +
                """
         | 
| 161 | 
            +
             | 
| 162 | 
            +
                cam_width = scene_scale * 0.05
         | 
| 163 | 
            +
                cam_height = scene_scale * 0.1
         | 
| 164 | 
            +
             | 
| 165 | 
            +
                # Create cone shape for camera
         | 
| 166 | 
            +
                rot_45_degree = np.eye(4)
         | 
| 167 | 
            +
                rot_45_degree[:3, :3] = Rotation.from_euler("z", 45, degrees=True).as_matrix()
         | 
| 168 | 
            +
                rot_45_degree[2, 3] = -cam_height
         | 
| 169 | 
            +
             | 
| 170 | 
            +
                opengl_transform = get_opengl_conversion_matrix()
         | 
| 171 | 
            +
                # Combine transformations
         | 
| 172 | 
            +
                complete_transform = transform @ opengl_transform @ rot_45_degree
         | 
| 173 | 
            +
                camera_cone_shape = trimesh.creation.cone(cam_width, cam_height, sections=4)
         | 
| 174 | 
            +
             | 
| 175 | 
            +
                # Generate mesh for the camera
         | 
| 176 | 
            +
                slight_rotation = np.eye(4)
         | 
| 177 | 
            +
                slight_rotation[:3, :3] = Rotation.from_euler("z", 2, degrees=True).as_matrix()
         | 
| 178 | 
            +
             | 
| 179 | 
            +
                vertices_combined = np.concatenate(
         | 
| 180 | 
            +
                    [
         | 
| 181 | 
            +
                        camera_cone_shape.vertices,
         | 
| 182 | 
            +
                        0.95 * camera_cone_shape.vertices,
         | 
| 183 | 
            +
                        transform_points(slight_rotation, camera_cone_shape.vertices),
         | 
| 184 | 
            +
                    ]
         | 
| 185 | 
            +
                )
         | 
| 186 | 
            +
                vertices_transformed = transform_points(complete_transform, vertices_combined)
         | 
| 187 | 
            +
             | 
| 188 | 
            +
                mesh_faces = compute_camera_faces(camera_cone_shape)
         | 
| 189 | 
            +
             | 
| 190 | 
            +
                # Add the camera mesh to the scene
         | 
| 191 | 
            +
                camera_mesh = trimesh.Trimesh(vertices=vertices_transformed, faces=mesh_faces)
         | 
| 192 | 
            +
                camera_mesh.visual.face_colors[:, :3] = face_colors
         | 
| 193 | 
            +
                scene.add_geometry(camera_mesh)
         | 
| 194 | 
            +
             | 
| 195 | 
            +
             | 
| 196 | 
            +
            def get_opengl_conversion_matrix() -> np.ndarray:
         | 
| 197 | 
            +
                """
         | 
| 198 | 
            +
                Constructs and returns the OpenGL conversion matrix.
         | 
| 199 | 
            +
             | 
| 200 | 
            +
                Returns:
         | 
| 201 | 
            +
                    numpy.ndarray: A 4x4 OpenGL conversion matrix.
         | 
| 202 | 
            +
                """
         | 
| 203 | 
            +
                # Create an identity matrix
         | 
| 204 | 
            +
                matrix = np.identity(4)
         | 
| 205 | 
            +
             | 
| 206 | 
            +
                # Flip the y and z axes
         | 
| 207 | 
            +
                matrix[1, 1] = -1
         | 
| 208 | 
            +
                matrix[2, 2] = -1
         | 
| 209 | 
            +
             | 
| 210 | 
            +
                return matrix
         | 
| 211 | 
            +
             | 
| 212 | 
            +
             | 
| 213 | 
            +
            def transform_points(transformation: np.ndarray, points: np.ndarray, dim: int = None) -> np.ndarray:
         | 
| 214 | 
            +
                """
         | 
| 215 | 
            +
                Applies a 4x4 transformation to a set of points.
         | 
| 216 | 
            +
             | 
| 217 | 
            +
                Args:
         | 
| 218 | 
            +
                    transformation (np.ndarray): Transformation matrix.
         | 
| 219 | 
            +
                    points (np.ndarray): Points to be transformed.
         | 
| 220 | 
            +
                    dim (int, optional): Dimension for reshaping the result.
         | 
| 221 | 
            +
             | 
| 222 | 
            +
                Returns:
         | 
| 223 | 
            +
                    np.ndarray: Transformed points.
         | 
| 224 | 
            +
                """
         | 
| 225 | 
            +
                points = np.asarray(points)
         | 
| 226 | 
            +
                initial_shape = points.shape[:-1]
         | 
| 227 | 
            +
                dim = dim or points.shape[-1]
         | 
| 228 | 
            +
             | 
| 229 | 
            +
                # Apply transformation
         | 
| 230 | 
            +
                transformation = transformation.swapaxes(-1, -2)  # Transpose the transformation matrix
         | 
| 231 | 
            +
                points = points @ transformation[..., :-1, :] + transformation[..., -1:, :]
         | 
| 232 | 
            +
             | 
| 233 | 
            +
                # Reshape the result
         | 
| 234 | 
            +
                result = points[..., :dim].reshape(*initial_shape, dim)
         | 
| 235 | 
            +
                return result
         | 
| 236 | 
            +
             | 
| 237 | 
            +
             | 
| 238 | 
            +
            def compute_camera_faces(cone_shape: trimesh.Trimesh) -> np.ndarray:
         | 
| 239 | 
            +
                """
         | 
| 240 | 
            +
                Computes the faces for the camera mesh.
         | 
| 241 | 
            +
             | 
| 242 | 
            +
                Args:
         | 
| 243 | 
            +
                    cone_shape (trimesh.Trimesh): The shape of the camera cone.
         | 
| 244 | 
            +
             | 
| 245 | 
            +
                Returns:
         | 
| 246 | 
            +
                    np.ndarray: Array of faces for the camera mesh.
         | 
| 247 | 
            +
                """
         | 
| 248 | 
            +
                # Create pseudo cameras
         | 
| 249 | 
            +
                faces_list = []
         | 
| 250 | 
            +
                num_vertices_cone = len(cone_shape.vertices)
         | 
| 251 | 
            +
             | 
| 252 | 
            +
                for face in cone_shape.faces:
         | 
| 253 | 
            +
                    if 0 in face:
         | 
| 254 | 
            +
                        continue
         | 
| 255 | 
            +
                    v1, v2, v3 = face
         | 
| 256 | 
            +
                    v1_offset, v2_offset, v3_offset = face + num_vertices_cone
         | 
| 257 | 
            +
                    v1_offset_2, v2_offset_2, v3_offset_2 = face + 2 * num_vertices_cone
         | 
| 258 | 
            +
             | 
| 259 | 
            +
                    faces_list.extend(
         | 
| 260 | 
            +
                        [
         | 
| 261 | 
            +
                            (v1, v2, v2_offset),
         | 
| 262 | 
            +
                            (v1, v1_offset, v3),
         | 
| 263 | 
            +
                            (v3_offset, v2, v3),
         | 
| 264 | 
            +
                            (v1, v2, v2_offset_2),
         | 
| 265 | 
            +
                            (v1, v1_offset_2, v3),
         | 
| 266 | 
            +
                            (v3_offset_2, v2, v3),
         | 
| 267 | 
            +
                        ]
         | 
| 268 | 
            +
                    )
         | 
| 269 | 
            +
             | 
| 270 | 
            +
                faces_list += [(v3, v2, v1) for v1, v2, v3 in faces_list]
         | 
| 271 | 
            +
                return np.array(faces_list)
         | 
| 272 | 
            +
             | 
| 273 | 
            +
             | 
| 274 | 
            +
            # -------------------------------------------------------------------------
         | 
| 275 | 
            +
            # 1) Core model inference
         | 
| 276 | 
            +
            # -------------------------------------------------------------------------
         | 
| 277 | 
            +
            def run_model(target_dir, model) -> dict:
         | 
| 278 | 
            +
                print(f"Processing images from {target_dir}")
         | 
| 279 | 
            +
             | 
| 280 | 
            +
                # Device check
         | 
| 281 | 
            +
                device = "cuda" if torch.cuda.is_available() else "cpu"
         | 
| 282 | 
            +
                if not torch.cuda.is_available():
         | 
| 283 | 
            +
                    raise ValueError("CUDA is not available. Check your environment.")
         | 
| 284 | 
            +
             | 
| 285 | 
            +
                # Move model to device
         | 
| 286 | 
            +
                model = model.to(device)
         | 
| 287 | 
            +
                model.eval()
         | 
| 288 | 
            +
             | 
| 289 | 
            +
                # Load and preprocess images
         | 
| 290 | 
            +
                image_names = glob.glob(os.path.join(target_dir, "images", "*"))
         | 
| 291 | 
            +
                image_names = sorted(image_names)
         | 
| 292 | 
            +
                print(f"Found {len(image_names)} images")
         | 
| 293 | 
            +
                if len(image_names) == 0:
         | 
| 294 | 
            +
                    raise ValueError("No images found. Check your upload.")
         | 
| 295 | 
            +
             | 
| 296 | 
            +
                # interval = 10 if target_dir.endswith('.mp4') else 1
         | 
| 297 | 
            +
                interval = 1
         | 
| 298 | 
            +
                imgs = load_images_as_tensor(os.path.join(target_dir, "images"), interval=interval).to(device) # (N, 3, H, W)
         | 
| 299 | 
            +
             | 
| 300 | 
            +
                # 3. Infer
         | 
| 301 | 
            +
                print("Running model inference...")
         | 
| 302 | 
            +
                dtype = torch.float16
         | 
| 303 | 
            +
                with torch.no_grad():
         | 
| 304 | 
            +
                    with torch.amp.autocast('cuda', dtype=dtype):
         | 
| 305 | 
            +
                        predictions = model(imgs[None]) # Add batch dimension
         | 
| 306 | 
            +
                predictions['images'] = imgs[None].permute(0, 1, 3, 4, 2)
         | 
| 307 | 
            +
                predictions['conf'] = torch.sigmoid(predictions['conf'])
         | 
| 308 | 
            +
                edge = depth_edge(predictions['local_points'][..., 2], rtol=0.03)
         | 
| 309 | 
            +
                predictions['conf'][edge] = 0.0
         | 
| 310 | 
            +
                del predictions['local_points']
         | 
| 311 | 
            +
             | 
| 312 | 
            +
                # # transform to first camera coordinate
         | 
| 313 | 
            +
                # predictions['points'] = torch.einsum('bij, bnhwj -> bnhwi', se3_inverse(predictions['camera_poses'][:, 0]), homogenize_points(predictions['points']))[..., :3]
         | 
| 314 | 
            +
                # predictions['camera_poses'] = torch.einsum('bij, bnjk -> bnik', se3_inverse(predictions['camera_poses'][:, 0]), predictions['camera_poses'])
         | 
| 315 | 
            +
             | 
| 316 | 
            +
                # Convert tensors to numpy
         | 
| 317 | 
            +
                for key in predictions.keys():
         | 
| 318 | 
            +
                    if isinstance(predictions[key], torch.Tensor):
         | 
| 319 | 
            +
                        predictions[key] = predictions[key].cpu().numpy().squeeze(0)  # remove batch dimension
         | 
| 320 | 
            +
             | 
| 321 | 
            +
                # Clean up
         | 
| 322 | 
            +
                torch.cuda.empty_cache()
         | 
| 323 | 
            +
                return predictions
         | 
| 324 | 
            +
             | 
| 325 | 
            +
             | 
| 326 | 
            +
            # -------------------------------------------------------------------------
         | 
| 327 | 
            +
            # 2) Handle uploaded video/images --> produce target_dir + images
         | 
| 328 | 
            +
            # -------------------------------------------------------------------------
         | 
| 329 | 
            +
            def handle_uploads(input_video, input_images, interval=-1):
         | 
| 330 | 
            +
                """
         | 
| 331 | 
            +
                Create a new 'target_dir' + 'images' subfolder, and place user-uploaded
         | 
| 332 | 
            +
                images or extracted frames from video into it. Return (target_dir, image_paths).
         | 
| 333 | 
            +
                """
         | 
| 334 | 
            +
                start_time = time.time()
         | 
| 335 | 
            +
                gc.collect()
         | 
| 336 | 
            +
                torch.cuda.empty_cache()
         | 
| 337 | 
            +
             | 
| 338 | 
            +
                # Create a unique folder name
         | 
| 339 | 
            +
                timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
         | 
| 340 | 
            +
                target_dir = f"input_images_{timestamp}"
         | 
| 341 | 
            +
                target_dir_images = os.path.join(target_dir, "images")
         | 
| 342 | 
            +
             | 
| 343 | 
            +
                # Clean up if somehow that folder already exists
         | 
| 344 | 
            +
                if os.path.exists(target_dir):
         | 
| 345 | 
            +
                    shutil.rmtree(target_dir)
         | 
| 346 | 
            +
                os.makedirs(target_dir, exist_ok=True)
         | 
| 347 | 
            +
                os.makedirs(target_dir_images, exist_ok=True)
         | 
| 348 | 
            +
             | 
| 349 | 
            +
                image_paths = []
         | 
| 350 | 
            +
             | 
| 351 | 
            +
                # --- Handle images ---
         | 
| 352 | 
            +
                if input_images is not None:
         | 
| 353 | 
            +
                    if interval is not None and interval > 0:
         | 
| 354 | 
            +
                        input_images = input_images[::interval]
         | 
| 355 | 
            +
             | 
| 356 | 
            +
                    for file_data in input_images:
         | 
| 357 | 
            +
                        if isinstance(file_data, dict) and "name" in file_data:
         | 
| 358 | 
            +
                            file_path = file_data["name"]
         | 
| 359 | 
            +
                        else:
         | 
| 360 | 
            +
                            file_path = file_data
         | 
| 361 | 
            +
                        dst_path = os.path.join(target_dir_images, os.path.basename(file_path))
         | 
| 362 | 
            +
                        shutil.copy(file_path, dst_path)
         | 
| 363 | 
            +
                        image_paths.append(dst_path)
         | 
| 364 | 
            +
                    
         | 
| 365 | 
            +
                # --- Handle video ---
         | 
| 366 | 
            +
                if input_video is not None:
         | 
| 367 | 
            +
                    if isinstance(input_video, dict) and "name" in input_video:
         | 
| 368 | 
            +
                        video_path = input_video["name"]
         | 
| 369 | 
            +
                    else:
         | 
| 370 | 
            +
                        video_path = input_video
         | 
| 371 | 
            +
             | 
| 372 | 
            +
                    vs = cv2.VideoCapture(video_path)
         | 
| 373 | 
            +
                    fps = vs.get(cv2.CAP_PROP_FPS)
         | 
| 374 | 
            +
                    if interval is not None and interval > 0:
         | 
| 375 | 
            +
                        frame_interval = interval
         | 
| 376 | 
            +
                    else:
         | 
| 377 | 
            +
                        frame_interval = int(fps * 1)  # 1 frame/sec
         | 
| 378 | 
            +
             | 
| 379 | 
            +
                    count = 0
         | 
| 380 | 
            +
                    video_frame_num = 0
         | 
| 381 | 
            +
                    while True:
         | 
| 382 | 
            +
                        gotit, frame = vs.read()
         | 
| 383 | 
            +
                        if not gotit:
         | 
| 384 | 
            +
                            break
         | 
| 385 | 
            +
                        count += 1
         | 
| 386 | 
            +
                        if count % frame_interval == 0:
         | 
| 387 | 
            +
                            image_path = os.path.join(target_dir_images, f"{video_frame_num:06}.png")
         | 
| 388 | 
            +
                            cv2.imwrite(image_path, frame)
         | 
| 389 | 
            +
                            image_paths.append(image_path)
         | 
| 390 | 
            +
                            video_frame_num += 1
         | 
| 391 | 
            +
             | 
| 392 | 
            +
                # Sort final images for gallery
         | 
| 393 | 
            +
                image_paths = sorted(image_paths)
         | 
| 394 | 
            +
             | 
| 395 | 
            +
                end_time = time.time()
         | 
| 396 | 
            +
                print(f"Files copied to {target_dir_images}; took {end_time - start_time:.3f} seconds")
         | 
| 397 | 
            +
                return target_dir, image_paths
         | 
| 398 | 
            +
             | 
| 399 | 
            +
             | 
| 400 | 
            +
            # -------------------------------------------------------------------------
         | 
| 401 | 
            +
            # 3) Update gallery on upload
         | 
| 402 | 
            +
            # -------------------------------------------------------------------------
         | 
| 403 | 
            +
            def update_gallery_on_upload(input_video, input_images, interval=-1):
         | 
| 404 | 
            +
                """
         | 
| 405 | 
            +
                Whenever user uploads or changes files, immediately handle them
         | 
| 406 | 
            +
                and show in the gallery. Return (target_dir, image_paths).
         | 
| 407 | 
            +
                If nothing is uploaded, returns "None" and empty list.
         | 
| 408 | 
            +
                """
         | 
| 409 | 
            +
                if not input_video and not input_images:
         | 
| 410 | 
            +
                    return None, None, None, None
         | 
| 411 | 
            +
                target_dir, image_paths = handle_uploads(input_video, input_images, interval=interval)
         | 
| 412 | 
            +
                return None, target_dir, image_paths, "Upload complete. Click 'Reconstruct' to begin 3D processing."
         | 
| 413 | 
            +
             | 
| 414 | 
            +
             | 
| 415 | 
            +
            # -------------------------------------------------------------------------
         | 
| 416 | 
            +
            # 4) Reconstruction: uses the target_dir plus any viz parameters
         | 
| 417 | 
            +
            # -------------------------------------------------------------------------
         | 
| 418 | 
            +
            def gradio_demo(
         | 
| 419 | 
            +
                target_dir,
         | 
| 420 | 
            +
                conf_thres=3.0,
         | 
| 421 | 
            +
                frame_filter="All",
         | 
| 422 | 
            +
                show_cam=True,
         | 
| 423 | 
            +
            ):
         | 
| 424 | 
            +
                """
         | 
| 425 | 
            +
                Perform reconstruction using the already-created target_dir/images.
         | 
| 426 | 
            +
                """
         | 
| 427 | 
            +
                if not os.path.isdir(target_dir) or target_dir == "None":
         | 
| 428 | 
            +
                    return None, "No valid target directory found. Please upload first.", None, None
         | 
| 429 | 
            +
             | 
| 430 | 
            +
                start_time = time.time()
         | 
| 431 | 
            +
                gc.collect()
         | 
| 432 | 
            +
                torch.cuda.empty_cache()
         | 
| 433 | 
            +
             | 
| 434 | 
            +
                # Prepare frame_filter dropdown
         | 
| 435 | 
            +
                target_dir_images = os.path.join(target_dir, "images")
         | 
| 436 | 
            +
                all_files = sorted(os.listdir(target_dir_images)) if os.path.isdir(target_dir_images) else []
         | 
| 437 | 
            +
                all_files = [f"{i}: {filename}" for i, filename in enumerate(all_files)]
         | 
| 438 | 
            +
                frame_filter_choices = ["All"] + all_files
         | 
| 439 | 
            +
             | 
| 440 | 
            +
                print("Running run_model...")
         | 
| 441 | 
            +
                with torch.no_grad():
         | 
| 442 | 
            +
                    predictions = run_model(target_dir, model)
         | 
| 443 | 
            +
             | 
| 444 | 
            +
                # Save predictions
         | 
| 445 | 
            +
                prediction_save_path = os.path.join(target_dir, "predictions.npz")
         | 
| 446 | 
            +
                np.savez(prediction_save_path, **predictions)
         | 
| 447 | 
            +
             | 
| 448 | 
            +
                # Handle None frame_filter
         | 
| 449 | 
            +
                if frame_filter is None:
         | 
| 450 | 
            +
                    frame_filter = "All"
         | 
| 451 | 
            +
             | 
| 452 | 
            +
                # Build a GLB file name
         | 
| 453 | 
            +
                glbfile = os.path.join(
         | 
| 454 | 
            +
                    target_dir,
         | 
| 455 | 
            +
                    f"glbscene_{conf_thres}_{frame_filter.replace('.', '_').replace(':', '').replace(' ', '_')}_cam{show_cam}.glb",
         | 
| 456 | 
            +
                )
         | 
| 457 | 
            +
             | 
| 458 | 
            +
                # Convert predictions to GLB
         | 
| 459 | 
            +
                glbscene = predictions_to_glb(
         | 
| 460 | 
            +
                    predictions,
         | 
| 461 | 
            +
                    conf_thres=conf_thres,
         | 
| 462 | 
            +
                    filter_by_frames=frame_filter,
         | 
| 463 | 
            +
                    show_cam=show_cam,
         | 
| 464 | 
            +
                )
         | 
| 465 | 
            +
                glbscene.export(file_obj=glbfile)
         | 
| 466 | 
            +
             | 
| 467 | 
            +
                # Cleanup
         | 
| 468 | 
            +
                del predictions
         | 
| 469 | 
            +
                gc.collect()
         | 
| 470 | 
            +
                torch.cuda.empty_cache()
         | 
| 471 | 
            +
             | 
| 472 | 
            +
                end_time = time.time()
         | 
| 473 | 
            +
                print(f"Total time: {end_time - start_time:.2f} seconds (including IO)")
         | 
| 474 | 
            +
                log_msg = f"Reconstruction Success ({len(all_files)} frames). Waiting for visualization."
         | 
| 475 | 
            +
             | 
| 476 | 
            +
                return glbfile, log_msg, gr.Dropdown(choices=frame_filter_choices, value=frame_filter, interactive=True)
         | 
| 477 | 
            +
             | 
| 478 | 
            +
             | 
| 479 | 
            +
            # -------------------------------------------------------------------------
         | 
| 480 | 
            +
            # 5) Helper functions for UI resets + re-visualization
         | 
| 481 | 
            +
            # -------------------------------------------------------------------------
         | 
| 482 | 
            +
            def clear_fields():
         | 
| 483 | 
            +
                """
         | 
| 484 | 
            +
                Clears the 3D viewer, the stored target_dir, and empties the gallery.
         | 
| 485 | 
            +
                """
         | 
| 486 | 
            +
                return None
         | 
| 487 | 
            +
             | 
| 488 | 
            +
             | 
| 489 | 
            +
            def update_log():
         | 
| 490 | 
            +
                """
         | 
| 491 | 
            +
                Display a quick log message while waiting.
         | 
| 492 | 
            +
                """
         | 
| 493 | 
            +
                return "Loading and Reconstructing..."
         | 
| 494 | 
            +
             | 
| 495 | 
            +
             | 
| 496 | 
            +
            def update_visualization(
         | 
| 497 | 
            +
                target_dir, conf_thres, frame_filter, show_cam, is_example
         | 
| 498 | 
            +
            ):
         | 
| 499 | 
            +
                """
         | 
| 500 | 
            +
                Reload saved predictions from npz, create (or reuse) the GLB for new parameters,
         | 
| 501 | 
            +
                and return it for the 3D viewer. If is_example == "True", skip.
         | 
| 502 | 
            +
                """
         | 
| 503 | 
            +
             | 
| 504 | 
            +
                # If it's an example click, skip as requested
         | 
| 505 | 
            +
                if is_example == "True":
         | 
| 506 | 
            +
                    return None, "No reconstruction available. Please click the Reconstruct button first."
         | 
| 507 | 
            +
             | 
| 508 | 
            +
                if not target_dir or target_dir == "None" or not os.path.isdir(target_dir):
         | 
| 509 | 
            +
                    return None, "No reconstruction available. Please click the Reconstruct button first."
         | 
| 510 | 
            +
             | 
| 511 | 
            +
                predictions_path = os.path.join(target_dir, "predictions.npz")
         | 
| 512 | 
            +
                if not os.path.exists(predictions_path):
         | 
| 513 | 
            +
                    return None, f"No reconstruction available at {predictions_path}. Please run 'Reconstruct' first."
         | 
| 514 | 
            +
             | 
| 515 | 
            +
                key_list = [
         | 
| 516 | 
            +
                    "images",
         | 
| 517 | 
            +
                    "points",
         | 
| 518 | 
            +
                    "conf",
         | 
| 519 | 
            +
                    "camera_poses",
         | 
| 520 | 
            +
                ]
         | 
| 521 | 
            +
             | 
| 522 | 
            +
                loaded = np.load(predictions_path)
         | 
| 523 | 
            +
                predictions = {key: np.array(loaded[key]) for key in key_list}
         | 
| 524 | 
            +
             | 
| 525 | 
            +
                glbfile = os.path.join(
         | 
| 526 | 
            +
                    target_dir,
         | 
| 527 | 
            +
                    f"glbscene_{conf_thres}_{frame_filter.replace('.', '_').replace(':', '').replace(' ', '_')}_cam{show_cam}.glb",
         | 
| 528 | 
            +
                )
         | 
| 529 | 
            +
             | 
| 530 | 
            +
                if not os.path.exists(glbfile):
         | 
| 531 | 
            +
                    glbscene = predictions_to_glb(
         | 
| 532 | 
            +
                        predictions,
         | 
| 533 | 
            +
                        conf_thres=conf_thres,
         | 
| 534 | 
            +
                        filter_by_frames=frame_filter,
         | 
| 535 | 
            +
                        show_cam=show_cam,
         | 
| 536 | 
            +
                    )
         | 
| 537 | 
            +
                    glbscene.export(file_obj=glbfile)
         | 
| 538 | 
            +
             | 
| 539 | 
            +
                return glbfile, "Updating Visualization"
         | 
| 540 | 
            +
             | 
| 541 | 
            +
             | 
| 542 | 
            +
            # -------------------------------------------------------------------------
         | 
| 543 | 
            +
            # Example images
         | 
| 544 | 
            +
            # -------------------------------------------------------------------------
         | 
| 545 | 
            +
             | 
| 546 | 
            +
            house = "examples/gradio_examples/house.mp4"
         | 
| 547 | 
            +
            man_walking_long = "examples/gradio_examples/man_walking_long.mp4"
         | 
| 548 | 
            +
            parkour = "examples/gradio_examples/parkour.mp4"
         | 
| 549 | 
            +
            valley = "examples/gradio_examples/valley.mp4"
         | 
| 550 | 
            +
            cartoon_horse = "examples/cartoon_horse.mp4"
         | 
| 551 | 
            +
            parkour_long = "examples/parkour_long.mp4"
         | 
| 552 | 
            +
            skating = "examples/skating.mp4"
         | 
| 553 | 
            +
            skiing = "examples/skiing.mp4"
         | 
| 554 | 
            +
             | 
| 555 | 
            +
            # -------------------------------------------------------------------------
         | 
| 556 | 
            +
            # 6) Build Gradio UI
         | 
| 557 | 
            +
            # -------------------------------------------------------------------------
         | 
| 558 | 
            +
             | 
| 559 | 
            +
            if __name__ == '__main__':
         | 
| 560 | 
            +
             | 
| 561 | 
            +
                device = "cuda" if torch.cuda.is_available() else "cpu"
         | 
| 562 | 
            +
             | 
| 563 | 
            +
                print("Initializing and loading Pi3 model...")
         | 
| 564 | 
            +
             | 
| 565 | 
            +
                model = Pi3()
         | 
| 566 | 
            +
                # _URL = "https://huggingface.co/yyfz233/Pi3/resolve/main/model.safetensors"
         | 
| 567 | 
            +
                # model.load_state_dict(torch.hub.load_state_dict_from_url(_URL))
         | 
| 568 | 
            +
                model.load_state_dict(torch.load('ckpts/pi3.pt', weights_only=False, map_location=device))
         | 
| 569 | 
            +
             | 
| 570 | 
            +
                model.eval()
         | 
| 571 | 
            +
                model = model.to(device)
         | 
| 572 | 
            +
             | 
| 573 | 
            +
                theme = gr.themes.Ocean()
         | 
| 574 | 
            +
                theme.set(
         | 
| 575 | 
            +
                    checkbox_label_background_fill_selected="*button_primary_background_fill",
         | 
| 576 | 
            +
                    checkbox_label_text_color_selected="*button_primary_text_color",
         | 
| 577 | 
            +
                )
         | 
| 578 | 
            +
             | 
| 579 | 
            +
                with gr.Blocks(
         | 
| 580 | 
            +
                    theme=theme,
         | 
| 581 | 
            +
                    css="""
         | 
| 582 | 
            +
                    /* --- Google 字体导入 (科技感字体) --- */
         | 
| 583 | 
            +
                    @import url('https://fonts.googleapis.com/css2?family=Orbitron:wght@400;700;900&family=Rajdhani:wght@400;500;700&display=swap');
         | 
| 584 | 
            +
             | 
| 585 | 
            +
                    /* --- 动画关键帧 --- */
         | 
| 586 | 
            +
                    /* 背景动态星云效果 */
         | 
| 587 | 
            +
                    @keyframes gradient-animation {
         | 
| 588 | 
            +
                        0% { background-position: 0% 50%; }
         | 
| 589 | 
            +
                        50% { background-position: 100% 50%; }
         | 
| 590 | 
            +
                        100% { background-position: 0% 50%; }
         | 
| 591 | 
            +
                    }
         | 
| 592 | 
            +
             | 
| 593 | 
            +
                    /* 标题和状态文字的霓虹灯光效 */
         | 
| 594 | 
            +
                    @keyframes text-glow {
         | 
| 595 | 
            +
                        0%, 100% {
         | 
| 596 | 
            +
                            text-shadow: 0 0 10px #0ea5e9, 0 0 20px #0ea5e9, 0 0 30px #4f46e5, 0 0 40px #4f46e5;
         | 
| 597 | 
            +
                        }
         | 
| 598 | 
            +
                        50% {
         | 
| 599 | 
            +
                            text-shadow: 0 0 5px #0ea5e9, 0 0 10px #0ea5e9, 0 0 15px #4f46e5, 0 0 20px #4f46e5;
         | 
| 600 | 
            +
                        }
         | 
| 601 | 
            +
                    }
         | 
| 602 | 
            +
             | 
| 603 | 
            +
                    /* 卡片边框呼吸光晕 */
         | 
| 604 | 
            +
                    @keyframes border-glow {
         | 
| 605 | 
            +
                        0% { border-color: rgba(79, 70, 229, 0.5); box-shadow: 0 0 15px rgba(79, 70, 229, 0.3); }
         | 
| 606 | 
            +
                        50% { border-color: rgba(14, 165, 233, 0.8); box-shadow: 0 0 25px rgba(14, 165, 233, 0.5); }
         | 
| 607 | 
            +
                        100% { border-color: rgba(79, 70, 229, 0.5); box-shadow: 0 0 15px rgba(79, 70, 229, 0.3); }
         | 
| 608 | 
            +
                    }
         | 
| 609 | 
            +
             | 
| 610 | 
            +
                    /* --- 全局样式:宇宙黑暗主题 --- */
         | 
| 611 | 
            +
                    .gradio-container {
         | 
| 612 | 
            +
                        font-family: 'Rajdhani', sans-serif;
         | 
| 613 | 
            +
                        background: linear-gradient(-45deg, #020617, #111827, #082f49, #4f46e5);
         | 
| 614 | 
            +
                        background-size: 400% 400%;
         | 
| 615 | 
            +
                        animation: gradient-animation 20s ease infinite;
         | 
| 616 | 
            +
                        color: #9ca3af;
         | 
| 617 | 
            +
                    }
         | 
| 618 | 
            +
             | 
| 619 | 
            +
                    /* --- 全局文字颜色修复 (解决Light Mode问题) --- */
         | 
| 620 | 
            +
                    
         | 
| 621 | 
            +
                    /* 1. 修复全局、标签和输入框内的文字颜色 */
         | 
| 622 | 
            +
                    .gradio-container, .gr-label label, .gr-input, input, textarea, .gr-check-radio label {
         | 
| 623 | 
            +
                        color: #d1d5db !important; /* 设置一个柔和的浅灰色 */
         | 
| 624 | 
            +
                    }
         | 
| 625 | 
            +
             | 
| 626 | 
            +
                    /* 2. 修复 Examples 表头 (这是您问题的核心) */
         | 
| 627 | 
            +
                    thead th {
         | 
| 628 | 
            +
                        color: white !important;
         | 
| 629 | 
            +
                        background-color: #1f2937 !important; /* 同时给表头一个背景色,视觉效果更好 */
         | 
| 630 | 
            +
                    }
         | 
| 631 | 
            +
             | 
| 632 | 
            +
                    /* 3. 修复 Examples 表格内容文字 */
         | 
| 633 | 
            +
                    tbody td {
         | 
| 634 | 
            +
                        color: #d1d5db !important;
         | 
| 635 | 
            +
                    }
         | 
| 636 | 
            +
                    
         | 
| 637 | 
            +
                    /* --- 状态信息 & 输出标题样式 (custom-log) ✨ --- */
         | 
| 638 | 
            +
                    .custom-log * {
         | 
| 639 | 
            +
                        font-family: 'Orbitron', sans-serif;
         | 
| 640 | 
            +
                        font-size: 24px !important;
         | 
| 641 | 
            +
                        font-weight: 700 !important;
         | 
| 642 | 
            +
                        text-align: center !important;
         | 
| 643 | 
            +
                        color: transparent !important;
         | 
| 644 | 
            +
                        background-image: linear-gradient(120deg, #93c5fd, #6ee7b7, #fde047);
         | 
| 645 | 
            +
                        background-size: 300% 300%;
         | 
| 646 | 
            +
                        -webkit-background-clip: text;
         | 
| 647 | 
            +
                        background-clip: text;
         | 
| 648 | 
            +
                        animation: gradient-animation 8s ease-in-out infinite, text-glow 3s ease-in-out infinite;
         | 
| 649 | 
            +
                        padding: 10px 0;
         | 
| 650 | 
            +
                    }
         | 
| 651 | 
            +
                    
         | 
| 652 | 
            +
                    /* --- UI 卡片/分组样式 (玻璃拟态) 💎 --- */
         | 
| 653 | 
            +
                    .gr-block.gr-group {
         | 
| 654 | 
            +
                        background-color: rgba(17, 24, 39, 0.6);
         | 
| 655 | 
            +
                        backdrop-filter: blur(10px);
         | 
| 656 | 
            +
                        -webkit-backdrop-filter: blur(10px);
         | 
| 657 | 
            +
                        border: 1px solid rgba(55, 65, 81, 0.5);
         | 
| 658 | 
            +
                        border-radius: 16px;
         | 
| 659 | 
            +
                        box-shadow: 0 8px 32px 0 rgba(0, 0, 0, 0.37);
         | 
| 660 | 
            +
                        transition: all 0.3s ease;
         | 
| 661 | 
            +
                        /* 应用边框呼吸光晕动画 */
         | 
| 662 | 
            +
                        animation: border-glow 5s infinite alternate;
         | 
| 663 | 
            +
                    }
         | 
| 664 | 
            +
                    .gr-block.gr-group:hover {
         | 
| 665 | 
            +
                        box-shadow: 0 0 25px rgba(14, 165, 233, 0.4);
         | 
| 666 | 
            +
                        border-color: rgba(14, 165, 233, 0.6);
         | 
| 667 | 
            +
                    }
         | 
| 668 | 
            +
                    
         | 
| 669 | 
            +
                    /* --- 酷炫按钮样式 🚀 --- */
         | 
| 670 | 
            +
                    .gr-button {
         | 
| 671 | 
            +
                        background: linear-gradient(to right, #4f46e5, #7c3aed, #0ea5e9) !important;
         | 
| 672 | 
            +
                        background-size: 200% auto !important;
         | 
| 673 | 
            +
                        color: white !important;
         | 
| 674 | 
            +
                        font-weight: bold !important;
         | 
| 675 | 
            +
                        border: none !important;
         | 
| 676 | 
            +
                        border-radius: 10px !important;
         | 
| 677 | 
            +
                        box-shadow: 0 4px 15px 0 rgba(79, 70, 229, 0.5) !important;
         | 
| 678 | 
            +
                        transition: all 0.4s ease-in-out !important;
         | 
| 679 | 
            +
                        font-family: 'Orbitron', sans-serif !important;
         | 
| 680 | 
            +
                        text-transform: uppercase;
         | 
| 681 | 
            +
                        letter-spacing: 1px;
         | 
| 682 | 
            +
                    }
         | 
| 683 | 
            +
                    .gr-button:hover {
         | 
| 684 | 
            +
                        background-position: right center !important;
         | 
| 685 | 
            +
                        box-shadow: 0 4px 20px 0 rgba(14, 165, 233, 0.6) !important;
         | 
| 686 | 
            +
                        transform: translateY(-3px) scale(1.02);
         | 
| 687 | 
            +
                    }
         | 
| 688 | 
            +
                    .gr-button.primary {
         | 
| 689 | 
            +
                        /* 主按钮增加呼吸光晕动画 */
         | 
| 690 | 
            +
                        animation: border-glow 3s infinite alternate;
         | 
| 691 | 
            +
                    }
         | 
| 692 | 
            +
                    """,
         | 
| 693 | 
            +
                ) as demo:
         | 
| 694 | 
            +
                    # Instead of gr.State, we use a hidden Textbox:
         | 
| 695 | 
            +
                    is_example = gr.Textbox(label="is_example", visible=False, value="None")
         | 
| 696 | 
            +
                    num_images = gr.Textbox(label="num_images", visible=False, value="None")
         | 
| 697 | 
            +
                    target_dir_output = gr.Textbox(label="Target Dir", visible=False, value="None")
         | 
| 698 | 
            +
             | 
| 699 | 
            +
                    gr.HTML(
         | 
| 700 | 
            +
                    """
         | 
| 701 | 
            +
                    <style>
         | 
| 702 | 
            +
                            /* --- 介绍文字区专属样式 --- */
         | 
| 703 | 
            +
                            .intro-content { font-size: 17px !important; line-height: 1.7; color: #C0C0C0 !important; }
         | 
| 704 | 
            +
                            /* 额外为 p 标签添加规则,确保覆盖 */
         | 
| 705 | 
            +
                            .intro-content p { color: #C0C0C0 !important; }
         | 
| 706 | 
            +
                            
         | 
| 707 | 
            +
                            .intro-content h1 {
         | 
| 708 | 
            +
                                font-family: 'Orbitron', sans-serif; font-size: 2.8em !important; font-weight: 900;
         | 
| 709 | 
            +
                                text-align: center; color: #C0C0C0 !important; animation: text-glow 4s ease-in-out infinite; margin-bottom: 0px;
         | 
| 710 | 
            +
                            }
         | 
| 711 | 
            +
                            .intro-content .pi-symbol {
         | 
| 712 | 
            +
                                display: inline-block; color: transparent;
         | 
| 713 | 
            +
                                background-image: linear-gradient(120deg, #38bdf8, #818cf8, #c084fc);
         | 
| 714 | 
            +
                                -webkit-background-clip: text; background-clip: text;
         | 
| 715 | 
            +
                                text-shadow: 0 0 15px rgba(129, 140, 248, 0.5);
         | 
| 716 | 
            +
                            }
         | 
| 717 | 
            +
                            .intro-content .subtitle { text-align: center; font-size: 1.1em; margin-bottom: 2rem; }
         | 
| 718 | 
            +
                            .intro-content a.themed-link {
         | 
| 719 | 
            +
                                color: #C0C0C0 !important; text-decoration: none; font-weight: 700; transition: all 0.3s ease;
         | 
| 720 | 
            +
                            }
         | 
| 721 | 
            +
                            .intro-content a.themed-link:hover { color: #EAEAEA !important; text-shadow: 0 0 8px rgba(234, 234, 234, 0.7); }
         | 
| 722 | 
            +
                            .intro-content h3 {
         | 
| 723 | 
            +
                                font-family: 'Orbitron', sans-serif; color: #C0C0C0 !important; text-transform: uppercase;
         | 
| 724 | 
            +
                                letter-spacing: 2px; border-bottom: 1px solid #374151; padding-bottom: 8px; margin-top: 25px;
         | 
| 725 | 
            +
                            }
         | 
| 726 | 
            +
                            .intro-content ol { list-style: none; padding-left: 0; counter-reset: step-counter; }
         | 
| 727 | 
            +
                            .intro-content ol li {
         | 
| 728 | 
            +
                                counter-increment: step-counter; margin-bottom: 15px; padding-left: 45px; position: relative;
         | 
| 729 | 
            +
                                color: #C0C0C0 !important; /* 确保列表项文字也是银白色 */
         | 
| 730 | 
            +
                            }
         | 
| 731 | 
            +
                            /* 自定义酷炫列表数字 */
         | 
| 732 | 
            +
                            .intro-content ol li::before {
         | 
| 733 | 
            +
                                content: counter(step-counter); position: absolute; left: 0; top: 0;
         | 
| 734 | 
            +
                                width: 30px; height: 30px; background: linear-gradient(135deg, #1e3a8a, #4f46e5);
         | 
| 735 | 
            +
                                border-radius: 50%; color: white; font-weight: 700; font-family: 'Orbitron', sans-serif;
         | 
| 736 | 
            +
                                display: flex; align-items: center; justify-content: center;
         | 
| 737 | 
            +
                                box-shadow: 0 0 10px rgba(79, 70, 229, 0.5);
         | 
| 738 | 
            +
                            }
         | 
| 739 | 
            +
                            .intro-content strong { color: #C0C0C0 !important; font-weight: 700; }
         | 
| 740 | 
            +
                            .intro-content .performance-note {
         | 
| 741 | 
            +
                                background-color: rgba(14, 165, 233, 0.1); border-left: 4px solid #0ea5e9;
         | 
| 742 | 
            +
                                padding: 15px; border-radius: 8px; margin-top: 20px;
         | 
| 743 | 
            +
                            }
         | 
| 744 | 
            +
                            /* 确保提示框内的文字也生效 */
         | 
| 745 | 
            +
                            .intro-content .performance-note p { color: #C0C0C0 !important; }
         | 
| 746 | 
            +
             | 
| 747 | 
            +
                    </style>
         | 
| 748 | 
            +
                            
         | 
| 749 | 
            +
                    <div class="intro-content">
         | 
| 750 | 
            +
                        <h1>🌌 <span class="pi-symbol">π³</span>: Scalable Permutation-Equivariant Visual Geometry Learning</h1>
         | 
| 751 | 
            +
                        <p class="subtitle">
         | 
| 752 | 
            +
                            <a class="themed-link" href="">🐙 GitHub Repository</a> |
         | 
| 753 | 
            +
                            <a class="themed-link" href="#">🚀 Project Page</a>
         | 
| 754 | 
            +
                        </p>
         | 
| 755 | 
            +
                        
         | 
| 756 | 
            +
                        <p>Transform your videos or image collections into detailed 3D models. The <strong class="pi-symbol">π³</strong> model processes your visual data to generate a rich 3D point cloud and calculate the corresponding camera perspectives.</p>
         | 
| 757 | 
            +
                        
         | 
| 758 | 
            +
                        <h3>How to Use:</h3>
         | 
| 759 | 
            +
                        <ol>
         | 
| 760 | 
            +
                            <li><strong>Provide Your Media:</strong> Upload a video or image set. You can specify a sampling interval below. By default, videos are sampled at 1 frame per second, and for image sets, every image is used (interval of 1). Your inputs will be displayed in the "Preview" gallery.</li>
         | 
| 761 | 
            +
                            <li><strong>Generate the 3D Model:</strong> Press the "Reconstruct" button to initiate the process.</li>
         | 
| 762 | 
            +
                            <li><strong>Explore and Refine Your Model:</strong> The generated 3D model will appear in the viewer on the right. Interact with it by rotating, panning, and zooming. You can also download the model as a GLB file. For further refinement, use the options below the viewer to adjust point confidence, filter by frame, or toggle camera visibility.</li>
         | 
| 763 | 
            +
                        </ol>
         | 
| 764 | 
            +
                        
         | 
| 765 | 
            +
                        <div class="performance-note">
         | 
| 766 | 
            +
                            <p><strong>A Quick Note on Performance:</strong> The core processing by <strong class="pi-symbol">π³</strong> is incredibly fast, typically finishing in under a second. However, rendering the final 3D point cloud can take longer, depending on the complexity of the scene and the capabilities of the rendering engine.</p>
         | 
| 767 | 
            +
                        </div>
         | 
| 768 | 
            +
                    </div>
         | 
| 769 | 
            +
                    """
         | 
| 770 | 
            +
                )
         | 
| 771 | 
            +
             | 
| 772 | 
            +
                    with gr.Row():
         | 
| 773 | 
            +
                        with gr.Column(scale=1):
         | 
| 774 | 
            +
                            with gr.Group():
         | 
| 775 | 
            +
                                gr.Markdown("### 1. Upload Media")
         | 
| 776 | 
            +
                                input_video = gr.Video(label="Upload Video", interactive=True)
         | 
| 777 | 
            +
                                input_images = gr.File(file_count="multiple", label="Or Upload Images", interactive=True)
         | 
| 778 | 
            +
                                interval = gr.Number(None, label='Frame/Image Interval', info="Sampling interval. Video default: 1 FPS. Image default: 1 (all images).")
         | 
| 779 | 
            +
                            
         | 
| 780 | 
            +
                            image_gallery = gr.Gallery(
         | 
| 781 | 
            +
                                label="Image Preview",
         | 
| 782 | 
            +
                                columns=4,
         | 
| 783 | 
            +
                                height="300px",
         | 
| 784 | 
            +
                                show_download_button=True,
         | 
| 785 | 
            +
                                object_fit="contain",
         | 
| 786 | 
            +
                                preview=True,
         | 
| 787 | 
            +
                            )
         | 
| 788 | 
            +
             | 
| 789 | 
            +
                        with gr.Column(scale=2):
         | 
| 790 | 
            +
                            gr.Markdown("### 2. View Reconstruction")
         | 
| 791 | 
            +
                            log_output = gr.Markdown("Please upload media and click Reconstruct.", elem_classes=["custom-log"])
         | 
| 792 | 
            +
                            reconstruction_output = gr.Model3D(height=480, zoom_speed=0.5, pan_speed=0.5, label="3D Output")
         | 
| 793 | 
            +
                            
         | 
| 794 | 
            +
                            with gr.Row():
         | 
| 795 | 
            +
                                submit_btn = gr.Button("Reconstruct", scale=3, variant="primary")
         | 
| 796 | 
            +
                                clear_btn = gr.ClearButton(
         | 
| 797 | 
            +
                                    scale=1
         | 
| 798 | 
            +
                                )
         | 
| 799 | 
            +
                            
         | 
| 800 | 
            +
                            with gr.Group():
         | 
| 801 | 
            +
                                gr.Markdown("### 3. Adjust Visualization")
         | 
| 802 | 
            +
                                with gr.Row():
         | 
| 803 | 
            +
                                    conf_thres = gr.Slider(minimum=0, maximum=100, value=20, step=0.1, label="Confidence Threshold (%)")
         | 
| 804 | 
            +
                                    show_cam = gr.Checkbox(label="Show Cameras", value=True)
         | 
| 805 | 
            +
                                frame_filter = gr.Dropdown(choices=["All"], value="All", label="Show Points from Frame")
         | 
| 806 | 
            +
             | 
| 807 | 
            +
                    # Set clear button targets
         | 
| 808 | 
            +
                    clear_btn.add([input_video, input_images, reconstruction_output, log_output, target_dir_output, image_gallery, interval])
         | 
| 809 | 
            +
             | 
| 810 | 
            +
                    # ---------------------- Examples section ----------------------
         | 
| 811 | 
            +
                    examples = [
         | 
| 812 | 
            +
                        [skating, None, 10, 20, True],
         | 
| 813 | 
            +
                        [parkour_long, None, 20, 10, True],
         | 
| 814 | 
            +
                        [cartoon_horse, None, 10, 20, True],
         | 
| 815 | 
            +
                        [skiing, None, 30, 70, True],
         | 
| 816 | 
            +
                        [man_walking_long, None, 1, 50, True],
         | 
| 817 | 
            +
                        [house, None, 1, 20, True],
         | 
| 818 | 
            +
                        [parkour, None, 1, 20, True],
         | 
| 819 | 
            +
                        [valley, None, 1, 20, True],
         | 
| 820 | 
            +
                    ]
         | 
| 821 | 
            +
             | 
| 822 | 
            +
                    def example_pipeline(
         | 
| 823 | 
            +
                        input_video,
         | 
| 824 | 
            +
                        input_images,
         | 
| 825 | 
            +
                        interval,
         | 
| 826 | 
            +
                        conf_thres,
         | 
| 827 | 
            +
                        show_cam,
         | 
| 828 | 
            +
                    ):
         | 
| 829 | 
            +
                        """
         | 
| 830 | 
            +
                        1) Copy example images to new target_dir
         | 
| 831 | 
            +
                        2) Reconstruct
         | 
| 832 | 
            +
                        3) Return model3D + logs + new_dir + updated dropdown + gallery
         | 
| 833 | 
            +
                        We do NOT return is_example. It's just an input.
         | 
| 834 | 
            +
                        """
         | 
| 835 | 
            +
                        target_dir, image_paths = handle_uploads(input_video, input_images, interval)
         | 
| 836 | 
            +
                        # Always use "All" for frame_filter in examples
         | 
| 837 | 
            +
                        frame_filter = "All"
         | 
| 838 | 
            +
                        glbfile, log_msg, dropdown = gradio_demo(
         | 
| 839 | 
            +
                            target_dir, conf_thres, frame_filter, show_cam
         | 
| 840 | 
            +
                        )
         | 
| 841 | 
            +
                        return glbfile, log_msg, target_dir, dropdown, image_paths
         | 
| 842 | 
            +
             | 
| 843 | 
            +
                    gr.Markdown("Click any row to load an example.", elem_classes=["example-log"])
         | 
| 844 | 
            +
             | 
| 845 | 
            +
                    gr.Examples(
         | 
| 846 | 
            +
                        examples=examples,
         | 
| 847 | 
            +
                        inputs=[
         | 
| 848 | 
            +
                            input_video,
         | 
| 849 | 
            +
                            input_images,
         | 
| 850 | 
            +
                            interval,
         | 
| 851 | 
            +
                            conf_thres,
         | 
| 852 | 
            +
                            show_cam,
         | 
| 853 | 
            +
                        ],
         | 
| 854 | 
            +
                        outputs=[reconstruction_output, log_output, target_dir_output, frame_filter, image_gallery],
         | 
| 855 | 
            +
                        fn=example_pipeline,
         | 
| 856 | 
            +
                        cache_examples=False,
         | 
| 857 | 
            +
                        examples_per_page=50,
         | 
| 858 | 
            +
                        run_on_click=False,
         | 
| 859 | 
            +
                    )
         | 
| 860 | 
            +
             | 
| 861 | 
            +
                    # -------------------------------------------------------------------------
         | 
| 862 | 
            +
                    # "Reconstruct" button logic:
         | 
| 863 | 
            +
                    #  - Clear fields
         | 
| 864 | 
            +
                    #  - Update log
         | 
| 865 | 
            +
                    #  - gradio_demo(...) with the existing target_dir
         | 
| 866 | 
            +
                    #  - Then set is_example = "False"
         | 
| 867 | 
            +
                    # -------------------------------------------------------------------------
         | 
| 868 | 
            +
                    submit_btn.click(fn=clear_fields, inputs=[], outputs=[reconstruction_output]).then(
         | 
| 869 | 
            +
                        fn=update_log, inputs=[], outputs=[log_output]
         | 
| 870 | 
            +
                    ).then(
         | 
| 871 | 
            +
                        fn=gradio_demo,
         | 
| 872 | 
            +
                        inputs=[
         | 
| 873 | 
            +
                            target_dir_output,
         | 
| 874 | 
            +
                            conf_thres,
         | 
| 875 | 
            +
                            frame_filter,
         | 
| 876 | 
            +
                            show_cam,
         | 
| 877 | 
            +
                        ],
         | 
| 878 | 
            +
                        outputs=[reconstruction_output, log_output, frame_filter],
         | 
| 879 | 
            +
                    ).then(
         | 
| 880 | 
            +
                        fn=lambda: "False", inputs=[], outputs=[is_example]  # set is_example to "False"
         | 
| 881 | 
            +
                    )
         | 
| 882 | 
            +
             | 
| 883 | 
            +
                    # -------------------------------------------------------------------------
         | 
| 884 | 
            +
                    # Real-time Visualization Updates
         | 
| 885 | 
            +
                    # -------------------------------------------------------------------------
         | 
| 886 | 
            +
                    conf_thres.change(
         | 
| 887 | 
            +
                        update_visualization,
         | 
| 888 | 
            +
                        [
         | 
| 889 | 
            +
                            target_dir_output,
         | 
| 890 | 
            +
                            conf_thres,
         | 
| 891 | 
            +
                            frame_filter,
         | 
| 892 | 
            +
                            show_cam,
         | 
| 893 | 
            +
                            is_example,
         | 
| 894 | 
            +
                        ],
         | 
| 895 | 
            +
                        [reconstruction_output, log_output],
         | 
| 896 | 
            +
                    )
         | 
| 897 | 
            +
                    frame_filter.change(
         | 
| 898 | 
            +
                        update_visualization,
         | 
| 899 | 
            +
                        [
         | 
| 900 | 
            +
                            target_dir_output,
         | 
| 901 | 
            +
                            conf_thres,
         | 
| 902 | 
            +
                            frame_filter,
         | 
| 903 | 
            +
                            show_cam,
         | 
| 904 | 
            +
                            is_example,
         | 
| 905 | 
            +
                        ],
         | 
| 906 | 
            +
                        [reconstruction_output, log_output],
         | 
| 907 | 
            +
                    )
         | 
| 908 | 
            +
                
         | 
| 909 | 
            +
                    show_cam.change(
         | 
| 910 | 
            +
                        update_visualization,
         | 
| 911 | 
            +
                        [
         | 
| 912 | 
            +
                            target_dir_output,
         | 
| 913 | 
            +
                            conf_thres,
         | 
| 914 | 
            +
                            frame_filter,
         | 
| 915 | 
            +
                            show_cam,
         | 
| 916 | 
            +
                            is_example,
         | 
| 917 | 
            +
                        ],
         | 
| 918 | 
            +
                        [reconstruction_output, log_output],
         | 
| 919 | 
            +
                    )
         | 
| 920 | 
            +
             | 
| 921 | 
            +
                    # -------------------------------------------------------------------------
         | 
| 922 | 
            +
                    # Auto-update gallery whenever user uploads or changes their files
         | 
| 923 | 
            +
                    # -------------------------------------------------------------------------
         | 
| 924 | 
            +
                    input_video.change(
         | 
| 925 | 
            +
                        fn=update_gallery_on_upload,
         | 
| 926 | 
            +
                        inputs=[input_video, input_images, interval],
         | 
| 927 | 
            +
                        outputs=[reconstruction_output, target_dir_output, image_gallery, log_output],
         | 
| 928 | 
            +
                    )
         | 
| 929 | 
            +
                    input_images.change(
         | 
| 930 | 
            +
                        fn=update_gallery_on_upload,
         | 
| 931 | 
            +
                        inputs=[input_video, input_images, interval],
         | 
| 932 | 
            +
                        outputs=[reconstruction_output, target_dir_output, image_gallery, log_output],
         | 
| 933 | 
            +
                    )
         | 
| 934 | 
            +
             | 
| 935 | 
            +
                demo.queue(max_size=20).launch(show_error=True, share=True, server_port=10001, server_name='0.0.0.0')
         | 
| 936 | 
            +
             | 
| 937 | 
            +
             | 
| 938 | 
            +
             | 
    	
        example.py
    ADDED
    
    | @@ -0,0 +1,67 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            import argparse
         | 
| 3 | 
            +
            from pi3.utils.basic import load_images_as_tensor, write_ply
         | 
| 4 | 
            +
            from pi3.utils.geometry import depth_edge
         | 
| 5 | 
            +
            from pi3.models.pi3 import Pi3
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            if __name__ == '__main__':
         | 
| 8 | 
            +
                # --- Argument Parsing ---
         | 
| 9 | 
            +
                parser = argparse.ArgumentParser(description="Run inference with the Pi3 model.")
         | 
| 10 | 
            +
                
         | 
| 11 | 
            +
                parser.add_argument("--data_path", type=str, default='examples/parkour',
         | 
| 12 | 
            +
                                    help="Path to the input image directory or a video file.")
         | 
| 13 | 
            +
                parser.add_argument("--save_path", type=str, default='examples/parkour.ply',
         | 
| 14 | 
            +
                                    help="Path to save the output .ply file.")
         | 
| 15 | 
            +
                parser.add_argument("--interval", type=int, default=-1,
         | 
| 16 | 
            +
                                    help="Interval to sample image. Default: 1 for images dir, 10 for video")
         | 
| 17 | 
            +
                parser.add_argument("--ckpt", type=str, default=None,
         | 
| 18 | 
            +
                                    help="Path to the model checkpoint file. Default: None")
         | 
| 19 | 
            +
                parser.add_argument("--device", type=str, default='cuda',
         | 
| 20 | 
            +
                                    help="Device to run inference on ('cuda' or 'cpu'). Default: 'cuda'")
         | 
| 21 | 
            +
                                    
         | 
| 22 | 
            +
                args = parser.parse_args()
         | 
| 23 | 
            +
                if args.interval < 0:
         | 
| 24 | 
            +
                    args.interval = 10 if args.data_path.endswith('.mp4') else 1
         | 
| 25 | 
            +
                print(f'Sampling interval: {args.interval}')
         | 
| 26 | 
            +
             | 
| 27 | 
            +
                # from pi3.utils.debug import setup_debug
         | 
| 28 | 
            +
                # setup_debug()
         | 
| 29 | 
            +
             | 
| 30 | 
            +
                # 1. Prepare model
         | 
| 31 | 
            +
                print(f"Loading model...")
         | 
| 32 | 
            +
                device = torch.device(args.device)
         | 
| 33 | 
            +
                if args.ckpt is not None:
         | 
| 34 | 
            +
                    model = Pi3().to(device).eval()
         | 
| 35 | 
            +
                    if args.ckpt.endswith('.safetensors'):
         | 
| 36 | 
            +
                        from safetensors.torch import load_file
         | 
| 37 | 
            +
                        weight = load_file(args.ckpt)
         | 
| 38 | 
            +
                    else:
         | 
| 39 | 
            +
                        weight = torch.load(args.ckpt, map_location=device, weights_only=False)
         | 
| 40 | 
            +
                    
         | 
| 41 | 
            +
                    model.load_state_dict(weight)
         | 
| 42 | 
            +
                else:
         | 
| 43 | 
            +
                    model = Pi3.from_pretrained("yyfz233/Pi3").to(device).eval()
         | 
| 44 | 
            +
                    # or
         | 
| 45 | 
            +
                    # model = Pi3().to(device).eval()
         | 
| 46 | 
            +
                    # _URL = "https://huggingface.co/yyfz233/Pi3/resolve/main/model.safetensors"
         | 
| 47 | 
            +
                    # model.load_state_dict(torch.hub.load_state_dict_from_url(_URL))
         | 
| 48 | 
            +
             | 
| 49 | 
            +
                # 2. Prepare input data
         | 
| 50 | 
            +
                # The load_images_as_tensor function will print the loading path
         | 
| 51 | 
            +
                imgs = load_images_as_tensor(args.data_path, interval=args.interval).to(device) # (N, 3, H, W)
         | 
| 52 | 
            +
             | 
| 53 | 
            +
                # 3. Infer
         | 
| 54 | 
            +
                print("Running model inference...")
         | 
| 55 | 
            +
                with torch.no_grad():
         | 
| 56 | 
            +
                    with torch.amp.autocast('cuda', dtype=torch.bfloat16):
         | 
| 57 | 
            +
                        res = model(imgs[None]) # Add batch dimension
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                # 4. process mask
         | 
| 60 | 
            +
                masks = torch.sigmoid(res['conf'][..., 0]) > 0.1
         | 
| 61 | 
            +
                non_edge = ~depth_edge(res['local_points'][..., 2], rtol=0.03)
         | 
| 62 | 
            +
                masks = torch.logical_and(masks, non_edge)[0]
         | 
| 63 | 
            +
             | 
| 64 | 
            +
                # 5. Save points
         | 
| 65 | 
            +
                print(f"Saving point cloud to: {args.save_path}")
         | 
| 66 | 
            +
                write_ply(res['points'][0][masks].cpu(), imgs.permute(0, 2, 3, 1)[masks], args.save_path)
         | 
| 67 | 
            +
                print("Done.")
         | 
    	
        examples/cartoon_horse.mp4
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:4edbaec5cc3e747f73f9c01e3a3bef298dc17bb9d2d06d9a9dcd34074f69a73a
         | 
| 3 | 
            +
            size 2314425
         | 
    	
        examples/gradio_examples/house.mp4
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:a77b5e44ac7db0e8d1393ef727ef840580bbd3b61fab9b33f6e1df208a7f804f
         | 
| 3 | 
            +
            size 412888
         | 
    	
        examples/gradio_examples/man_walking_long.mp4
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:f2a445a27bff4eced24cf44168b7a79e1b486d8df74cf6ca2800097e254b237b
         | 
| 3 | 
            +
            size 952503
         | 
    	
        examples/gradio_examples/parkour.mp4
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:1618c992c2f87fe019e90e489457cec0b0e2bcbc8e27768cd36a08fdf7e923db
         | 
| 3 | 
            +
            size 413960
         | 
    	
        examples/gradio_examples/valley.mp4
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:01a0b3d4eafc9d012b59f120e408a909ea4f9ab0e2f01f144d24e34af245e129
         | 
| 3 | 
            +
            size 710743
         | 
    	
        examples/parkour_long.mp4
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:a27bde28b0a0466e20bbc66a69fcde1c2625c34ee0465fb7cb5d4b4c2fd973fa
         | 
| 3 | 
            +
            size 5316123
         | 
    	
        examples/skating.mp4
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:fe64b74f942806e413515fdb79f7825a19ed4e77affb61652dcb554d7b54e05d
         | 
| 3 | 
            +
            size 1438490
         | 
    	
        examples/skiing.mp4
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:65e87d6fdda621f91a5e57e2b90641a6dcf28cbada9b890db8a17fe94a9958e1
         | 
| 3 | 
            +
            size 7960567
         | 
    	
        pi3/models/dinov2/__init__.py
    ADDED
    
    | @@ -0,0 +1,6 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Meta Platforms, Inc. and affiliates.
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # This source code is licensed under the Apache License, Version 2.0
         | 
| 4 | 
            +
            # found in the LICENSE file in the root directory of this source tree.
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            __version__ = "0.0.1"
         | 
    	
        pi3/models/dinov2/hub/__init__.py
    ADDED
    
    | @@ -0,0 +1,4 @@ | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Meta Platforms, Inc. and affiliates.
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # This source code is licensed under the Apache License, Version 2.0
         | 
| 4 | 
            +
            # found in the LICENSE file in the root directory of this source tree.
         | 
    	
        pi3/models/dinov2/hub/backbones.py
    ADDED
    
    | @@ -0,0 +1,156 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Meta Platforms, Inc. and affiliates.
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # This source code is licensed under the Apache License, Version 2.0
         | 
| 4 | 
            +
            # found in the LICENSE file in the root directory of this source tree.
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            from enum import Enum
         | 
| 7 | 
            +
            from typing import Union
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            import torch
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            from .utils import _DINOV2_BASE_URL, _make_dinov2_model_name
         | 
| 12 | 
            +
             | 
| 13 | 
            +
             | 
| 14 | 
            +
            class Weights(Enum):
         | 
| 15 | 
            +
                LVD142M = "LVD142M"
         | 
| 16 | 
            +
             | 
| 17 | 
            +
             | 
| 18 | 
            +
            def _make_dinov2_model(
         | 
| 19 | 
            +
                *,
         | 
| 20 | 
            +
                arch_name: str = "vit_large",
         | 
| 21 | 
            +
                img_size: int = 518,
         | 
| 22 | 
            +
                patch_size: int = 14,
         | 
| 23 | 
            +
                init_values: float = 1.0,
         | 
| 24 | 
            +
                ffn_layer: str = "mlp",
         | 
| 25 | 
            +
                block_chunks: int = 0,
         | 
| 26 | 
            +
                num_register_tokens: int = 0,
         | 
| 27 | 
            +
                interpolate_antialias: bool = False,
         | 
| 28 | 
            +
                interpolate_offset: float = 0.1,
         | 
| 29 | 
            +
                pretrained: bool = True,
         | 
| 30 | 
            +
                weights: Union[Weights, str] = Weights.LVD142M,
         | 
| 31 | 
            +
                **kwargs,
         | 
| 32 | 
            +
            ):
         | 
| 33 | 
            +
                from ..models import vision_transformer as vits
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                if isinstance(weights, str):
         | 
| 36 | 
            +
                    try:
         | 
| 37 | 
            +
                        weights = Weights[weights]
         | 
| 38 | 
            +
                    except KeyError:
         | 
| 39 | 
            +
                        raise AssertionError(f"Unsupported weights: {weights}")
         | 
| 40 | 
            +
             | 
| 41 | 
            +
                model_base_name = _make_dinov2_model_name(arch_name, patch_size)
         | 
| 42 | 
            +
                vit_kwargs = dict(
         | 
| 43 | 
            +
                    img_size=img_size,
         | 
| 44 | 
            +
                    patch_size=patch_size,
         | 
| 45 | 
            +
                    init_values=init_values,
         | 
| 46 | 
            +
                    ffn_layer=ffn_layer,
         | 
| 47 | 
            +
                    block_chunks=block_chunks,
         | 
| 48 | 
            +
                    num_register_tokens=num_register_tokens,
         | 
| 49 | 
            +
                    interpolate_antialias=interpolate_antialias,
         | 
| 50 | 
            +
                    interpolate_offset=interpolate_offset,
         | 
| 51 | 
            +
                )
         | 
| 52 | 
            +
                vit_kwargs.update(**kwargs)
         | 
| 53 | 
            +
                model = vits.__dict__[arch_name](**vit_kwargs)
         | 
| 54 | 
            +
             | 
| 55 | 
            +
                if pretrained:
         | 
| 56 | 
            +
                    model_full_name = _make_dinov2_model_name(arch_name, patch_size, num_register_tokens)
         | 
| 57 | 
            +
                    url = _DINOV2_BASE_URL + f"/{model_base_name}/{model_full_name}_pretrain.pth"
         | 
| 58 | 
            +
                    state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu")
         | 
| 59 | 
            +
                    model.load_state_dict(state_dict, strict=True)
         | 
| 60 | 
            +
             | 
| 61 | 
            +
                return model
         | 
| 62 | 
            +
             | 
| 63 | 
            +
             | 
| 64 | 
            +
            def dinov2_vits14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
         | 
| 65 | 
            +
                """
         | 
| 66 | 
            +
                DINOv2 ViT-S/14 model (optionally) pretrained on the LVD-142M dataset.
         | 
| 67 | 
            +
                """
         | 
| 68 | 
            +
                return _make_dinov2_model(arch_name="vit_small", pretrained=pretrained, weights=weights, **kwargs)
         | 
| 69 | 
            +
             | 
| 70 | 
            +
             | 
| 71 | 
            +
            def dinov2_vitb14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
         | 
| 72 | 
            +
                """
         | 
| 73 | 
            +
                DINOv2 ViT-B/14 model (optionally) pretrained on the LVD-142M dataset.
         | 
| 74 | 
            +
                """
         | 
| 75 | 
            +
                return _make_dinov2_model(arch_name="vit_base", pretrained=pretrained, weights=weights, **kwargs)
         | 
| 76 | 
            +
             | 
| 77 | 
            +
             | 
| 78 | 
            +
            def dinov2_vitl14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
         | 
| 79 | 
            +
                """
         | 
| 80 | 
            +
                DINOv2 ViT-L/14 model (optionally) pretrained on the LVD-142M dataset.
         | 
| 81 | 
            +
                """
         | 
| 82 | 
            +
                return _make_dinov2_model(arch_name="vit_large", pretrained=pretrained, weights=weights, **kwargs)
         | 
| 83 | 
            +
             | 
| 84 | 
            +
             | 
| 85 | 
            +
            def dinov2_vitg14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
         | 
| 86 | 
            +
                """
         | 
| 87 | 
            +
                DINOv2 ViT-g/14 model (optionally) pretrained on the LVD-142M dataset.
         | 
| 88 | 
            +
                """
         | 
| 89 | 
            +
                return _make_dinov2_model(
         | 
| 90 | 
            +
                    arch_name="vit_giant2",
         | 
| 91 | 
            +
                    ffn_layer="swiglufused",
         | 
| 92 | 
            +
                    weights=weights,
         | 
| 93 | 
            +
                    pretrained=pretrained,
         | 
| 94 | 
            +
                    **kwargs,
         | 
| 95 | 
            +
                )
         | 
| 96 | 
            +
             | 
| 97 | 
            +
             | 
| 98 | 
            +
            def dinov2_vits14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
         | 
| 99 | 
            +
                """
         | 
| 100 | 
            +
                DINOv2 ViT-S/14 model with registers (optionally) pretrained on the LVD-142M dataset.
         | 
| 101 | 
            +
                """
         | 
| 102 | 
            +
                return _make_dinov2_model(
         | 
| 103 | 
            +
                    arch_name="vit_small",
         | 
| 104 | 
            +
                    pretrained=pretrained,
         | 
| 105 | 
            +
                    weights=weights,
         | 
| 106 | 
            +
                    num_register_tokens=4,
         | 
| 107 | 
            +
                    interpolate_antialias=True,
         | 
| 108 | 
            +
                    interpolate_offset=0.0,
         | 
| 109 | 
            +
                    **kwargs,
         | 
| 110 | 
            +
                )
         | 
| 111 | 
            +
             | 
| 112 | 
            +
             | 
| 113 | 
            +
            def dinov2_vitb14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
         | 
| 114 | 
            +
                """
         | 
| 115 | 
            +
                DINOv2 ViT-B/14 model with registers (optionally) pretrained on the LVD-142M dataset.
         | 
| 116 | 
            +
                """
         | 
| 117 | 
            +
                return _make_dinov2_model(
         | 
| 118 | 
            +
                    arch_name="vit_base",
         | 
| 119 | 
            +
                    pretrained=pretrained,
         | 
| 120 | 
            +
                    weights=weights,
         | 
| 121 | 
            +
                    num_register_tokens=4,
         | 
| 122 | 
            +
                    interpolate_antialias=True,
         | 
| 123 | 
            +
                    interpolate_offset=0.0,
         | 
| 124 | 
            +
                    **kwargs,
         | 
| 125 | 
            +
                )
         | 
| 126 | 
            +
             | 
| 127 | 
            +
             | 
| 128 | 
            +
            def dinov2_vitl14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
         | 
| 129 | 
            +
                """
         | 
| 130 | 
            +
                DINOv2 ViT-L/14 model with registers (optionally) pretrained on the LVD-142M dataset.
         | 
| 131 | 
            +
                """
         | 
| 132 | 
            +
                return _make_dinov2_model(
         | 
| 133 | 
            +
                    arch_name="vit_large",
         | 
| 134 | 
            +
                    pretrained=pretrained,
         | 
| 135 | 
            +
                    weights=weights,
         | 
| 136 | 
            +
                    num_register_tokens=4,
         | 
| 137 | 
            +
                    interpolate_antialias=True,
         | 
| 138 | 
            +
                    interpolate_offset=0.0,
         | 
| 139 | 
            +
                    **kwargs,
         | 
| 140 | 
            +
                )
         | 
| 141 | 
            +
             | 
| 142 | 
            +
             | 
| 143 | 
            +
            def dinov2_vitg14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
         | 
| 144 | 
            +
                """
         | 
| 145 | 
            +
                DINOv2 ViT-g/14 model with registers (optionally) pretrained on the LVD-142M dataset.
         | 
| 146 | 
            +
                """
         | 
| 147 | 
            +
                return _make_dinov2_model(
         | 
| 148 | 
            +
                    arch_name="vit_giant2",
         | 
| 149 | 
            +
                    ffn_layer="swiglufused",
         | 
| 150 | 
            +
                    weights=weights,
         | 
| 151 | 
            +
                    pretrained=pretrained,
         | 
| 152 | 
            +
                    num_register_tokens=4,
         | 
| 153 | 
            +
                    interpolate_antialias=True,
         | 
| 154 | 
            +
                    interpolate_offset=0.0,
         | 
| 155 | 
            +
                    **kwargs,
         | 
| 156 | 
            +
                )
         | 
    	
        pi3/models/dinov2/hub/utils.py
    ADDED
    
    | @@ -0,0 +1,39 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Meta Platforms, Inc. and affiliates.
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # This source code is licensed under the Apache License, Version 2.0
         | 
| 4 | 
            +
            # found in the LICENSE file in the root directory of this source tree.
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            import itertools
         | 
| 7 | 
            +
            import math
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            import torch
         | 
| 10 | 
            +
            import torch.nn as nn
         | 
| 11 | 
            +
            import torch.nn.functional as F
         | 
| 12 | 
            +
             | 
| 13 | 
            +
             | 
| 14 | 
            +
            _DINOV2_BASE_URL = "https://dl.fbaipublicfiles.com/dinov2"
         | 
| 15 | 
            +
             | 
| 16 | 
            +
             | 
| 17 | 
            +
            def _make_dinov2_model_name(arch_name: str, patch_size: int, num_register_tokens: int = 0) -> str:
         | 
| 18 | 
            +
                compact_arch_name = arch_name.replace("_", "")[:4]
         | 
| 19 | 
            +
                registers_suffix = f"_reg{num_register_tokens}" if num_register_tokens else ""
         | 
| 20 | 
            +
                return f"dinov2_{compact_arch_name}{patch_size}{registers_suffix}"
         | 
| 21 | 
            +
             | 
| 22 | 
            +
             | 
| 23 | 
            +
            class CenterPadding(nn.Module):
         | 
| 24 | 
            +
                def __init__(self, multiple):
         | 
| 25 | 
            +
                    super().__init__()
         | 
| 26 | 
            +
                    self.multiple = multiple
         | 
| 27 | 
            +
             | 
| 28 | 
            +
                def _get_pad(self, size):
         | 
| 29 | 
            +
                    new_size = math.ceil(size / self.multiple) * self.multiple
         | 
| 30 | 
            +
                    pad_size = new_size - size
         | 
| 31 | 
            +
                    pad_size_left = pad_size // 2
         | 
| 32 | 
            +
                    pad_size_right = pad_size - pad_size_left
         | 
| 33 | 
            +
                    return pad_size_left, pad_size_right
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                @torch.inference_mode()
         | 
| 36 | 
            +
                def forward(self, x):
         | 
| 37 | 
            +
                    pads = list(itertools.chain.from_iterable(self._get_pad(m) for m in x.shape[:1:-1]))
         | 
| 38 | 
            +
                    output = F.pad(x, pads)
         | 
| 39 | 
            +
                    return output
         | 
    	
        pi3/models/dinov2/layers/__init__.py
    ADDED
    
    | @@ -0,0 +1,11 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Meta Platforms, Inc. and affiliates.
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # This source code is licensed under the Apache License, Version 2.0
         | 
| 4 | 
            +
            # found in the LICENSE file in the root directory of this source tree.
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            from .dino_head import DINOHead
         | 
| 7 | 
            +
            from .mlp import Mlp
         | 
| 8 | 
            +
            from .patch_embed import PatchEmbed
         | 
| 9 | 
            +
            from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
         | 
| 10 | 
            +
            from .block import NestedTensorBlock
         | 
| 11 | 
            +
            from .attention import MemEffAttention
         | 
    	
        pi3/models/dinov2/layers/attention.py
    ADDED
    
    | @@ -0,0 +1,89 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Meta Platforms, Inc. and affiliates.
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # This source code is licensed under the Apache License, Version 2.0
         | 
| 4 | 
            +
            # found in the LICENSE file in the root directory of this source tree.
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            # References:
         | 
| 7 | 
            +
            #   https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
         | 
| 8 | 
            +
            #   https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            import logging
         | 
| 11 | 
            +
            import os
         | 
| 12 | 
            +
            import warnings
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            from torch import Tensor
         | 
| 15 | 
            +
            from torch import nn
         | 
| 16 | 
            +
             | 
| 17 | 
            +
             | 
| 18 | 
            +
            logger = logging.getLogger("dinov2")
         | 
| 19 | 
            +
             | 
| 20 | 
            +
             | 
| 21 | 
            +
            XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
         | 
| 22 | 
            +
            try:
         | 
| 23 | 
            +
                if XFORMERS_ENABLED:
         | 
| 24 | 
            +
                    from xformers.ops import memory_efficient_attention, unbind
         | 
| 25 | 
            +
             | 
| 26 | 
            +
                    XFORMERS_AVAILABLE = True
         | 
| 27 | 
            +
                    # warnings.warn("xFormers is available (Attention)")
         | 
| 28 | 
            +
                else:
         | 
| 29 | 
            +
                    # warnings.warn("xFormers is disabled (Attention)")
         | 
| 30 | 
            +
                    raise ImportError
         | 
| 31 | 
            +
            except ImportError:
         | 
| 32 | 
            +
                XFORMERS_AVAILABLE = False
         | 
| 33 | 
            +
                # warnings.warn("xFormers is not available (Attention)")
         | 
| 34 | 
            +
             | 
| 35 | 
            +
             | 
| 36 | 
            +
            class Attention(nn.Module):
         | 
| 37 | 
            +
                def __init__(
         | 
| 38 | 
            +
                    self,
         | 
| 39 | 
            +
                    dim: int,
         | 
| 40 | 
            +
                    num_heads: int = 8,
         | 
| 41 | 
            +
                    qkv_bias: bool = False,
         | 
| 42 | 
            +
                    proj_bias: bool = True,
         | 
| 43 | 
            +
                    attn_drop: float = 0.0,
         | 
| 44 | 
            +
                    proj_drop: float = 0.0,
         | 
| 45 | 
            +
                ) -> None:
         | 
| 46 | 
            +
                    super().__init__()
         | 
| 47 | 
            +
                    self.num_heads = num_heads
         | 
| 48 | 
            +
                    head_dim = dim // num_heads
         | 
| 49 | 
            +
                    self.scale = head_dim**-0.5
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                    self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
         | 
| 52 | 
            +
                    self.attn_drop = nn.Dropout(attn_drop)
         | 
| 53 | 
            +
                    self.proj = nn.Linear(dim, dim, bias=proj_bias)
         | 
| 54 | 
            +
                    self.proj_drop = nn.Dropout(proj_drop)
         | 
| 55 | 
            +
             | 
| 56 | 
            +
                def forward(self, x: Tensor, attn_bias=None) -> Tensor:
         | 
| 57 | 
            +
                    B, N, C = x.shape
         | 
| 58 | 
            +
                    qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
         | 
| 59 | 
            +
                    
         | 
| 60 | 
            +
                    q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
         | 
| 61 | 
            +
                    attn = q @ k.transpose(-2, -1)
         | 
| 62 | 
            +
             | 
| 63 | 
            +
                    attn = attn.softmax(dim=-1)
         | 
| 64 | 
            +
                    attn = self.attn_drop(attn)
         | 
| 65 | 
            +
             | 
| 66 | 
            +
                    x = (attn @ v).transpose(1, 2).reshape(B, N, C)
         | 
| 67 | 
            +
                    x = self.proj(x)
         | 
| 68 | 
            +
                    x = self.proj_drop(x)
         | 
| 69 | 
            +
                    return x
         | 
| 70 | 
            +
             | 
| 71 | 
            +
             | 
| 72 | 
            +
            class MemEffAttention(Attention):
         | 
| 73 | 
            +
                def forward(self, x: Tensor, attn_bias=None) -> Tensor:
         | 
| 74 | 
            +
                    if not XFORMERS_AVAILABLE:
         | 
| 75 | 
            +
                        if attn_bias is not None:
         | 
| 76 | 
            +
                            raise AssertionError("xFormers is required for using nested tensors")
         | 
| 77 | 
            +
                        return super().forward(x)
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                    B, N, C = x.shape
         | 
| 80 | 
            +
                    qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
         | 
| 81 | 
            +
             | 
| 82 | 
            +
                    q, k, v = unbind(qkv, 2)
         | 
| 83 | 
            +
             | 
| 84 | 
            +
                    x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
         | 
| 85 | 
            +
                    x = x.reshape([B, N, C])
         | 
| 86 | 
            +
             | 
| 87 | 
            +
                    x = self.proj(x)
         | 
| 88 | 
            +
                    x = self.proj_drop(x)
         | 
| 89 | 
            +
                    return x
         | 
    	
        pi3/models/dinov2/layers/block.py
    ADDED
    
    | @@ -0,0 +1,259 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Meta Platforms, Inc. and affiliates.
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # This source code is licensed under the Apache License, Version 2.0
         | 
| 4 | 
            +
            # found in the LICENSE file in the root directory of this source tree.
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            # References:
         | 
| 7 | 
            +
            #   https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
         | 
| 8 | 
            +
            #   https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            import logging
         | 
| 11 | 
            +
            import os
         | 
| 12 | 
            +
            from typing import Callable, List, Any, Tuple, Dict
         | 
| 13 | 
            +
            import warnings
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            import torch
         | 
| 16 | 
            +
            from torch import nn, Tensor
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            from .attention import Attention, MemEffAttention
         | 
| 19 | 
            +
            from .drop_path import DropPath
         | 
| 20 | 
            +
            from .layer_scale import LayerScale
         | 
| 21 | 
            +
            from .mlp import Mlp
         | 
| 22 | 
            +
             | 
| 23 | 
            +
             | 
| 24 | 
            +
            logger = logging.getLogger("dinov2")
         | 
| 25 | 
            +
             | 
| 26 | 
            +
             | 
| 27 | 
            +
            XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
         | 
| 28 | 
            +
            try:
         | 
| 29 | 
            +
                if XFORMERS_ENABLED:
         | 
| 30 | 
            +
                    from xformers.ops import fmha, scaled_index_add, index_select_cat
         | 
| 31 | 
            +
             | 
| 32 | 
            +
                    XFORMERS_AVAILABLE = True
         | 
| 33 | 
            +
                    # warnings.warn("xFormers is available (Block)")
         | 
| 34 | 
            +
                else:
         | 
| 35 | 
            +
                    # warnings.warn("xFormers is disabled (Block)")
         | 
| 36 | 
            +
                    raise ImportError
         | 
| 37 | 
            +
            except ImportError:
         | 
| 38 | 
            +
                XFORMERS_AVAILABLE = False
         | 
| 39 | 
            +
                # warnings.warn("xFormers is not available (Block)")
         | 
| 40 | 
            +
             | 
| 41 | 
            +
             | 
| 42 | 
            +
            class Block(nn.Module):
         | 
| 43 | 
            +
                def __init__(
         | 
| 44 | 
            +
                    self,
         | 
| 45 | 
            +
                    dim: int,
         | 
| 46 | 
            +
                    num_heads: int,
         | 
| 47 | 
            +
                    mlp_ratio: float = 4.0,
         | 
| 48 | 
            +
                    qkv_bias: bool = False,
         | 
| 49 | 
            +
                    proj_bias: bool = True,
         | 
| 50 | 
            +
                    ffn_bias: bool = True,
         | 
| 51 | 
            +
                    drop: float = 0.0,
         | 
| 52 | 
            +
                    attn_drop: float = 0.0,
         | 
| 53 | 
            +
                    init_values=None,
         | 
| 54 | 
            +
                    drop_path: float = 0.0,
         | 
| 55 | 
            +
                    act_layer: Callable[..., nn.Module] = nn.GELU,
         | 
| 56 | 
            +
                    norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
         | 
| 57 | 
            +
                    attn_class: Callable[..., nn.Module] = Attention,
         | 
| 58 | 
            +
                    ffn_layer: Callable[..., nn.Module] = Mlp,
         | 
| 59 | 
            +
                ) -> None:
         | 
| 60 | 
            +
                    super().__init__()
         | 
| 61 | 
            +
                    # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
         | 
| 62 | 
            +
                    self.norm1 = norm_layer(dim)
         | 
| 63 | 
            +
                    self.attn = attn_class(
         | 
| 64 | 
            +
                        dim,
         | 
| 65 | 
            +
                        num_heads=num_heads,
         | 
| 66 | 
            +
                        qkv_bias=qkv_bias,
         | 
| 67 | 
            +
                        proj_bias=proj_bias,
         | 
| 68 | 
            +
                        attn_drop=attn_drop,
         | 
| 69 | 
            +
                        proj_drop=drop,
         | 
| 70 | 
            +
                    )
         | 
| 71 | 
            +
                    self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
         | 
| 72 | 
            +
                    self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
         | 
| 73 | 
            +
             | 
| 74 | 
            +
                    self.norm2 = norm_layer(dim)
         | 
| 75 | 
            +
                    mlp_hidden_dim = int(dim * mlp_ratio)
         | 
| 76 | 
            +
                    self.mlp = ffn_layer(
         | 
| 77 | 
            +
                        in_features=dim,
         | 
| 78 | 
            +
                        hidden_features=mlp_hidden_dim,
         | 
| 79 | 
            +
                        act_layer=act_layer,
         | 
| 80 | 
            +
                        drop=drop,
         | 
| 81 | 
            +
                        bias=ffn_bias,
         | 
| 82 | 
            +
                    )
         | 
| 83 | 
            +
                    self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
         | 
| 84 | 
            +
                    self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
         | 
| 85 | 
            +
             | 
| 86 | 
            +
                    self.sample_drop_ratio = drop_path
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                def forward(self, x: Tensor) -> Tensor:
         | 
| 89 | 
            +
                    def attn_residual_func(x: Tensor) -> Tensor:
         | 
| 90 | 
            +
                        return self.ls1(self.attn(self.norm1(x)))
         | 
| 91 | 
            +
             | 
| 92 | 
            +
                    def ffn_residual_func(x: Tensor) -> Tensor:
         | 
| 93 | 
            +
                        return self.ls2(self.mlp(self.norm2(x)))
         | 
| 94 | 
            +
             | 
| 95 | 
            +
                    if self.training and self.sample_drop_ratio > 0.1:
         | 
| 96 | 
            +
                        # the overhead is compensated only for a drop path rate larger than 0.1
         | 
| 97 | 
            +
                        x = drop_add_residual_stochastic_depth(
         | 
| 98 | 
            +
                            x,
         | 
| 99 | 
            +
                            residual_func=attn_residual_func,
         | 
| 100 | 
            +
                            sample_drop_ratio=self.sample_drop_ratio,
         | 
| 101 | 
            +
                        )
         | 
| 102 | 
            +
                        x = drop_add_residual_stochastic_depth(
         | 
| 103 | 
            +
                            x,
         | 
| 104 | 
            +
                            residual_func=ffn_residual_func,
         | 
| 105 | 
            +
                            sample_drop_ratio=self.sample_drop_ratio,
         | 
| 106 | 
            +
                        )
         | 
| 107 | 
            +
                    elif self.training and self.sample_drop_ratio > 0.0:
         | 
| 108 | 
            +
                        x = x + self.drop_path1(attn_residual_func(x))
         | 
| 109 | 
            +
                        x = x + self.drop_path1(ffn_residual_func(x))  # FIXME: drop_path2
         | 
| 110 | 
            +
                    else:
         | 
| 111 | 
            +
                        x = x + attn_residual_func(x)
         | 
| 112 | 
            +
                        x = x + ffn_residual_func(x)
         | 
| 113 | 
            +
                    return x
         | 
| 114 | 
            +
             | 
| 115 | 
            +
             | 
| 116 | 
            +
            def drop_add_residual_stochastic_depth(
         | 
| 117 | 
            +
                x: Tensor,
         | 
| 118 | 
            +
                residual_func: Callable[[Tensor], Tensor],
         | 
| 119 | 
            +
                sample_drop_ratio: float = 0.0,
         | 
| 120 | 
            +
            ) -> Tensor:
         | 
| 121 | 
            +
                # 1) extract subset using permutation
         | 
| 122 | 
            +
                b, n, d = x.shape
         | 
| 123 | 
            +
                sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
         | 
| 124 | 
            +
                brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
         | 
| 125 | 
            +
                x_subset = x[brange]
         | 
| 126 | 
            +
             | 
| 127 | 
            +
                # 2) apply residual_func to get residual
         | 
| 128 | 
            +
                residual = residual_func(x_subset)
         | 
| 129 | 
            +
             | 
| 130 | 
            +
                x_flat = x.flatten(1)
         | 
| 131 | 
            +
                residual = residual.flatten(1)
         | 
| 132 | 
            +
             | 
| 133 | 
            +
                residual_scale_factor = b / sample_subset_size
         | 
| 134 | 
            +
             | 
| 135 | 
            +
                # 3) add the residual
         | 
| 136 | 
            +
                x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
         | 
| 137 | 
            +
                return x_plus_residual.view_as(x)
         | 
| 138 | 
            +
             | 
| 139 | 
            +
             | 
| 140 | 
            +
            def get_branges_scales(x, sample_drop_ratio=0.0):
         | 
| 141 | 
            +
                b, n, d = x.shape
         | 
| 142 | 
            +
                sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
         | 
| 143 | 
            +
                brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
         | 
| 144 | 
            +
                residual_scale_factor = b / sample_subset_size
         | 
| 145 | 
            +
                return brange, residual_scale_factor
         | 
| 146 | 
            +
             | 
| 147 | 
            +
             | 
| 148 | 
            +
            def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
         | 
| 149 | 
            +
                if scaling_vector is None:
         | 
| 150 | 
            +
                    x_flat = x.flatten(1)
         | 
| 151 | 
            +
                    residual = residual.flatten(1)
         | 
| 152 | 
            +
                    x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
         | 
| 153 | 
            +
                else:
         | 
| 154 | 
            +
                    x_plus_residual = scaled_index_add(
         | 
| 155 | 
            +
                        x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
         | 
| 156 | 
            +
                    )
         | 
| 157 | 
            +
                return x_plus_residual
         | 
| 158 | 
            +
             | 
| 159 | 
            +
             | 
| 160 | 
            +
            attn_bias_cache: Dict[Tuple, Any] = {}
         | 
| 161 | 
            +
             | 
| 162 | 
            +
             | 
| 163 | 
            +
            def get_attn_bias_and_cat(x_list, branges=None):
         | 
| 164 | 
            +
                """
         | 
| 165 | 
            +
                this will perform the index select, cat the tensors, and provide the attn_bias from cache
         | 
| 166 | 
            +
                """
         | 
| 167 | 
            +
                batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
         | 
| 168 | 
            +
                all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
         | 
| 169 | 
            +
                if all_shapes not in attn_bias_cache.keys():
         | 
| 170 | 
            +
                    seqlens = []
         | 
| 171 | 
            +
                    for b, x in zip(batch_sizes, x_list):
         | 
| 172 | 
            +
                        for _ in range(b):
         | 
| 173 | 
            +
                            seqlens.append(x.shape[1])
         | 
| 174 | 
            +
                    attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
         | 
| 175 | 
            +
                    attn_bias._batch_sizes = batch_sizes
         | 
| 176 | 
            +
                    attn_bias_cache[all_shapes] = attn_bias
         | 
| 177 | 
            +
             | 
| 178 | 
            +
                if branges is not None:
         | 
| 179 | 
            +
                    cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
         | 
| 180 | 
            +
                else:
         | 
| 181 | 
            +
                    tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
         | 
| 182 | 
            +
                    cat_tensors = torch.cat(tensors_bs1, dim=1)
         | 
| 183 | 
            +
             | 
| 184 | 
            +
                return attn_bias_cache[all_shapes], cat_tensors
         | 
| 185 | 
            +
             | 
| 186 | 
            +
             | 
| 187 | 
            +
            def drop_add_residual_stochastic_depth_list(
         | 
| 188 | 
            +
                x_list: List[Tensor],
         | 
| 189 | 
            +
                residual_func: Callable[[Tensor, Any], Tensor],
         | 
| 190 | 
            +
                sample_drop_ratio: float = 0.0,
         | 
| 191 | 
            +
                scaling_vector=None,
         | 
| 192 | 
            +
            ) -> Tensor:
         | 
| 193 | 
            +
                # 1) generate random set of indices for dropping samples in the batch
         | 
| 194 | 
            +
                branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
         | 
| 195 | 
            +
                branges = [s[0] for s in branges_scales]
         | 
| 196 | 
            +
                residual_scale_factors = [s[1] for s in branges_scales]
         | 
| 197 | 
            +
             | 
| 198 | 
            +
                # 2) get attention bias and index+concat the tensors
         | 
| 199 | 
            +
                attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
         | 
| 200 | 
            +
             | 
| 201 | 
            +
                # 3) apply residual_func to get residual, and split the result
         | 
| 202 | 
            +
                residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias))  # type: ignore
         | 
| 203 | 
            +
             | 
| 204 | 
            +
                outputs = []
         | 
| 205 | 
            +
                for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
         | 
| 206 | 
            +
                    outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
         | 
| 207 | 
            +
                return outputs
         | 
| 208 | 
            +
             | 
| 209 | 
            +
             | 
| 210 | 
            +
            class NestedTensorBlock(Block):
         | 
| 211 | 
            +
                def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
         | 
| 212 | 
            +
                    """
         | 
| 213 | 
            +
                    x_list contains a list of tensors to nest together and run
         | 
| 214 | 
            +
                    """
         | 
| 215 | 
            +
                    assert isinstance(self.attn, MemEffAttention)
         | 
| 216 | 
            +
             | 
| 217 | 
            +
                    if self.training and self.sample_drop_ratio > 0.0:
         | 
| 218 | 
            +
             | 
| 219 | 
            +
                        def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
         | 
| 220 | 
            +
                            return self.attn(self.norm1(x), attn_bias=attn_bias)
         | 
| 221 | 
            +
             | 
| 222 | 
            +
                        def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
         | 
| 223 | 
            +
                            return self.mlp(self.norm2(x))
         | 
| 224 | 
            +
             | 
| 225 | 
            +
                        x_list = drop_add_residual_stochastic_depth_list(
         | 
| 226 | 
            +
                            x_list,
         | 
| 227 | 
            +
                            residual_func=attn_residual_func,
         | 
| 228 | 
            +
                            sample_drop_ratio=self.sample_drop_ratio,
         | 
| 229 | 
            +
                            scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
         | 
| 230 | 
            +
                        )
         | 
| 231 | 
            +
                        x_list = drop_add_residual_stochastic_depth_list(
         | 
| 232 | 
            +
                            x_list,
         | 
| 233 | 
            +
                            residual_func=ffn_residual_func,
         | 
| 234 | 
            +
                            sample_drop_ratio=self.sample_drop_ratio,
         | 
| 235 | 
            +
                            scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
         | 
| 236 | 
            +
                        )
         | 
| 237 | 
            +
                        return x_list
         | 
| 238 | 
            +
                    else:
         | 
| 239 | 
            +
             | 
| 240 | 
            +
                        def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
         | 
| 241 | 
            +
                            return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
         | 
| 242 | 
            +
             | 
| 243 | 
            +
                        def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
         | 
| 244 | 
            +
                            return self.ls2(self.mlp(self.norm2(x)))
         | 
| 245 | 
            +
             | 
| 246 | 
            +
                        attn_bias, x = get_attn_bias_and_cat(x_list)
         | 
| 247 | 
            +
                        x = x + attn_residual_func(x, attn_bias=attn_bias)
         | 
| 248 | 
            +
                        x = x + ffn_residual_func(x)
         | 
| 249 | 
            +
                        return attn_bias.split(x)
         | 
| 250 | 
            +
             | 
| 251 | 
            +
                def forward(self, x_or_x_list):
         | 
| 252 | 
            +
                    if isinstance(x_or_x_list, Tensor):
         | 
| 253 | 
            +
                        return super().forward(x_or_x_list)
         | 
| 254 | 
            +
                    elif isinstance(x_or_x_list, list):
         | 
| 255 | 
            +
                        if not XFORMERS_AVAILABLE:
         | 
| 256 | 
            +
                            raise AssertionError("xFormers is required for using nested tensors")
         | 
| 257 | 
            +
                        return self.forward_nested(x_or_x_list)
         | 
| 258 | 
            +
                    else:
         | 
| 259 | 
            +
                        raise AssertionError
         | 
    	
        pi3/models/dinov2/layers/dino_head.py
    ADDED
    
    | @@ -0,0 +1,58 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Meta Platforms, Inc. and affiliates.
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # This source code is licensed under the Apache License, Version 2.0
         | 
| 4 | 
            +
            # found in the LICENSE file in the root directory of this source tree.
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            import torch
         | 
| 7 | 
            +
            import torch.nn as nn
         | 
| 8 | 
            +
            from torch.nn.init import trunc_normal_
         | 
| 9 | 
            +
            from torch.nn.utils import weight_norm
         | 
| 10 | 
            +
             | 
| 11 | 
            +
             | 
| 12 | 
            +
            class DINOHead(nn.Module):
         | 
| 13 | 
            +
                def __init__(
         | 
| 14 | 
            +
                    self,
         | 
| 15 | 
            +
                    in_dim,
         | 
| 16 | 
            +
                    out_dim,
         | 
| 17 | 
            +
                    use_bn=False,
         | 
| 18 | 
            +
                    nlayers=3,
         | 
| 19 | 
            +
                    hidden_dim=2048,
         | 
| 20 | 
            +
                    bottleneck_dim=256,
         | 
| 21 | 
            +
                    mlp_bias=True,
         | 
| 22 | 
            +
                ):
         | 
| 23 | 
            +
                    super().__init__()
         | 
| 24 | 
            +
                    nlayers = max(nlayers, 1)
         | 
| 25 | 
            +
                    self.mlp = _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=hidden_dim, use_bn=use_bn, bias=mlp_bias)
         | 
| 26 | 
            +
                    self.apply(self._init_weights)
         | 
| 27 | 
            +
                    self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
         | 
| 28 | 
            +
                    self.last_layer.weight_g.data.fill_(1)
         | 
| 29 | 
            +
             | 
| 30 | 
            +
                def _init_weights(self, m):
         | 
| 31 | 
            +
                    if isinstance(m, nn.Linear):
         | 
| 32 | 
            +
                        trunc_normal_(m.weight, std=0.02)
         | 
| 33 | 
            +
                        if isinstance(m, nn.Linear) and m.bias is not None:
         | 
| 34 | 
            +
                            nn.init.constant_(m.bias, 0)
         | 
| 35 | 
            +
             | 
| 36 | 
            +
                def forward(self, x):
         | 
| 37 | 
            +
                    x = self.mlp(x)
         | 
| 38 | 
            +
                    eps = 1e-6 if x.dtype == torch.float16 else 1e-12
         | 
| 39 | 
            +
                    x = nn.functional.normalize(x, dim=-1, p=2, eps=eps)
         | 
| 40 | 
            +
                    x = self.last_layer(x)
         | 
| 41 | 
            +
                    return x
         | 
| 42 | 
            +
             | 
| 43 | 
            +
             | 
| 44 | 
            +
            def _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True):
         | 
| 45 | 
            +
                if nlayers == 1:
         | 
| 46 | 
            +
                    return nn.Linear(in_dim, bottleneck_dim, bias=bias)
         | 
| 47 | 
            +
                else:
         | 
| 48 | 
            +
                    layers = [nn.Linear(in_dim, hidden_dim, bias=bias)]
         | 
| 49 | 
            +
                    if use_bn:
         | 
| 50 | 
            +
                        layers.append(nn.BatchNorm1d(hidden_dim))
         | 
| 51 | 
            +
                    layers.append(nn.GELU())
         | 
| 52 | 
            +
                    for _ in range(nlayers - 2):
         | 
| 53 | 
            +
                        layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias))
         | 
| 54 | 
            +
                        if use_bn:
         | 
| 55 | 
            +
                            layers.append(nn.BatchNorm1d(hidden_dim))
         | 
| 56 | 
            +
                        layers.append(nn.GELU())
         | 
| 57 | 
            +
                    layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias))
         | 
| 58 | 
            +
                    return nn.Sequential(*layers)
         | 
    	
        pi3/models/dinov2/layers/drop_path.py
    ADDED
    
    | @@ -0,0 +1,34 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Meta Platforms, Inc. and affiliates.
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # This source code is licensed under the Apache License, Version 2.0
         | 
| 4 | 
            +
            # found in the LICENSE file in the root directory of this source tree.
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            # References:
         | 
| 7 | 
            +
            #   https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
         | 
| 8 | 
            +
            #   https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
         | 
| 9 | 
            +
             | 
| 10 | 
            +
             | 
| 11 | 
            +
            from torch import nn
         | 
| 12 | 
            +
             | 
| 13 | 
            +
             | 
| 14 | 
            +
            def drop_path(x, drop_prob: float = 0.0, training: bool = False):
         | 
| 15 | 
            +
                if drop_prob == 0.0 or not training:
         | 
| 16 | 
            +
                    return x
         | 
| 17 | 
            +
                keep_prob = 1 - drop_prob
         | 
| 18 | 
            +
                shape = (x.shape[0],) + (1,) * (x.ndim - 1)  # work with diff dim tensors, not just 2D ConvNets
         | 
| 19 | 
            +
                random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
         | 
| 20 | 
            +
                if keep_prob > 0.0:
         | 
| 21 | 
            +
                    random_tensor.div_(keep_prob)
         | 
| 22 | 
            +
                output = x * random_tensor
         | 
| 23 | 
            +
                return output
         | 
| 24 | 
            +
             | 
| 25 | 
            +
             | 
| 26 | 
            +
            class DropPath(nn.Module):
         | 
| 27 | 
            +
                """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
         | 
| 28 | 
            +
             | 
| 29 | 
            +
                def __init__(self, drop_prob=None):
         | 
| 30 | 
            +
                    super(DropPath, self).__init__()
         | 
| 31 | 
            +
                    self.drop_prob = drop_prob
         | 
| 32 | 
            +
             | 
| 33 | 
            +
                def forward(self, x):
         | 
| 34 | 
            +
                    return drop_path(x, self.drop_prob, self.training)
         | 
    	
        pi3/models/dinov2/layers/layer_scale.py
    ADDED
    
    | @@ -0,0 +1,27 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Meta Platforms, Inc. and affiliates.
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # This source code is licensed under the Apache License, Version 2.0
         | 
| 4 | 
            +
            # found in the LICENSE file in the root directory of this source tree.
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            # Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            from typing import Union
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            import torch
         | 
| 11 | 
            +
            from torch import Tensor
         | 
| 12 | 
            +
            from torch import nn
         | 
| 13 | 
            +
             | 
| 14 | 
            +
             | 
| 15 | 
            +
            class LayerScale(nn.Module):
         | 
| 16 | 
            +
                def __init__(
         | 
| 17 | 
            +
                    self,
         | 
| 18 | 
            +
                    dim: int,
         | 
| 19 | 
            +
                    init_values: Union[float, Tensor] = 1e-5,
         | 
| 20 | 
            +
                    inplace: bool = False,
         | 
| 21 | 
            +
                ) -> None:
         | 
| 22 | 
            +
                    super().__init__()
         | 
| 23 | 
            +
                    self.inplace = inplace
         | 
| 24 | 
            +
                    self.gamma = nn.Parameter(init_values * torch.ones(dim))
         | 
| 25 | 
            +
             | 
| 26 | 
            +
                def forward(self, x: Tensor) -> Tensor:
         | 
| 27 | 
            +
                    return x.mul_(self.gamma) if self.inplace else x * self.gamma
         | 
    	
        pi3/models/dinov2/layers/mlp.py
    ADDED
    
    | @@ -0,0 +1,40 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Meta Platforms, Inc. and affiliates.
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # This source code is licensed under the Apache License, Version 2.0
         | 
| 4 | 
            +
            # found in the LICENSE file in the root directory of this source tree.
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            # References:
         | 
| 7 | 
            +
            #   https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
         | 
| 8 | 
            +
            #   https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
         | 
| 9 | 
            +
             | 
| 10 | 
            +
             | 
| 11 | 
            +
            from typing import Callable, Optional
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            from torch import Tensor, nn
         | 
| 14 | 
            +
             | 
| 15 | 
            +
             | 
| 16 | 
            +
            class Mlp(nn.Module):
         | 
| 17 | 
            +
                def __init__(
         | 
| 18 | 
            +
                    self,
         | 
| 19 | 
            +
                    in_features: int,
         | 
| 20 | 
            +
                    hidden_features: Optional[int] = None,
         | 
| 21 | 
            +
                    out_features: Optional[int] = None,
         | 
| 22 | 
            +
                    act_layer: Callable[..., nn.Module] = nn.GELU,
         | 
| 23 | 
            +
                    drop: float = 0.0,
         | 
| 24 | 
            +
                    bias: bool = True,
         | 
| 25 | 
            +
                ) -> None:
         | 
| 26 | 
            +
                    super().__init__()
         | 
| 27 | 
            +
                    out_features = out_features or in_features
         | 
| 28 | 
            +
                    hidden_features = hidden_features or in_features
         | 
| 29 | 
            +
                    self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
         | 
| 30 | 
            +
                    self.act = act_layer()
         | 
| 31 | 
            +
                    self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
         | 
| 32 | 
            +
                    self.drop = nn.Dropout(drop)
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                def forward(self, x: Tensor) -> Tensor:
         | 
| 35 | 
            +
                    x = self.fc1(x)
         | 
| 36 | 
            +
                    x = self.act(x)
         | 
| 37 | 
            +
                    x = self.drop(x)
         | 
| 38 | 
            +
                    x = self.fc2(x)
         | 
| 39 | 
            +
                    x = self.drop(x)
         | 
| 40 | 
            +
                    return x
         | 
    	
        pi3/models/dinov2/layers/patch_embed.py
    ADDED
    
    | @@ -0,0 +1,88 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Meta Platforms, Inc. and affiliates.
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # This source code is licensed under the Apache License, Version 2.0
         | 
| 4 | 
            +
            # found in the LICENSE file in the root directory of this source tree.
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            # References:
         | 
| 7 | 
            +
            #   https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
         | 
| 8 | 
            +
            #   https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            from typing import Callable, Optional, Tuple, Union
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            from torch import Tensor
         | 
| 13 | 
            +
            import torch.nn as nn
         | 
| 14 | 
            +
             | 
| 15 | 
            +
             | 
| 16 | 
            +
            def make_2tuple(x):
         | 
| 17 | 
            +
                if isinstance(x, tuple):
         | 
| 18 | 
            +
                    assert len(x) == 2
         | 
| 19 | 
            +
                    return x
         | 
| 20 | 
            +
             | 
| 21 | 
            +
                assert isinstance(x, int)
         | 
| 22 | 
            +
                return (x, x)
         | 
| 23 | 
            +
             | 
| 24 | 
            +
             | 
| 25 | 
            +
            class PatchEmbed(nn.Module):
         | 
| 26 | 
            +
                """
         | 
| 27 | 
            +
                2D image to patch embedding: (B,C,H,W) -> (B,N,D)
         | 
| 28 | 
            +
             | 
| 29 | 
            +
                Args:
         | 
| 30 | 
            +
                    img_size: Image size.
         | 
| 31 | 
            +
                    patch_size: Patch token size.
         | 
| 32 | 
            +
                    in_chans: Number of input image channels.
         | 
| 33 | 
            +
                    embed_dim: Number of linear projection output channels.
         | 
| 34 | 
            +
                    norm_layer: Normalization layer.
         | 
| 35 | 
            +
                """
         | 
| 36 | 
            +
             | 
| 37 | 
            +
                def __init__(
         | 
| 38 | 
            +
                    self,
         | 
| 39 | 
            +
                    img_size: Union[int, Tuple[int, int]] = 224,
         | 
| 40 | 
            +
                    patch_size: Union[int, Tuple[int, int]] = 16,
         | 
| 41 | 
            +
                    in_chans: int = 3,
         | 
| 42 | 
            +
                    embed_dim: int = 768,
         | 
| 43 | 
            +
                    norm_layer: Optional[Callable] = None,
         | 
| 44 | 
            +
                    flatten_embedding: bool = True,
         | 
| 45 | 
            +
                ) -> None:
         | 
| 46 | 
            +
                    super().__init__()
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                    image_HW = make_2tuple(img_size)
         | 
| 49 | 
            +
                    patch_HW = make_2tuple(patch_size)
         | 
| 50 | 
            +
                    patch_grid_size = (
         | 
| 51 | 
            +
                        image_HW[0] // patch_HW[0],
         | 
| 52 | 
            +
                        image_HW[1] // patch_HW[1],
         | 
| 53 | 
            +
                    )
         | 
| 54 | 
            +
             | 
| 55 | 
            +
                    self.img_size = image_HW
         | 
| 56 | 
            +
                    self.patch_size = patch_HW
         | 
| 57 | 
            +
                    self.patches_resolution = patch_grid_size
         | 
| 58 | 
            +
                    self.num_patches = patch_grid_size[0] * patch_grid_size[1]
         | 
| 59 | 
            +
             | 
| 60 | 
            +
                    self.in_chans = in_chans
         | 
| 61 | 
            +
                    self.embed_dim = embed_dim
         | 
| 62 | 
            +
             | 
| 63 | 
            +
                    self.flatten_embedding = flatten_embedding
         | 
| 64 | 
            +
             | 
| 65 | 
            +
                    self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
         | 
| 66 | 
            +
                    self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                def forward(self, x: Tensor) -> Tensor:
         | 
| 69 | 
            +
                    _, _, H, W = x.shape
         | 
| 70 | 
            +
                    patch_H, patch_W = self.patch_size
         | 
| 71 | 
            +
             | 
| 72 | 
            +
                    assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
         | 
| 73 | 
            +
                    assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
         | 
| 74 | 
            +
             | 
| 75 | 
            +
                    x = self.proj(x)  # B C H W
         | 
| 76 | 
            +
                    H, W = x.size(2), x.size(3)
         | 
| 77 | 
            +
                    x = x.flatten(2).transpose(1, 2)  # B HW C
         | 
| 78 | 
            +
                    x = self.norm(x)
         | 
| 79 | 
            +
                    if not self.flatten_embedding:
         | 
| 80 | 
            +
                        x = x.reshape(-1, H, W, self.embed_dim)  # B H W C
         | 
| 81 | 
            +
                    return x
         | 
| 82 | 
            +
             | 
| 83 | 
            +
                def flops(self) -> float:
         | 
| 84 | 
            +
                    Ho, Wo = self.patches_resolution
         | 
| 85 | 
            +
                    flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
         | 
| 86 | 
            +
                    if self.norm is not None:
         | 
| 87 | 
            +
                        flops += Ho * Wo * self.embed_dim
         | 
| 88 | 
            +
                    return flops
         | 
    	
        pi3/models/dinov2/layers/swiglu_ffn.py
    ADDED
    
    | @@ -0,0 +1,72 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Meta Platforms, Inc. and affiliates.
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # This source code is licensed under the Apache License, Version 2.0
         | 
| 4 | 
            +
            # found in the LICENSE file in the root directory of this source tree.
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            import os
         | 
| 7 | 
            +
            from typing import Callable, Optional
         | 
| 8 | 
            +
            import warnings
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            from torch import Tensor, nn
         | 
| 11 | 
            +
            import torch.nn.functional as F
         | 
| 12 | 
            +
             | 
| 13 | 
            +
             | 
| 14 | 
            +
            class SwiGLUFFN(nn.Module):
         | 
| 15 | 
            +
                def __init__(
         | 
| 16 | 
            +
                    self,
         | 
| 17 | 
            +
                    in_features: int,
         | 
| 18 | 
            +
                    hidden_features: Optional[int] = None,
         | 
| 19 | 
            +
                    out_features: Optional[int] = None,
         | 
| 20 | 
            +
                    act_layer: Callable[..., nn.Module] = None,
         | 
| 21 | 
            +
                    drop: float = 0.0,
         | 
| 22 | 
            +
                    bias: bool = True,
         | 
| 23 | 
            +
                ) -> None:
         | 
| 24 | 
            +
                    super().__init__()
         | 
| 25 | 
            +
                    out_features = out_features or in_features
         | 
| 26 | 
            +
                    hidden_features = hidden_features or in_features
         | 
| 27 | 
            +
                    self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
         | 
| 28 | 
            +
                    self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
         | 
| 29 | 
            +
             | 
| 30 | 
            +
                def forward(self, x: Tensor) -> Tensor:
         | 
| 31 | 
            +
                    x12 = self.w12(x)
         | 
| 32 | 
            +
                    x1, x2 = x12.chunk(2, dim=-1)
         | 
| 33 | 
            +
                    hidden = F.silu(x1) * x2
         | 
| 34 | 
            +
                    return self.w3(hidden)
         | 
| 35 | 
            +
             | 
| 36 | 
            +
             | 
| 37 | 
            +
            XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
         | 
| 38 | 
            +
            try:
         | 
| 39 | 
            +
                if XFORMERS_ENABLED:
         | 
| 40 | 
            +
                    from xformers.ops import SwiGLU
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                    XFORMERS_AVAILABLE = True
         | 
| 43 | 
            +
                    # warnings.warn("xFormers is available (SwiGLU)")
         | 
| 44 | 
            +
                else:
         | 
| 45 | 
            +
                    # warnings.warn("xFormers is disabled (SwiGLU)")
         | 
| 46 | 
            +
                    raise ImportError
         | 
| 47 | 
            +
            except ImportError:
         | 
| 48 | 
            +
                SwiGLU = SwiGLUFFN
         | 
| 49 | 
            +
                XFORMERS_AVAILABLE = False
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                # warnings.warn("xFormers is not available (SwiGLU)")
         | 
| 52 | 
            +
             | 
| 53 | 
            +
             | 
| 54 | 
            +
            class SwiGLUFFNFused(SwiGLU):
         | 
| 55 | 
            +
                def __init__(
         | 
| 56 | 
            +
                    self,
         | 
| 57 | 
            +
                    in_features: int,
         | 
| 58 | 
            +
                    hidden_features: Optional[int] = None,
         | 
| 59 | 
            +
                    out_features: Optional[int] = None,
         | 
| 60 | 
            +
                    act_layer: Callable[..., nn.Module] = None,
         | 
| 61 | 
            +
                    drop: float = 0.0,
         | 
| 62 | 
            +
                    bias: bool = True,
         | 
| 63 | 
            +
                ) -> None:
         | 
| 64 | 
            +
                    out_features = out_features or in_features
         | 
| 65 | 
            +
                    hidden_features = hidden_features or in_features
         | 
| 66 | 
            +
                    hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
         | 
| 67 | 
            +
                    super().__init__(
         | 
| 68 | 
            +
                        in_features=in_features,
         | 
| 69 | 
            +
                        hidden_features=hidden_features,
         | 
| 70 | 
            +
                        out_features=out_features,
         | 
| 71 | 
            +
                        bias=bias,
         | 
| 72 | 
            +
                    )
         | 
    	
        pi3/models/dinov2/models/__init__.py
    ADDED
    
    | @@ -0,0 +1,43 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Meta Platforms, Inc. and affiliates.
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # This source code is licensed under the Apache License, Version 2.0
         | 
| 4 | 
            +
            # found in the LICENSE file in the root directory of this source tree.
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            import logging
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            from . import vision_transformer as vits
         | 
| 9 | 
            +
             | 
| 10 | 
            +
             | 
| 11 | 
            +
            logger = logging.getLogger("dinov2")
         | 
| 12 | 
            +
             | 
| 13 | 
            +
             | 
| 14 | 
            +
            def build_model(args, only_teacher=False, img_size=224):
         | 
| 15 | 
            +
                args.arch = args.arch.removesuffix("_memeff")
         | 
| 16 | 
            +
                if "vit" in args.arch:
         | 
| 17 | 
            +
                    vit_kwargs = dict(
         | 
| 18 | 
            +
                        img_size=img_size,
         | 
| 19 | 
            +
                        patch_size=args.patch_size,
         | 
| 20 | 
            +
                        init_values=args.layerscale,
         | 
| 21 | 
            +
                        ffn_layer=args.ffn_layer,
         | 
| 22 | 
            +
                        block_chunks=args.block_chunks,
         | 
| 23 | 
            +
                        qkv_bias=args.qkv_bias,
         | 
| 24 | 
            +
                        proj_bias=args.proj_bias,
         | 
| 25 | 
            +
                        ffn_bias=args.ffn_bias,
         | 
| 26 | 
            +
                        num_register_tokens=args.num_register_tokens,
         | 
| 27 | 
            +
                        interpolate_offset=args.interpolate_offset,
         | 
| 28 | 
            +
                        interpolate_antialias=args.interpolate_antialias,
         | 
| 29 | 
            +
                    )
         | 
| 30 | 
            +
                    teacher = vits.__dict__[args.arch](**vit_kwargs)
         | 
| 31 | 
            +
                    if only_teacher:
         | 
| 32 | 
            +
                        return teacher, teacher.embed_dim
         | 
| 33 | 
            +
                    student = vits.__dict__[args.arch](
         | 
| 34 | 
            +
                        **vit_kwargs,
         | 
| 35 | 
            +
                        drop_path_rate=args.drop_path_rate,
         | 
| 36 | 
            +
                        drop_path_uniform=args.drop_path_uniform,
         | 
| 37 | 
            +
                    )
         | 
| 38 | 
            +
                    embed_dim = student.embed_dim
         | 
| 39 | 
            +
                return student, teacher, embed_dim
         | 
| 40 | 
            +
             | 
| 41 | 
            +
             | 
| 42 | 
            +
            def build_model_from_cfg(cfg, only_teacher=False):
         | 
| 43 | 
            +
                return build_model(cfg.student, only_teacher=only_teacher, img_size=cfg.crops.global_crops_size)
         | 
    	
        pi3/models/dinov2/models/vision_transformer.py
    ADDED
    
    | @@ -0,0 +1,404 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Meta Platforms, Inc. and affiliates.
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # This source code is licensed under the Apache License, Version 2.0
         | 
| 4 | 
            +
            # found in the LICENSE file in the root directory of this source tree.
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            # References:
         | 
| 7 | 
            +
            #   https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
         | 
| 8 | 
            +
            #   https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            from functools import partial
         | 
| 11 | 
            +
            import math
         | 
| 12 | 
            +
            import logging
         | 
| 13 | 
            +
            from typing import Sequence, Tuple, Union, Callable
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            import torch
         | 
| 16 | 
            +
            import torch.nn as nn
         | 
| 17 | 
            +
            from torch.utils.checkpoint import checkpoint
         | 
| 18 | 
            +
            from torch.nn.init import trunc_normal_
         | 
| 19 | 
            +
             | 
| 20 | 
            +
            from ..layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block
         | 
| 21 | 
            +
            from ...layers.attention import FlashAttention
         | 
| 22 | 
            +
             | 
| 23 | 
            +
             | 
| 24 | 
            +
            # logger = logging.getLogger("dinov2")
         | 
| 25 | 
            +
             | 
| 26 | 
            +
             | 
| 27 | 
            +
            def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
         | 
| 28 | 
            +
                if not depth_first and include_root:
         | 
| 29 | 
            +
                    fn(module=module, name=name)
         | 
| 30 | 
            +
                for child_name, child_module in module.named_children():
         | 
| 31 | 
            +
                    child_name = ".".join((name, child_name)) if name else child_name
         | 
| 32 | 
            +
                    named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
         | 
| 33 | 
            +
                if depth_first and include_root:
         | 
| 34 | 
            +
                    fn(module=module, name=name)
         | 
| 35 | 
            +
                return module
         | 
| 36 | 
            +
             | 
| 37 | 
            +
             | 
| 38 | 
            +
            class BlockChunk(nn.ModuleList):
         | 
| 39 | 
            +
                def forward(self, x):
         | 
| 40 | 
            +
                    for b in self:
         | 
| 41 | 
            +
                        x = b(x)
         | 
| 42 | 
            +
                    return x
         | 
| 43 | 
            +
             | 
| 44 | 
            +
             | 
| 45 | 
            +
            class DinoVisionTransformer(nn.Module):
         | 
| 46 | 
            +
                def __init__(
         | 
| 47 | 
            +
                    self,
         | 
| 48 | 
            +
                    img_size=224,
         | 
| 49 | 
            +
                    patch_size=16,
         | 
| 50 | 
            +
                    in_chans=3,
         | 
| 51 | 
            +
                    embed_dim=768,
         | 
| 52 | 
            +
                    depth=12,
         | 
| 53 | 
            +
                    num_heads=12,
         | 
| 54 | 
            +
                    mlp_ratio=4.0,
         | 
| 55 | 
            +
                    qkv_bias=True,
         | 
| 56 | 
            +
                    ffn_bias=True,
         | 
| 57 | 
            +
                    proj_bias=True,
         | 
| 58 | 
            +
                    drop_path_rate=0.0,
         | 
| 59 | 
            +
                    drop_path_uniform=False,
         | 
| 60 | 
            +
                    init_values=None,  # for layerscale: None or 0 => no layerscale
         | 
| 61 | 
            +
                    embed_layer=PatchEmbed,
         | 
| 62 | 
            +
                    act_layer=nn.GELU,
         | 
| 63 | 
            +
                    block_fn=Block,
         | 
| 64 | 
            +
                    ffn_layer="mlp",
         | 
| 65 | 
            +
                    block_chunks=1,
         | 
| 66 | 
            +
                    num_register_tokens=0,
         | 
| 67 | 
            +
                    interpolate_antialias=False,
         | 
| 68 | 
            +
                    interpolate_offset=0.1,
         | 
| 69 | 
            +
                ):
         | 
| 70 | 
            +
                    """
         | 
| 71 | 
            +
                    Args:
         | 
| 72 | 
            +
                        img_size (int, tuple): input image size
         | 
| 73 | 
            +
                        patch_size (int, tuple): patch size
         | 
| 74 | 
            +
                        in_chans (int): number of input channels
         | 
| 75 | 
            +
                        embed_dim (int): embedding dimension
         | 
| 76 | 
            +
                        depth (int): depth of transformer
         | 
| 77 | 
            +
                        num_heads (int): number of attention heads
         | 
| 78 | 
            +
                        mlp_ratio (int): ratio of mlp hidden dim to embedding dim
         | 
| 79 | 
            +
                        qkv_bias (bool): enable bias for qkv if True
         | 
| 80 | 
            +
                        proj_bias (bool): enable bias for proj in attn if True
         | 
| 81 | 
            +
                        ffn_bias (bool): enable bias for ffn if True
         | 
| 82 | 
            +
                        drop_path_rate (float): stochastic depth rate
         | 
| 83 | 
            +
                        drop_path_uniform (bool): apply uniform drop rate across blocks
         | 
| 84 | 
            +
                        weight_init (str): weight init scheme
         | 
| 85 | 
            +
                        init_values (float): layer-scale init values
         | 
| 86 | 
            +
                        embed_layer (nn.Module): patch embedding layer
         | 
| 87 | 
            +
                        act_layer (nn.Module): MLP activation layer
         | 
| 88 | 
            +
                        block_fn (nn.Module): transformer block class
         | 
| 89 | 
            +
                        ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
         | 
| 90 | 
            +
                        block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
         | 
| 91 | 
            +
                        num_register_tokens: (int) number of extra cls tokens (so-called "registers")
         | 
| 92 | 
            +
                        interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings
         | 
| 93 | 
            +
                        interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings
         | 
| 94 | 
            +
                    """
         | 
| 95 | 
            +
                    super().__init__()
         | 
| 96 | 
            +
                    norm_layer = partial(nn.LayerNorm, eps=1e-6)
         | 
| 97 | 
            +
             | 
| 98 | 
            +
                    self.num_features = self.embed_dim = embed_dim  # num_features for consistency with other models
         | 
| 99 | 
            +
                    self.num_tokens = 1
         | 
| 100 | 
            +
                    self.n_blocks = depth
         | 
| 101 | 
            +
                    self.num_heads = num_heads
         | 
| 102 | 
            +
                    self.patch_size = patch_size
         | 
| 103 | 
            +
                    self.num_register_tokens = num_register_tokens
         | 
| 104 | 
            +
                    self.interpolate_antialias = interpolate_antialias
         | 
| 105 | 
            +
                    self.interpolate_offset = interpolate_offset
         | 
| 106 | 
            +
             | 
| 107 | 
            +
                    self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
         | 
| 108 | 
            +
                    num_patches = self.patch_embed.num_patches
         | 
| 109 | 
            +
             | 
| 110 | 
            +
                    self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
         | 
| 111 | 
            +
                    self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
         | 
| 112 | 
            +
                    assert num_register_tokens >= 0
         | 
| 113 | 
            +
                    self.register_tokens = (
         | 
| 114 | 
            +
                        nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None
         | 
| 115 | 
            +
                    )
         | 
| 116 | 
            +
             | 
| 117 | 
            +
                    if drop_path_uniform is True:
         | 
| 118 | 
            +
                        dpr = [drop_path_rate] * depth
         | 
| 119 | 
            +
                    else:
         | 
| 120 | 
            +
                        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]  # stochastic depth decay rule
         | 
| 121 | 
            +
             | 
| 122 | 
            +
                    if ffn_layer == "mlp":
         | 
| 123 | 
            +
                        # logger.info("using MLP layer as FFN")
         | 
| 124 | 
            +
                        ffn_layer = Mlp
         | 
| 125 | 
            +
                    elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
         | 
| 126 | 
            +
                        # logger.info("using SwiGLU layer as FFN")
         | 
| 127 | 
            +
                        ffn_layer = SwiGLUFFNFused
         | 
| 128 | 
            +
                    elif ffn_layer == "identity":
         | 
| 129 | 
            +
                        # logger.info("using Identity layer as FFN")
         | 
| 130 | 
            +
             | 
| 131 | 
            +
                        def f(*args, **kwargs):
         | 
| 132 | 
            +
                            return nn.Identity()
         | 
| 133 | 
            +
             | 
| 134 | 
            +
                        ffn_layer = f
         | 
| 135 | 
            +
                    else:
         | 
| 136 | 
            +
                        raise NotImplementedError
         | 
| 137 | 
            +
             | 
| 138 | 
            +
                    blocks_list = [
         | 
| 139 | 
            +
                        block_fn(
         | 
| 140 | 
            +
                            dim=embed_dim,
         | 
| 141 | 
            +
                            num_heads=num_heads,
         | 
| 142 | 
            +
                            mlp_ratio=mlp_ratio,
         | 
| 143 | 
            +
                            qkv_bias=qkv_bias,
         | 
| 144 | 
            +
                            proj_bias=proj_bias,
         | 
| 145 | 
            +
                            ffn_bias=ffn_bias,
         | 
| 146 | 
            +
                            drop_path=dpr[i],
         | 
| 147 | 
            +
                            norm_layer=norm_layer,
         | 
| 148 | 
            +
                            act_layer=act_layer,
         | 
| 149 | 
            +
                            ffn_layer=ffn_layer,
         | 
| 150 | 
            +
                            init_values=init_values,
         | 
| 151 | 
            +
                            attn_class=FlashAttention
         | 
| 152 | 
            +
                        )
         | 
| 153 | 
            +
                        for i in range(depth)
         | 
| 154 | 
            +
                    ]
         | 
| 155 | 
            +
                    if block_chunks > 0:
         | 
| 156 | 
            +
                        self.chunked_blocks = True
         | 
| 157 | 
            +
                        chunked_blocks = []
         | 
| 158 | 
            +
                        chunksize = depth // block_chunks
         | 
| 159 | 
            +
                        for i in range(0, depth, chunksize):
         | 
| 160 | 
            +
                            # this is to keep the block index consistent if we chunk the block list
         | 
| 161 | 
            +
                            chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
         | 
| 162 | 
            +
                        self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
         | 
| 163 | 
            +
                    else:
         | 
| 164 | 
            +
                        self.chunked_blocks = False
         | 
| 165 | 
            +
                        self.blocks = nn.ModuleList(blocks_list)
         | 
| 166 | 
            +
             | 
| 167 | 
            +
                    self.norm = norm_layer(embed_dim)
         | 
| 168 | 
            +
                    self.head = nn.Identity()
         | 
| 169 | 
            +
             | 
| 170 | 
            +
                    self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
         | 
| 171 | 
            +
             | 
| 172 | 
            +
                    self.init_weights()
         | 
| 173 | 
            +
             | 
| 174 | 
            +
                def init_weights(self):
         | 
| 175 | 
            +
                    trunc_normal_(self.pos_embed, std=0.02)
         | 
| 176 | 
            +
                    nn.init.normal_(self.cls_token, std=1e-6)
         | 
| 177 | 
            +
                    if self.register_tokens is not None:
         | 
| 178 | 
            +
                        nn.init.normal_(self.register_tokens, std=1e-6)
         | 
| 179 | 
            +
                    named_apply(init_weights_vit_timm, self)
         | 
| 180 | 
            +
             | 
| 181 | 
            +
                def interpolate_pos_encoding(self, x, w, h):
         | 
| 182 | 
            +
                    previous_dtype = x.dtype
         | 
| 183 | 
            +
                    npatch = x.shape[1] - 1
         | 
| 184 | 
            +
                    N = self.pos_embed.shape[1] - 1
         | 
| 185 | 
            +
                    if npatch == N and w == h:
         | 
| 186 | 
            +
                        return self.pos_embed
         | 
| 187 | 
            +
                    pos_embed = self.pos_embed.float()
         | 
| 188 | 
            +
                    class_pos_embed = pos_embed[:, 0]
         | 
| 189 | 
            +
                    patch_pos_embed = pos_embed[:, 1:]
         | 
| 190 | 
            +
                    dim = x.shape[-1]
         | 
| 191 | 
            +
                    w0 = w // self.patch_size
         | 
| 192 | 
            +
                    h0 = h // self.patch_size
         | 
| 193 | 
            +
                    M = int(math.sqrt(N))  # Recover the number of patches in each dimension
         | 
| 194 | 
            +
                    assert N == M * M
         | 
| 195 | 
            +
                    kwargs = {}
         | 
| 196 | 
            +
                    if self.interpolate_offset:
         | 
| 197 | 
            +
                        # Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8
         | 
| 198 | 
            +
                        # Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors
         | 
| 199 | 
            +
                        sx = float(w0 + self.interpolate_offset) / M
         | 
| 200 | 
            +
                        sy = float(h0 + self.interpolate_offset) / M
         | 
| 201 | 
            +
                        kwargs["scale_factor"] = (sx, sy)
         | 
| 202 | 
            +
                    else:
         | 
| 203 | 
            +
                        # Simply specify an output size instead of a scale factor
         | 
| 204 | 
            +
                        kwargs["size"] = (w0, h0)
         | 
| 205 | 
            +
                    patch_pos_embed = nn.functional.interpolate(
         | 
| 206 | 
            +
                        patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2),
         | 
| 207 | 
            +
                        mode="bicubic",
         | 
| 208 | 
            +
                        antialias=self.interpolate_antialias,
         | 
| 209 | 
            +
                        **kwargs,
         | 
| 210 | 
            +
                    )
         | 
| 211 | 
            +
                    assert (w0, h0) == patch_pos_embed.shape[-2:]
         | 
| 212 | 
            +
                    patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
         | 
| 213 | 
            +
                    return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
         | 
| 214 | 
            +
             | 
| 215 | 
            +
                def prepare_tokens_with_masks(self, x, masks=None):
         | 
| 216 | 
            +
                    B, nc, w, h = x.shape
         | 
| 217 | 
            +
                    x = self.patch_embed(x)
         | 
| 218 | 
            +
                    if masks is not None:
         | 
| 219 | 
            +
                        x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
         | 
| 220 | 
            +
             | 
| 221 | 
            +
                    x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
         | 
| 222 | 
            +
                    x = x + self.interpolate_pos_encoding(x, w, h)
         | 
| 223 | 
            +
             | 
| 224 | 
            +
                    if self.register_tokens is not None:
         | 
| 225 | 
            +
                        x = torch.cat(
         | 
| 226 | 
            +
                            (
         | 
| 227 | 
            +
                                x[:, :1],
         | 
| 228 | 
            +
                                self.register_tokens.expand(x.shape[0], -1, -1),
         | 
| 229 | 
            +
                                x[:, 1:],
         | 
| 230 | 
            +
                            ),
         | 
| 231 | 
            +
                            dim=1,
         | 
| 232 | 
            +
                        )
         | 
| 233 | 
            +
             | 
| 234 | 
            +
                    return x
         | 
| 235 | 
            +
             | 
| 236 | 
            +
                def forward_features_list(self, x_list, masks_list):
         | 
| 237 | 
            +
                    x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
         | 
| 238 | 
            +
                    for blk in self.blocks:
         | 
| 239 | 
            +
                        if self.training:
         | 
| 240 | 
            +
                            x = checkpoint(blk, x, use_reentrant=False)
         | 
| 241 | 
            +
                        else:
         | 
| 242 | 
            +
                            x = blk(x)
         | 
| 243 | 
            +
             | 
| 244 | 
            +
                    all_x = x
         | 
| 245 | 
            +
                    output = []
         | 
| 246 | 
            +
                    for x, masks in zip(all_x, masks_list):
         | 
| 247 | 
            +
                        x_norm = self.norm(x)
         | 
| 248 | 
            +
                        output.append(
         | 
| 249 | 
            +
                            {
         | 
| 250 | 
            +
                                "x_norm_clstoken": x_norm[:, 0],
         | 
| 251 | 
            +
                                "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
         | 
| 252 | 
            +
                                "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
         | 
| 253 | 
            +
                                "x_prenorm": x,
         | 
| 254 | 
            +
                                "masks": masks,
         | 
| 255 | 
            +
                            }
         | 
| 256 | 
            +
                        )
         | 
| 257 | 
            +
                    return output
         | 
| 258 | 
            +
             | 
| 259 | 
            +
                def forward_features(self, x, masks=None):
         | 
| 260 | 
            +
                    if isinstance(x, list):
         | 
| 261 | 
            +
                        return self.forward_features_list(x, masks)
         | 
| 262 | 
            +
             | 
| 263 | 
            +
                    x = self.prepare_tokens_with_masks(x, masks)
         | 
| 264 | 
            +
             | 
| 265 | 
            +
                    for blk in self.blocks:
         | 
| 266 | 
            +
                        if self.training:
         | 
| 267 | 
            +
                            x = checkpoint(blk, x, use_reentrant=False)
         | 
| 268 | 
            +
                        else:
         | 
| 269 | 
            +
                            x = blk(x)
         | 
| 270 | 
            +
             | 
| 271 | 
            +
                    x_norm = self.norm(x)
         | 
| 272 | 
            +
                    return {
         | 
| 273 | 
            +
                        "x_norm_clstoken": x_norm[:, 0],
         | 
| 274 | 
            +
                        "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
         | 
| 275 | 
            +
                        "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
         | 
| 276 | 
            +
                        "x_prenorm": x,
         | 
| 277 | 
            +
                        "masks": masks,
         | 
| 278 | 
            +
                    }
         | 
| 279 | 
            +
             | 
| 280 | 
            +
                def _get_intermediate_layers_not_chunked(self, x, n=1):
         | 
| 281 | 
            +
                    x = self.prepare_tokens_with_masks(x)
         | 
| 282 | 
            +
                    # If n is an int, take the n last blocks. If it's a list, take them
         | 
| 283 | 
            +
                    output, total_block_len = [], len(self.blocks)
         | 
| 284 | 
            +
                    blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
         | 
| 285 | 
            +
                    for i, blk in enumerate(self.blocks):
         | 
| 286 | 
            +
                        x = blk(x)
         | 
| 287 | 
            +
                        if i in blocks_to_take:
         | 
| 288 | 
            +
                            output.append(x)
         | 
| 289 | 
            +
                    assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
         | 
| 290 | 
            +
                    return output
         | 
| 291 | 
            +
             | 
| 292 | 
            +
                def _get_intermediate_layers_chunked(self, x, n=1):
         | 
| 293 | 
            +
                    x = self.prepare_tokens_with_masks(x)
         | 
| 294 | 
            +
                    output, i, total_block_len = [], 0, len(self.blocks[-1])
         | 
| 295 | 
            +
                    # If n is an int, take the n last blocks. If it's a list, take them
         | 
| 296 | 
            +
                    blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
         | 
| 297 | 
            +
                    for block_chunk in self.blocks:
         | 
| 298 | 
            +
                        for blk in block_chunk[i:]:  # Passing the nn.Identity()
         | 
| 299 | 
            +
                            x = blk(x)
         | 
| 300 | 
            +
                            if i in blocks_to_take:
         | 
| 301 | 
            +
                                output.append(x)
         | 
| 302 | 
            +
                            i += 1
         | 
| 303 | 
            +
                    assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
         | 
| 304 | 
            +
                    return output
         | 
| 305 | 
            +
             | 
| 306 | 
            +
                def get_intermediate_layers(
         | 
| 307 | 
            +
                    self,
         | 
| 308 | 
            +
                    x: torch.Tensor,
         | 
| 309 | 
            +
                    n: Union[int, Sequence] = 1,  # Layers or n last layers to take
         | 
| 310 | 
            +
                    reshape: bool = False,
         | 
| 311 | 
            +
                    return_class_token: bool = False,
         | 
| 312 | 
            +
                    norm=True,
         | 
| 313 | 
            +
                ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
         | 
| 314 | 
            +
                    if self.chunked_blocks:
         | 
| 315 | 
            +
                        outputs = self._get_intermediate_layers_chunked(x, n)
         | 
| 316 | 
            +
                    else:
         | 
| 317 | 
            +
                        outputs = self._get_intermediate_layers_not_chunked(x, n)
         | 
| 318 | 
            +
                    if norm:
         | 
| 319 | 
            +
                        outputs = [self.norm(out) for out in outputs]
         | 
| 320 | 
            +
                    class_tokens = [out[:, 0] for out in outputs]
         | 
| 321 | 
            +
                    outputs = [out[:, 1 + self.num_register_tokens :] for out in outputs]
         | 
| 322 | 
            +
                    if reshape:
         | 
| 323 | 
            +
                        B, _, w, h = x.shape
         | 
| 324 | 
            +
                        outputs = [
         | 
| 325 | 
            +
                            out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
         | 
| 326 | 
            +
                            for out in outputs
         | 
| 327 | 
            +
                        ]
         | 
| 328 | 
            +
                    if return_class_token:
         | 
| 329 | 
            +
                        return tuple(zip(outputs, class_tokens))
         | 
| 330 | 
            +
                    return tuple(outputs)
         | 
| 331 | 
            +
             | 
| 332 | 
            +
                def forward(self, *args, is_training=False, **kwargs):
         | 
| 333 | 
            +
                    ret = self.forward_features(*args, **kwargs)
         | 
| 334 | 
            +
                    if is_training:
         | 
| 335 | 
            +
                        return ret
         | 
| 336 | 
            +
                    else:
         | 
| 337 | 
            +
                        return self.head(ret["x_norm_clstoken"])
         | 
| 338 | 
            +
             | 
| 339 | 
            +
             | 
| 340 | 
            +
            def init_weights_vit_timm(module: nn.Module, name: str = ""):
         | 
| 341 | 
            +
                """ViT weight initialization, original timm impl (for reproducibility)"""
         | 
| 342 | 
            +
                if isinstance(module, nn.Linear):
         | 
| 343 | 
            +
                    trunc_normal_(module.weight, std=0.02)
         | 
| 344 | 
            +
                    if module.bias is not None:
         | 
| 345 | 
            +
                        nn.init.zeros_(module.bias)
         | 
| 346 | 
            +
             | 
| 347 | 
            +
             | 
| 348 | 
            +
            def vit_small(patch_size=16, num_register_tokens=0, **kwargs):
         | 
| 349 | 
            +
                model = DinoVisionTransformer(
         | 
| 350 | 
            +
                    patch_size=patch_size,
         | 
| 351 | 
            +
                    embed_dim=384,
         | 
| 352 | 
            +
                    depth=12,
         | 
| 353 | 
            +
                    num_heads=6,
         | 
| 354 | 
            +
                    mlp_ratio=4,
         | 
| 355 | 
            +
                    block_fn=partial(Block, attn_class=MemEffAttention),
         | 
| 356 | 
            +
                    num_register_tokens=num_register_tokens,
         | 
| 357 | 
            +
                    **kwargs,
         | 
| 358 | 
            +
                )
         | 
| 359 | 
            +
                return model
         | 
| 360 | 
            +
             | 
| 361 | 
            +
             | 
| 362 | 
            +
            def vit_base(patch_size=16, num_register_tokens=0, **kwargs):
         | 
| 363 | 
            +
                model = DinoVisionTransformer(
         | 
| 364 | 
            +
                    patch_size=patch_size,
         | 
| 365 | 
            +
                    embed_dim=768,
         | 
| 366 | 
            +
                    depth=12,
         | 
| 367 | 
            +
                    num_heads=12,
         | 
| 368 | 
            +
                    mlp_ratio=4,
         | 
| 369 | 
            +
                    block_fn=partial(Block, attn_class=MemEffAttention),
         | 
| 370 | 
            +
                    num_register_tokens=num_register_tokens,
         | 
| 371 | 
            +
                    **kwargs,
         | 
| 372 | 
            +
                )
         | 
| 373 | 
            +
                return model
         | 
| 374 | 
            +
             | 
| 375 | 
            +
             | 
| 376 | 
            +
            def vit_large(patch_size=16, num_register_tokens=0, **kwargs):
         | 
| 377 | 
            +
                model = DinoVisionTransformer(
         | 
| 378 | 
            +
                    patch_size=patch_size,
         | 
| 379 | 
            +
                    embed_dim=1024,
         | 
| 380 | 
            +
                    depth=24,
         | 
| 381 | 
            +
                    num_heads=16,
         | 
| 382 | 
            +
                    mlp_ratio=4,
         | 
| 383 | 
            +
                    block_fn=partial(Block, attn_class=MemEffAttention),
         | 
| 384 | 
            +
                    num_register_tokens=num_register_tokens,
         | 
| 385 | 
            +
                    **kwargs,
         | 
| 386 | 
            +
                )
         | 
| 387 | 
            +
                return model
         | 
| 388 | 
            +
             | 
| 389 | 
            +
             | 
| 390 | 
            +
            def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs):
         | 
| 391 | 
            +
                """
         | 
| 392 | 
            +
                Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
         | 
| 393 | 
            +
                """
         | 
| 394 | 
            +
                model = DinoVisionTransformer(
         | 
| 395 | 
            +
                    patch_size=patch_size,
         | 
| 396 | 
            +
                    embed_dim=1536,
         | 
| 397 | 
            +
                    depth=40,
         | 
| 398 | 
            +
                    num_heads=24,
         | 
| 399 | 
            +
                    mlp_ratio=4,
         | 
| 400 | 
            +
                    block_fn=partial(Block, attn_class=MemEffAttention),
         | 
| 401 | 
            +
                    num_register_tokens=num_register_tokens,
         | 
| 402 | 
            +
                    **kwargs,
         | 
| 403 | 
            +
                )
         | 
| 404 | 
            +
                return model
         | 
    	
        pi3/models/dinov2/utils/__init__.py
    ADDED
    
    | @@ -0,0 +1,4 @@ | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Meta Platforms, Inc. and affiliates.
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # This source code is licensed under the Apache License, Version 2.0
         | 
| 4 | 
            +
            # found in the LICENSE file in the root directory of this source tree.
         | 
    	
        pi3/models/dinov2/utils/cluster.py
    ADDED
    
    | @@ -0,0 +1,95 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Meta Platforms, Inc. and affiliates.
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # This source code is licensed under the Apache License, Version 2.0
         | 
| 4 | 
            +
            # found in the LICENSE file in the root directory of this source tree.
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            from enum import Enum
         | 
| 7 | 
            +
            import os
         | 
| 8 | 
            +
            from pathlib import Path
         | 
| 9 | 
            +
            from typing import Any, Dict, Optional
         | 
| 10 | 
            +
             | 
| 11 | 
            +
             | 
| 12 | 
            +
            class ClusterType(Enum):
         | 
| 13 | 
            +
                AWS = "aws"
         | 
| 14 | 
            +
                FAIR = "fair"
         | 
| 15 | 
            +
                RSC = "rsc"
         | 
| 16 | 
            +
             | 
| 17 | 
            +
             | 
| 18 | 
            +
            def _guess_cluster_type() -> ClusterType:
         | 
| 19 | 
            +
                uname = os.uname()
         | 
| 20 | 
            +
                if uname.sysname == "Linux":
         | 
| 21 | 
            +
                    if uname.release.endswith("-aws"):
         | 
| 22 | 
            +
                        # Linux kernel versions on AWS instances are of the form "5.4.0-1051-aws"
         | 
| 23 | 
            +
                        return ClusterType.AWS
         | 
| 24 | 
            +
                    elif uname.nodename.startswith("rsc"):
         | 
| 25 | 
            +
                        # Linux kernel versions on RSC instances are standard ones but hostnames start with "rsc"
         | 
| 26 | 
            +
                        return ClusterType.RSC
         | 
| 27 | 
            +
             | 
| 28 | 
            +
                return ClusterType.FAIR
         | 
| 29 | 
            +
             | 
| 30 | 
            +
             | 
| 31 | 
            +
            def get_cluster_type(cluster_type: Optional[ClusterType] = None) -> Optional[ClusterType]:
         | 
| 32 | 
            +
                if cluster_type is None:
         | 
| 33 | 
            +
                    return _guess_cluster_type()
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                return cluster_type
         | 
| 36 | 
            +
             | 
| 37 | 
            +
             | 
| 38 | 
            +
            def get_checkpoint_path(cluster_type: Optional[ClusterType] = None) -> Optional[Path]:
         | 
| 39 | 
            +
                cluster_type = get_cluster_type(cluster_type)
         | 
| 40 | 
            +
                if cluster_type is None:
         | 
| 41 | 
            +
                    return None
         | 
| 42 | 
            +
             | 
| 43 | 
            +
                CHECKPOINT_DIRNAMES = {
         | 
| 44 | 
            +
                    ClusterType.AWS: "checkpoints",
         | 
| 45 | 
            +
                    ClusterType.FAIR: "checkpoint",
         | 
| 46 | 
            +
                    ClusterType.RSC: "checkpoint/dino",
         | 
| 47 | 
            +
                }
         | 
| 48 | 
            +
                return Path("/") / CHECKPOINT_DIRNAMES[cluster_type]
         | 
| 49 | 
            +
             | 
| 50 | 
            +
             | 
| 51 | 
            +
            def get_user_checkpoint_path(cluster_type: Optional[ClusterType] = None) -> Optional[Path]:
         | 
| 52 | 
            +
                checkpoint_path = get_checkpoint_path(cluster_type)
         | 
| 53 | 
            +
                if checkpoint_path is None:
         | 
| 54 | 
            +
                    return None
         | 
| 55 | 
            +
             | 
| 56 | 
            +
                username = os.environ.get("USER")
         | 
| 57 | 
            +
                assert username is not None
         | 
| 58 | 
            +
                return checkpoint_path / username
         | 
| 59 | 
            +
             | 
| 60 | 
            +
             | 
| 61 | 
            +
            def get_slurm_partition(cluster_type: Optional[ClusterType] = None) -> Optional[str]:
         | 
| 62 | 
            +
                cluster_type = get_cluster_type(cluster_type)
         | 
| 63 | 
            +
                if cluster_type is None:
         | 
| 64 | 
            +
                    return None
         | 
| 65 | 
            +
             | 
| 66 | 
            +
                SLURM_PARTITIONS = {
         | 
| 67 | 
            +
                    ClusterType.AWS: "learnlab",
         | 
| 68 | 
            +
                    ClusterType.FAIR: "learnlab",
         | 
| 69 | 
            +
                    ClusterType.RSC: "learn",
         | 
| 70 | 
            +
                }
         | 
| 71 | 
            +
                return SLURM_PARTITIONS[cluster_type]
         | 
| 72 | 
            +
             | 
| 73 | 
            +
             | 
| 74 | 
            +
            def get_slurm_executor_parameters(
         | 
| 75 | 
            +
                nodes: int, num_gpus_per_node: int, cluster_type: Optional[ClusterType] = None, **kwargs
         | 
| 76 | 
            +
            ) -> Dict[str, Any]:
         | 
| 77 | 
            +
                # create default parameters
         | 
| 78 | 
            +
                params = {
         | 
| 79 | 
            +
                    "mem_gb": 0,  # Requests all memory on a node, see https://slurm.schedmd.com/sbatch.html
         | 
| 80 | 
            +
                    "gpus_per_node": num_gpus_per_node,
         | 
| 81 | 
            +
                    "tasks_per_node": num_gpus_per_node,  # one task per GPU
         | 
| 82 | 
            +
                    "cpus_per_task": 10,
         | 
| 83 | 
            +
                    "nodes": nodes,
         | 
| 84 | 
            +
                    "slurm_partition": get_slurm_partition(cluster_type),
         | 
| 85 | 
            +
                }
         | 
| 86 | 
            +
                # apply cluster-specific adjustments
         | 
| 87 | 
            +
                cluster_type = get_cluster_type(cluster_type)
         | 
| 88 | 
            +
                if cluster_type == ClusterType.AWS:
         | 
| 89 | 
            +
                    params["cpus_per_task"] = 12
         | 
| 90 | 
            +
                    del params["mem_gb"]
         | 
| 91 | 
            +
                elif cluster_type == ClusterType.RSC:
         | 
| 92 | 
            +
                    params["cpus_per_task"] = 12
         | 
| 93 | 
            +
                # set additional parameters / apply overrides
         | 
| 94 | 
            +
                params.update(kwargs)
         | 
| 95 | 
            +
                return params
         | 
    	
        pi3/models/dinov2/utils/config.py
    ADDED
    
    | @@ -0,0 +1,72 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Meta Platforms, Inc. and affiliates.
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # This source code is licensed under the Apache License, Version 2.0
         | 
| 4 | 
            +
            # found in the LICENSE file in the root directory of this source tree.
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            import math
         | 
| 7 | 
            +
            import logging
         | 
| 8 | 
            +
            import os
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            from omegaconf import OmegaConf
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            import dinov2.distributed as distributed
         | 
| 13 | 
            +
            from dinov2.logging import setup_logging
         | 
| 14 | 
            +
            from dinov2.utils import utils
         | 
| 15 | 
            +
            from dinov2.configs import dinov2_default_config
         | 
| 16 | 
            +
             | 
| 17 | 
            +
             | 
| 18 | 
            +
            logger = logging.getLogger("dinov2")
         | 
| 19 | 
            +
             | 
| 20 | 
            +
             | 
| 21 | 
            +
            def apply_scaling_rules_to_cfg(cfg):  # to fix
         | 
| 22 | 
            +
                if cfg.optim.scaling_rule == "sqrt_wrt_1024":
         | 
| 23 | 
            +
                    base_lr = cfg.optim.base_lr
         | 
| 24 | 
            +
                    cfg.optim.lr = base_lr
         | 
| 25 | 
            +
                    cfg.optim.lr *= math.sqrt(cfg.train.batch_size_per_gpu * distributed.get_global_size() / 1024.0)
         | 
| 26 | 
            +
                    logger.info(f"sqrt scaling learning rate; base: {base_lr}, new: {cfg.optim.lr}")
         | 
| 27 | 
            +
                else:
         | 
| 28 | 
            +
                    raise NotImplementedError
         | 
| 29 | 
            +
                return cfg
         | 
| 30 | 
            +
             | 
| 31 | 
            +
             | 
| 32 | 
            +
            def write_config(cfg, output_dir, name="config.yaml"):
         | 
| 33 | 
            +
                logger.info(OmegaConf.to_yaml(cfg))
         | 
| 34 | 
            +
                saved_cfg_path = os.path.join(output_dir, name)
         | 
| 35 | 
            +
                with open(saved_cfg_path, "w") as f:
         | 
| 36 | 
            +
                    OmegaConf.save(config=cfg, f=f)
         | 
| 37 | 
            +
                return saved_cfg_path
         | 
| 38 | 
            +
             | 
| 39 | 
            +
             | 
| 40 | 
            +
            def get_cfg_from_args(args):
         | 
| 41 | 
            +
                args.output_dir = os.path.abspath(args.output_dir)
         | 
| 42 | 
            +
                args.opts += [f"train.output_dir={args.output_dir}"]
         | 
| 43 | 
            +
                default_cfg = OmegaConf.create(dinov2_default_config)
         | 
| 44 | 
            +
                cfg = OmegaConf.load(args.config_file)
         | 
| 45 | 
            +
                cfg = OmegaConf.merge(default_cfg, cfg, OmegaConf.from_cli(args.opts))
         | 
| 46 | 
            +
                return cfg
         | 
| 47 | 
            +
             | 
| 48 | 
            +
             | 
| 49 | 
            +
            def default_setup(args):
         | 
| 50 | 
            +
                distributed.enable(overwrite=True)
         | 
| 51 | 
            +
                seed = getattr(args, "seed", 0)
         | 
| 52 | 
            +
                rank = distributed.get_global_rank()
         | 
| 53 | 
            +
             | 
| 54 | 
            +
                global logger
         | 
| 55 | 
            +
                setup_logging(output=args.output_dir, level=logging.INFO)
         | 
| 56 | 
            +
                logger = logging.getLogger("dinov2")
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                utils.fix_random_seeds(seed + rank)
         | 
| 59 | 
            +
                logger.info("git:\n  {}\n".format(utils.get_sha()))
         | 
| 60 | 
            +
                logger.info("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items())))
         | 
| 61 | 
            +
             | 
| 62 | 
            +
             | 
| 63 | 
            +
            def setup(args):
         | 
| 64 | 
            +
                """
         | 
| 65 | 
            +
                Create configs and perform basic setups.
         | 
| 66 | 
            +
                """
         | 
| 67 | 
            +
                cfg = get_cfg_from_args(args)
         | 
| 68 | 
            +
                os.makedirs(args.output_dir, exist_ok=True)
         | 
| 69 | 
            +
                default_setup(args)
         | 
| 70 | 
            +
                apply_scaling_rules_to_cfg(cfg)
         | 
| 71 | 
            +
                write_config(cfg, args.output_dir)
         | 
| 72 | 
            +
                return cfg
         | 
    	
        pi3/models/dinov2/utils/dtype.py
    ADDED
    
    | @@ -0,0 +1,37 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Meta Platforms, Inc. and affiliates.
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # This source code is licensed under the Apache License, Version 2.0
         | 
| 4 | 
            +
            # found in the LICENSE file in the root directory of this source tree.
         | 
| 5 | 
            +
             | 
| 6 | 
            +
             | 
| 7 | 
            +
            from typing import Dict, Union
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            import numpy as np
         | 
| 10 | 
            +
            import torch
         | 
| 11 | 
            +
             | 
| 12 | 
            +
             | 
| 13 | 
            +
            TypeSpec = Union[str, np.dtype, torch.dtype]
         | 
| 14 | 
            +
             | 
| 15 | 
            +
             | 
| 16 | 
            +
            _NUMPY_TO_TORCH_DTYPE: Dict[np.dtype, torch.dtype] = {
         | 
| 17 | 
            +
                np.dtype("bool"): torch.bool,
         | 
| 18 | 
            +
                np.dtype("uint8"): torch.uint8,
         | 
| 19 | 
            +
                np.dtype("int8"): torch.int8,
         | 
| 20 | 
            +
                np.dtype("int16"): torch.int16,
         | 
| 21 | 
            +
                np.dtype("int32"): torch.int32,
         | 
| 22 | 
            +
                np.dtype("int64"): torch.int64,
         | 
| 23 | 
            +
                np.dtype("float16"): torch.float16,
         | 
| 24 | 
            +
                np.dtype("float32"): torch.float32,
         | 
| 25 | 
            +
                np.dtype("float64"): torch.float64,
         | 
| 26 | 
            +
                np.dtype("complex64"): torch.complex64,
         | 
| 27 | 
            +
                np.dtype("complex128"): torch.complex128,
         | 
| 28 | 
            +
            }
         | 
| 29 | 
            +
             | 
| 30 | 
            +
             | 
| 31 | 
            +
            def as_torch_dtype(dtype: TypeSpec) -> torch.dtype:
         | 
| 32 | 
            +
                if isinstance(dtype, torch.dtype):
         | 
| 33 | 
            +
                    return dtype
         | 
| 34 | 
            +
                if isinstance(dtype, str):
         | 
| 35 | 
            +
                    dtype = np.dtype(dtype)
         | 
| 36 | 
            +
                assert isinstance(dtype, np.dtype), f"Expected an instance of nunpy dtype, got {type(dtype)}"
         | 
| 37 | 
            +
                return _NUMPY_TO_TORCH_DTYPE[dtype]
         | 
    	
        pi3/models/dinov2/utils/param_groups.py
    ADDED
    
    | @@ -0,0 +1,103 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Meta Platforms, Inc. and affiliates.
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # This source code is licensed under the Apache License, Version 2.0
         | 
| 4 | 
            +
            # found in the LICENSE file in the root directory of this source tree.
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            from collections import defaultdict
         | 
| 7 | 
            +
            import logging
         | 
| 8 | 
            +
             | 
| 9 | 
            +
             | 
| 10 | 
            +
            logger = logging.getLogger("dinov2")
         | 
| 11 | 
            +
             | 
| 12 | 
            +
             | 
| 13 | 
            +
            def get_vit_lr_decay_rate(name, lr_decay_rate=1.0, num_layers=12, force_is_backbone=False, chunked_blocks=False):
         | 
| 14 | 
            +
                """
         | 
| 15 | 
            +
                Calculate lr decay rate for different ViT blocks.
         | 
| 16 | 
            +
                Args:
         | 
| 17 | 
            +
                    name (string): parameter name.
         | 
| 18 | 
            +
                    lr_decay_rate (float): base lr decay rate.
         | 
| 19 | 
            +
                    num_layers (int): number of ViT blocks.
         | 
| 20 | 
            +
                Returns:
         | 
| 21 | 
            +
                    lr decay rate for the given parameter.
         | 
| 22 | 
            +
                """
         | 
| 23 | 
            +
                layer_id = num_layers + 1
         | 
| 24 | 
            +
                if name.startswith("backbone") or force_is_backbone:
         | 
| 25 | 
            +
                    if (
         | 
| 26 | 
            +
                        ".pos_embed" in name
         | 
| 27 | 
            +
                        or ".patch_embed" in name
         | 
| 28 | 
            +
                        or ".mask_token" in name
         | 
| 29 | 
            +
                        or ".cls_token" in name
         | 
| 30 | 
            +
                        or ".register_tokens" in name
         | 
| 31 | 
            +
                    ):
         | 
| 32 | 
            +
                        layer_id = 0
         | 
| 33 | 
            +
                    elif force_is_backbone and (
         | 
| 34 | 
            +
                        "pos_embed" in name
         | 
| 35 | 
            +
                        or "patch_embed" in name
         | 
| 36 | 
            +
                        or "mask_token" in name
         | 
| 37 | 
            +
                        or "cls_token" in name
         | 
| 38 | 
            +
                        or "register_tokens" in name
         | 
| 39 | 
            +
                    ):
         | 
| 40 | 
            +
                        layer_id = 0
         | 
| 41 | 
            +
                    elif ".blocks." in name and ".residual." not in name:
         | 
| 42 | 
            +
                        layer_id = int(name[name.find(".blocks.") :].split(".")[2]) + 1
         | 
| 43 | 
            +
                    elif chunked_blocks and "blocks." in name and "residual." not in name:
         | 
| 44 | 
            +
                        layer_id = int(name[name.find("blocks.") :].split(".")[2]) + 1
         | 
| 45 | 
            +
                    elif "blocks." in name and "residual." not in name:
         | 
| 46 | 
            +
                        layer_id = int(name[name.find("blocks.") :].split(".")[1]) + 1
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                return lr_decay_rate ** (num_layers + 1 - layer_id)
         | 
| 49 | 
            +
             | 
| 50 | 
            +
             | 
| 51 | 
            +
            def get_params_groups_with_decay(model, lr_decay_rate=1.0, patch_embed_lr_mult=1.0):
         | 
| 52 | 
            +
                chunked_blocks = False
         | 
| 53 | 
            +
                if hasattr(model, "n_blocks"):
         | 
| 54 | 
            +
                    logger.info("chunked fsdp")
         | 
| 55 | 
            +
                    n_blocks = model.n_blocks
         | 
| 56 | 
            +
                    chunked_blocks = model.chunked_blocks
         | 
| 57 | 
            +
                elif hasattr(model, "blocks"):
         | 
| 58 | 
            +
                    logger.info("first code branch")
         | 
| 59 | 
            +
                    n_blocks = len(model.blocks)
         | 
| 60 | 
            +
                elif hasattr(model, "backbone"):
         | 
| 61 | 
            +
                    logger.info("second code branch")
         | 
| 62 | 
            +
                    n_blocks = len(model.backbone.blocks)
         | 
| 63 | 
            +
                else:
         | 
| 64 | 
            +
                    logger.info("else code branch")
         | 
| 65 | 
            +
                    n_blocks = 0
         | 
| 66 | 
            +
                all_param_groups = []
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                for name, param in model.named_parameters():
         | 
| 69 | 
            +
                    name = name.replace("_fsdp_wrapped_module.", "")
         | 
| 70 | 
            +
                    if not param.requires_grad:
         | 
| 71 | 
            +
                        continue
         | 
| 72 | 
            +
                    decay_rate = get_vit_lr_decay_rate(
         | 
| 73 | 
            +
                        name, lr_decay_rate, num_layers=n_blocks, force_is_backbone=n_blocks > 0, chunked_blocks=chunked_blocks
         | 
| 74 | 
            +
                    )
         | 
| 75 | 
            +
                    d = {"params": param, "is_last_layer": False, "lr_multiplier": decay_rate, "wd_multiplier": 1.0, "name": name}
         | 
| 76 | 
            +
             | 
| 77 | 
            +
                    if "last_layer" in name:
         | 
| 78 | 
            +
                        d.update({"is_last_layer": True})
         | 
| 79 | 
            +
             | 
| 80 | 
            +
                    if name.endswith(".bias") or "norm" in name or "gamma" in name:
         | 
| 81 | 
            +
                        d.update({"wd_multiplier": 0.0})
         | 
| 82 | 
            +
             | 
| 83 | 
            +
                    if "patch_embed" in name:
         | 
| 84 | 
            +
                        d.update({"lr_multiplier": d["lr_multiplier"] * patch_embed_lr_mult})
         | 
| 85 | 
            +
             | 
| 86 | 
            +
                    all_param_groups.append(d)
         | 
| 87 | 
            +
                    logger.info(f"""{name}: lr_multiplier: {d["lr_multiplier"]}, wd_multiplier: {d["wd_multiplier"]}""")
         | 
| 88 | 
            +
             | 
| 89 | 
            +
                return all_param_groups
         | 
| 90 | 
            +
             | 
| 91 | 
            +
             | 
| 92 | 
            +
            def fuse_params_groups(all_params_groups, keys=("lr_multiplier", "wd_multiplier", "is_last_layer")):
         | 
| 93 | 
            +
                fused_params_groups = defaultdict(lambda: {"params": []})
         | 
| 94 | 
            +
                for d in all_params_groups:
         | 
| 95 | 
            +
                    identifier = ""
         | 
| 96 | 
            +
                    for k in keys:
         | 
| 97 | 
            +
                        identifier += k + str(d[k]) + "_"
         | 
| 98 | 
            +
             | 
| 99 | 
            +
                    for k in keys:
         | 
| 100 | 
            +
                        fused_params_groups[identifier][k] = d[k]
         | 
| 101 | 
            +
                    fused_params_groups[identifier]["params"].append(d["params"])
         | 
| 102 | 
            +
             | 
| 103 | 
            +
                return fused_params_groups.values()
         | 
    	
        pi3/models/dinov2/utils/utils.py
    ADDED
    
    | @@ -0,0 +1,95 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Meta Platforms, Inc. and affiliates.
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # This source code is licensed under the Apache License, Version 2.0
         | 
| 4 | 
            +
            # found in the LICENSE file in the root directory of this source tree.
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            import logging
         | 
| 7 | 
            +
            import os
         | 
| 8 | 
            +
            import random
         | 
| 9 | 
            +
            import subprocess
         | 
| 10 | 
            +
            from urllib.parse import urlparse
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            import numpy as np
         | 
| 13 | 
            +
            import torch
         | 
| 14 | 
            +
            from torch import nn
         | 
| 15 | 
            +
             | 
| 16 | 
            +
             | 
| 17 | 
            +
            # logger = logging.getLogger("dinov2")
         | 
| 18 | 
            +
             | 
| 19 | 
            +
             | 
| 20 | 
            +
            def load_pretrained_weights(model, pretrained_weights, checkpoint_key):
         | 
| 21 | 
            +
                if urlparse(pretrained_weights).scheme:  # If it looks like an URL
         | 
| 22 | 
            +
                    state_dict = torch.hub.load_state_dict_from_url(pretrained_weights, map_location="cpu")
         | 
| 23 | 
            +
                else:
         | 
| 24 | 
            +
                    state_dict = torch.load(pretrained_weights, map_location="cpu")
         | 
| 25 | 
            +
                if checkpoint_key is not None and checkpoint_key in state_dict:
         | 
| 26 | 
            +
                    # logger.info(f"Take key {checkpoint_key} in provided checkpoint dict")
         | 
| 27 | 
            +
                    state_dict = state_dict[checkpoint_key]
         | 
| 28 | 
            +
                # remove `module.` prefix
         | 
| 29 | 
            +
                state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
         | 
| 30 | 
            +
                # remove `backbone.` prefix induced by multicrop wrapper
         | 
| 31 | 
            +
                state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
         | 
| 32 | 
            +
                msg = model.load_state_dict(state_dict, strict=False)
         | 
| 33 | 
            +
                # logger.info("Pretrained weights found at {} and loaded with msg: {}".format(pretrained_weights, msg))
         | 
| 34 | 
            +
             | 
| 35 | 
            +
             | 
| 36 | 
            +
            def fix_random_seeds(seed=31):
         | 
| 37 | 
            +
                """
         | 
| 38 | 
            +
                Fix random seeds.
         | 
| 39 | 
            +
                """
         | 
| 40 | 
            +
                torch.manual_seed(seed)
         | 
| 41 | 
            +
                torch.cuda.manual_seed_all(seed)
         | 
| 42 | 
            +
                np.random.seed(seed)
         | 
| 43 | 
            +
                random.seed(seed)
         | 
| 44 | 
            +
             | 
| 45 | 
            +
             | 
| 46 | 
            +
            def get_sha():
         | 
| 47 | 
            +
                cwd = os.path.dirname(os.path.abspath(__file__))
         | 
| 48 | 
            +
             | 
| 49 | 
            +
                def _run(command):
         | 
| 50 | 
            +
                    return subprocess.check_output(command, cwd=cwd).decode("ascii").strip()
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                sha = "N/A"
         | 
| 53 | 
            +
                diff = "clean"
         | 
| 54 | 
            +
                branch = "N/A"
         | 
| 55 | 
            +
                try:
         | 
| 56 | 
            +
                    sha = _run(["git", "rev-parse", "HEAD"])
         | 
| 57 | 
            +
                    subprocess.check_output(["git", "diff"], cwd=cwd)
         | 
| 58 | 
            +
                    diff = _run(["git", "diff-index", "HEAD"])
         | 
| 59 | 
            +
                    diff = "has uncommitted changes" if diff else "clean"
         | 
| 60 | 
            +
                    branch = _run(["git", "rev-parse", "--abbrev-ref", "HEAD"])
         | 
| 61 | 
            +
                except Exception:
         | 
| 62 | 
            +
                    pass
         | 
| 63 | 
            +
                message = f"sha: {sha}, status: {diff}, branch: {branch}"
         | 
| 64 | 
            +
                return message
         | 
| 65 | 
            +
             | 
| 66 | 
            +
             | 
| 67 | 
            +
            class CosineScheduler(object):
         | 
| 68 | 
            +
                def __init__(self, base_value, final_value, total_iters, warmup_iters=0, start_warmup_value=0, freeze_iters=0):
         | 
| 69 | 
            +
                    super().__init__()
         | 
| 70 | 
            +
                    self.final_value = final_value
         | 
| 71 | 
            +
                    self.total_iters = total_iters
         | 
| 72 | 
            +
             | 
| 73 | 
            +
                    freeze_schedule = np.zeros((freeze_iters))
         | 
| 74 | 
            +
             | 
| 75 | 
            +
                    warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters)
         | 
| 76 | 
            +
             | 
| 77 | 
            +
                    iters = np.arange(total_iters - warmup_iters - freeze_iters)
         | 
| 78 | 
            +
                    schedule = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters)))
         | 
| 79 | 
            +
                    self.schedule = np.concatenate((freeze_schedule, warmup_schedule, schedule))
         | 
| 80 | 
            +
             | 
| 81 | 
            +
                    assert len(self.schedule) == self.total_iters
         | 
| 82 | 
            +
             | 
| 83 | 
            +
                def __getitem__(self, it):
         | 
| 84 | 
            +
                    if it >= self.total_iters:
         | 
| 85 | 
            +
                        return self.final_value
         | 
| 86 | 
            +
                    else:
         | 
| 87 | 
            +
                        return self.schedule[it]
         | 
| 88 | 
            +
             | 
| 89 | 
            +
             | 
| 90 | 
            +
            def has_batchnorms(model):
         | 
| 91 | 
            +
                bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm)
         | 
| 92 | 
            +
                for name, module in model.named_modules():
         | 
| 93 | 
            +
                    if isinstance(module, bn_types):
         | 
| 94 | 
            +
                        return True
         | 
| 95 | 
            +
                return False
         | 
    	
        pi3/models/layers/attention.py
    ADDED
    
    | @@ -0,0 +1,369 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Meta Platforms, Inc. and affiliates.
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # This source code is licensed under the Apache License, Version 2.0
         | 
| 4 | 
            +
            # found in the LICENSE file in the root directory of this source tree.
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            # References:
         | 
| 7 | 
            +
            #   https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
         | 
| 8 | 
            +
            #   https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            import logging
         | 
| 11 | 
            +
            import os
         | 
| 12 | 
            +
            import warnings
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            from torch import Tensor
         | 
| 15 | 
            +
            from torch import nn
         | 
| 16 | 
            +
            import torch
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            from torch.nn.functional import scaled_dot_product_attention
         | 
| 19 | 
            +
            from torch.nn.attention import SDPBackend
         | 
| 20 | 
            +
             | 
| 21 | 
            +
            XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
         | 
| 22 | 
            +
            try:
         | 
| 23 | 
            +
                if XFORMERS_ENABLED:
         | 
| 24 | 
            +
                    from xformers.ops import memory_efficient_attention, unbind
         | 
| 25 | 
            +
             | 
| 26 | 
            +
                    XFORMERS_AVAILABLE = True
         | 
| 27 | 
            +
                    # warnings.warn("xFormers is available (Attention)")
         | 
| 28 | 
            +
                else:
         | 
| 29 | 
            +
                    # warnings.warn("xFormers is disabled (Attention)")
         | 
| 30 | 
            +
                    raise ImportError
         | 
| 31 | 
            +
            except ImportError:
         | 
| 32 | 
            +
                XFORMERS_AVAILABLE = False
         | 
| 33 | 
            +
                # warnings.warn("xFormers is not available (Attention)")
         | 
| 34 | 
            +
             | 
| 35 | 
            +
             | 
| 36 | 
            +
            class Attention(nn.Module):
         | 
| 37 | 
            +
                def __init__(
         | 
| 38 | 
            +
                    self,
         | 
| 39 | 
            +
                    dim: int,
         | 
| 40 | 
            +
                    num_heads: int = 8,
         | 
| 41 | 
            +
                    qkv_bias: bool = False,
         | 
| 42 | 
            +
                    proj_bias: bool = True,
         | 
| 43 | 
            +
                    attn_drop: float = 0.0,
         | 
| 44 | 
            +
                    proj_drop: float = 0.0,
         | 
| 45 | 
            +
                ) -> None:
         | 
| 46 | 
            +
                    super().__init__()
         | 
| 47 | 
            +
                    self.num_heads = num_heads
         | 
| 48 | 
            +
                    head_dim = dim // num_heads
         | 
| 49 | 
            +
                    self.scale = head_dim**-0.5
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                    self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
         | 
| 52 | 
            +
                    self.attn_drop = nn.Dropout(attn_drop)
         | 
| 53 | 
            +
                    self.proj = nn.Linear(dim, dim, bias=proj_bias)
         | 
| 54 | 
            +
                    self.proj_drop = nn.Dropout(proj_drop)
         | 
| 55 | 
            +
             | 
| 56 | 
            +
                def forward(self, x: Tensor, attn_bias=None) -> Tensor:
         | 
| 57 | 
            +
                    B, N, C = x.shape
         | 
| 58 | 
            +
                    qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
         | 
| 59 | 
            +
                    
         | 
| 60 | 
            +
                    q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
         | 
| 61 | 
            +
                    attn = q @ k.transpose(-2, -1)
         | 
| 62 | 
            +
             | 
| 63 | 
            +
                    attn = attn.softmax(dim=-1)
         | 
| 64 | 
            +
                    attn = self.attn_drop(attn)
         | 
| 65 | 
            +
             | 
| 66 | 
            +
                    x = (attn @ v).transpose(1, 2).reshape(B, N, C)
         | 
| 67 | 
            +
                    x = self.proj(x)
         | 
| 68 | 
            +
                    x = self.proj_drop(x)
         | 
| 69 | 
            +
                    return x
         | 
| 70 | 
            +
             | 
| 71 | 
            +
             | 
| 72 | 
            +
            class MemEffAttention(Attention):
         | 
| 73 | 
            +
                def forward(self, x: Tensor, attn_bias=None) -> Tensor:
         | 
| 74 | 
            +
                    if not XFORMERS_AVAILABLE:
         | 
| 75 | 
            +
                        if attn_bias is not None:
         | 
| 76 | 
            +
                            raise AssertionError("xFormers is required for using nested tensors")
         | 
| 77 | 
            +
                        return super().forward(x)
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                    B, N, C = x.shape
         | 
| 80 | 
            +
                    qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
         | 
| 81 | 
            +
             | 
| 82 | 
            +
                    # q, k, v = unbind(qkv, 2)
         | 
| 83 | 
            +
                    q, k, v = [qkv[:,:,i] for i in range(3)]
         | 
| 84 | 
            +
             | 
| 85 | 
            +
                    x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
         | 
| 86 | 
            +
                    x = x.reshape([B, N, C])
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                    x = self.proj(x)
         | 
| 89 | 
            +
                    x = self.proj_drop(x)
         | 
| 90 | 
            +
                    return x
         | 
| 91 | 
            +
             | 
| 92 | 
            +
             | 
| 93 | 
            +
                
         | 
| 94 | 
            +
            class FlashAttention(Attention):
         | 
| 95 | 
            +
                def forward(self, x: Tensor, attn_bias=None) -> Tensor:
         | 
| 96 | 
            +
                    B, N, C = x.shape
         | 
| 97 | 
            +
                    qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).transpose(1, 3)
         | 
| 98 | 
            +
             | 
| 99 | 
            +
                    # q, k, v = unbind(qkv, 2)
         | 
| 100 | 
            +
                    q, k, v = [qkv[:,:,i] for i in range(3)]
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                    if q.dtype == torch.bfloat16:
         | 
| 103 | 
            +
                        with nn.attention.sdpa_kernel(SDPBackend.FLASH_ATTENTION):
         | 
| 104 | 
            +
                            x = scaled_dot_product_attention(q, k, v)
         | 
| 105 | 
            +
                    else:
         | 
| 106 | 
            +
                        with nn.attention.sdpa_kernel([SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION]):
         | 
| 107 | 
            +
                            x = scaled_dot_product_attention(q, k, v)
         | 
| 108 | 
            +
             | 
| 109 | 
            +
                    x = x.transpose(1, 2).reshape([B, N, C])
         | 
| 110 | 
            +
             | 
| 111 | 
            +
                    x = self.proj(x)
         | 
| 112 | 
            +
                    x = self.proj_drop(x)
         | 
| 113 | 
            +
                    return x
         | 
| 114 | 
            +
             | 
| 115 | 
            +
             | 
| 116 | 
            +
            """
         | 
| 117 | 
            +
            Following is written by GPT-4o
         | 
| 118 | 
            +
            """
         | 
| 119 | 
            +
            class CrossAttentionRope(nn.Module):
         | 
| 120 | 
            +
                def __init__(
         | 
| 121 | 
            +
                    self,
         | 
| 122 | 
            +
                    dim: int,
         | 
| 123 | 
            +
                    num_heads: int = 8,
         | 
| 124 | 
            +
                    qkv_bias: bool = False,
         | 
| 125 | 
            +
                    proj_bias: bool = True,
         | 
| 126 | 
            +
                    attn_drop: float = 0.0,
         | 
| 127 | 
            +
                    proj_drop: float = 0.0,
         | 
| 128 | 
            +
                    qk_norm: bool = False,
         | 
| 129 | 
            +
                    norm_layer: nn.Module = nn.LayerNorm,
         | 
| 130 | 
            +
                    rope=None,
         | 
| 131 | 
            +
                ) -> None:
         | 
| 132 | 
            +
                    super().__init__()
         | 
| 133 | 
            +
                    self.num_heads = num_heads
         | 
| 134 | 
            +
                    head_dim = dim // num_heads
         | 
| 135 | 
            +
                    self.scale = head_dim**-0.5
         | 
| 136 | 
            +
             | 
| 137 | 
            +
                    # Separate projection layers for query, key, and value
         | 
| 138 | 
            +
                    self.q_proj = nn.Linear(dim, dim, bias=qkv_bias)
         | 
| 139 | 
            +
                    self.k_proj = nn.Linear(dim, dim, bias=qkv_bias)
         | 
| 140 | 
            +
                    self.v_proj = nn.Linear(dim, dim, bias=qkv_bias)
         | 
| 141 | 
            +
             | 
| 142 | 
            +
                    self.q_norm = norm_layer(head_dim) if qk_norm else nn.Identity()
         | 
| 143 | 
            +
                    self.k_norm = norm_layer(head_dim) if qk_norm else nn.Identity()
         | 
| 144 | 
            +
             | 
| 145 | 
            +
                    self.attn_drop = nn.Dropout(attn_drop)
         | 
| 146 | 
            +
                    self.proj = nn.Linear(dim, dim, bias=proj_bias)
         | 
| 147 | 
            +
                    self.proj_drop = nn.Dropout(proj_drop)
         | 
| 148 | 
            +
             | 
| 149 | 
            +
                    self.rope = rope
         | 
| 150 | 
            +
             | 
| 151 | 
            +
                def forward(self, query: Tensor, key: Tensor, value: Tensor, attn_bias=None, qpos=None, kpos=None) -> Tensor:
         | 
| 152 | 
            +
                    """
         | 
| 153 | 
            +
                    Args:
         | 
| 154 | 
            +
                        query: Tensor of shape (B, N, C), input query
         | 
| 155 | 
            +
                        key: Tensor of shape (B, M, C), input key
         | 
| 156 | 
            +
                        value: Tensor of shape (B, M, C), input value
         | 
| 157 | 
            +
                        attn_bias: Optional tensor for attention bias
         | 
| 158 | 
            +
                    Returns:
         | 
| 159 | 
            +
                        Tensor of shape (B, N, C), output of cross-attention
         | 
| 160 | 
            +
                    """
         | 
| 161 | 
            +
                    B, N, C = query.shape
         | 
| 162 | 
            +
                    _, M, _ = key.shape
         | 
| 163 | 
            +
             | 
| 164 | 
            +
                    # Project query, key, and value
         | 
| 165 | 
            +
                    q = self.q_proj(query).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
         | 
| 166 | 
            +
                    k = self.k_proj(key).reshape(B, M, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
         | 
| 167 | 
            +
                    v = self.v_proj(value).reshape(B, M, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
         | 
| 168 | 
            +
                    q, k = self.q_norm(q).to(v.dtype), self.k_norm(k).to(v.dtype)
         | 
| 169 | 
            +
             | 
| 170 | 
            +
                    if self.rope is not None:
         | 
| 171 | 
            +
                        q = self.rope(q, qpos)
         | 
| 172 | 
            +
                        k = self.rope(k, kpos)
         | 
| 173 | 
            +
             | 
| 174 | 
            +
                    # Scale query
         | 
| 175 | 
            +
                    q = q * self.scale
         | 
| 176 | 
            +
             | 
| 177 | 
            +
                    # Compute attention scores
         | 
| 178 | 
            +
                    attn = q @ k.transpose(-2, -1)  # (B, num_heads, N, M)
         | 
| 179 | 
            +
                    if attn_bias is not None:
         | 
| 180 | 
            +
                        attn = attn + attn_bias
         | 
| 181 | 
            +
             | 
| 182 | 
            +
                    attn = attn.softmax(dim=-1)
         | 
| 183 | 
            +
                    attn = self.attn_drop(attn)
         | 
| 184 | 
            +
             | 
| 185 | 
            +
                    # Compute attention output
         | 
| 186 | 
            +
                    x = (attn @ v).transpose(1, 2).reshape(B, N, C)  # (B, N, C)
         | 
| 187 | 
            +
             | 
| 188 | 
            +
                    # Final projection
         | 
| 189 | 
            +
                    x = self.proj(x)
         | 
| 190 | 
            +
                    x = self.proj_drop(x)
         | 
| 191 | 
            +
                    return x
         | 
| 192 | 
            +
             | 
| 193 | 
            +
             | 
| 194 | 
            +
            class MemEffCrossAttentionRope(CrossAttentionRope):
         | 
| 195 | 
            +
                def forward(self, query: Tensor, key: Tensor, value: Tensor, attn_bias=None, qpos=None, kpos=None) -> Tensor:
         | 
| 196 | 
            +
                    """
         | 
| 197 | 
            +
                    Args:
         | 
| 198 | 
            +
                        query: Tensor of shape (B, N, C), input query
         | 
| 199 | 
            +
                        key: Tensor of shape (B, M, C), input key
         | 
| 200 | 
            +
                        value: Tensor of shape (B, M, C), input value
         | 
| 201 | 
            +
                        attn_bias: Optional tensor for attention bias
         | 
| 202 | 
            +
                    Returns:
         | 
| 203 | 
            +
                        Tensor of shape (B, N, C), output of cross-attention
         | 
| 204 | 
            +
                    """
         | 
| 205 | 
            +
                    if not XFORMERS_AVAILABLE:
         | 
| 206 | 
            +
                        if attn_bias is not None:
         | 
| 207 | 
            +
                            raise AssertionError("xFormers is required for using nested tensors")
         | 
| 208 | 
            +
                        return super().forward(query, key, value, attn_bias)
         | 
| 209 | 
            +
             | 
| 210 | 
            +
                    B, N, C = query.shape
         | 
| 211 | 
            +
                    _, M, _ = key.shape
         | 
| 212 | 
            +
             | 
| 213 | 
            +
                    # Project query, key, and value
         | 
| 214 | 
            +
                    q = self.q_proj(query).reshape(B, N, self.num_heads, C // self.num_heads)
         | 
| 215 | 
            +
                    k = self.k_proj(key).reshape(B, M, self.num_heads, C // self.num_heads)
         | 
| 216 | 
            +
                    v = self.v_proj(value).reshape(B, M, self.num_heads, C // self.num_heads)
         | 
| 217 | 
            +
             | 
| 218 | 
            +
                    q = q.transpose(1, 2)
         | 
| 219 | 
            +
                    k = k.transpose(1, 2)
         | 
| 220 | 
            +
                    q, k = self.q_norm(q).to(v.dtype), self.k_norm(k).to(v.dtype)
         | 
| 221 | 
            +
             | 
| 222 | 
            +
                    if self.rope is not None:
         | 
| 223 | 
            +
                        q = self.rope(q, qpos)
         | 
| 224 | 
            +
                        k = self.rope(k, kpos)
         | 
| 225 | 
            +
             | 
| 226 | 
            +
                    q = q.transpose(1, 2)
         | 
| 227 | 
            +
                    k = k.transpose(1, 2)
         | 
| 228 | 
            +
             | 
| 229 | 
            +
                    # Compute memory-efficient attention
         | 
| 230 | 
            +
                    x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
         | 
| 231 | 
            +
                    x = x.reshape(B, N, C)
         | 
| 232 | 
            +
             | 
| 233 | 
            +
                    # Final projection
         | 
| 234 | 
            +
                    x = self.proj(x)
         | 
| 235 | 
            +
                    x = self.proj_drop(x)
         | 
| 236 | 
            +
                    return x
         | 
| 237 | 
            +
             | 
| 238 | 
            +
            class AttentionRope(nn.Module):
         | 
| 239 | 
            +
                def __init__(
         | 
| 240 | 
            +
                    self,
         | 
| 241 | 
            +
                    dim: int,
         | 
| 242 | 
            +
                    num_heads: int = 8,
         | 
| 243 | 
            +
                    qkv_bias: bool = False,
         | 
| 244 | 
            +
                    proj_bias: bool = True,
         | 
| 245 | 
            +
                    attn_drop: float = 0.0,
         | 
| 246 | 
            +
                    proj_drop: float = 0.0,
         | 
| 247 | 
            +
                    qk_norm: bool = False,
         | 
| 248 | 
            +
                    norm_layer: nn.Module = nn.LayerNorm,
         | 
| 249 | 
            +
                    rope=None
         | 
| 250 | 
            +
                ) -> None:
         | 
| 251 | 
            +
                    super().__init__()
         | 
| 252 | 
            +
                    self.num_heads = num_heads
         | 
| 253 | 
            +
                    head_dim = dim // num_heads
         | 
| 254 | 
            +
                    self.scale = head_dim**-0.5
         | 
| 255 | 
            +
             | 
| 256 | 
            +
                    self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
         | 
| 257 | 
            +
                    self.attn_drop = nn.Dropout(attn_drop)
         | 
| 258 | 
            +
                    self.proj = nn.Linear(dim, dim, bias=proj_bias)
         | 
| 259 | 
            +
                    self.proj_drop = nn.Dropout(proj_drop)
         | 
| 260 | 
            +
             | 
| 261 | 
            +
                    self.q_norm = norm_layer(head_dim) if qk_norm else nn.Identity()
         | 
| 262 | 
            +
                    self.k_norm = norm_layer(head_dim) if qk_norm else nn.Identity()
         | 
| 263 | 
            +
             | 
| 264 | 
            +
                    self.rope = rope
         | 
| 265 | 
            +
             | 
| 266 | 
            +
                def forward(self, x: Tensor, attn_bias=None, xpos=None) -> Tensor:
         | 
| 267 | 
            +
                    B, N, C = x.shape
         | 
| 268 | 
            +
                    qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
         | 
| 269 | 
            +
                    q, k, v = qkv[0], qkv[1], qkv[2]
         | 
| 270 | 
            +
                    q, k = self.q_norm(q).to(v.dtype), self.k_norm(k).to(v.dtype)
         | 
| 271 | 
            +
             | 
| 272 | 
            +
                    if self.rope is not None:
         | 
| 273 | 
            +
                        q = self.rope(q, xpos)
         | 
| 274 | 
            +
                        k = self.rope(k, xpos)
         | 
| 275 | 
            +
                    
         | 
| 276 | 
            +
                    q = q * self.scale
         | 
| 277 | 
            +
                    attn = q @ k.transpose(-2, -1)
         | 
| 278 | 
            +
             | 
| 279 | 
            +
                    attn = attn.softmax(dim=-1)
         | 
| 280 | 
            +
                    attn = self.attn_drop(attn)
         | 
| 281 | 
            +
             | 
| 282 | 
            +
                    x = (attn @ v).transpose(1, 2).reshape(B, N, C)
         | 
| 283 | 
            +
                    x = self.proj(x)
         | 
| 284 | 
            +
                    x = self.proj_drop(x)
         | 
| 285 | 
            +
                    return x
         | 
| 286 | 
            +
             | 
| 287 | 
            +
             | 
| 288 | 
            +
            class MemEffAttentionRope(AttentionRope):
         | 
| 289 | 
            +
                def forward(self, x: Tensor, attn_bias=None, xpos=None) -> Tensor:
         | 
| 290 | 
            +
                    if not XFORMERS_AVAILABLE:
         | 
| 291 | 
            +
                        if attn_bias is not None:
         | 
| 292 | 
            +
                            raise AssertionError("xFormers is required for using nested tensors")
         | 
| 293 | 
            +
                        return super().forward(x)
         | 
| 294 | 
            +
             | 
| 295 | 
            +
                    B, N, C = x.shape
         | 
| 296 | 
            +
                    qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
         | 
| 297 | 
            +
                    
         | 
| 298 | 
            +
                    qkv = qkv.transpose(1, 3)
         | 
| 299 | 
            +
                    # q, k, v = unbind(qkv, 2)
         | 
| 300 | 
            +
                    q, k, v = [qkv[:,:,i] for i in range(3)]
         | 
| 301 | 
            +
                    q, k = self.q_norm(q).to(v.dtype), self.k_norm(k).to(v.dtype)
         | 
| 302 | 
            +
             | 
| 303 | 
            +
                    if self.rope is not None:
         | 
| 304 | 
            +
                        q = self.rope(q, xpos)
         | 
| 305 | 
            +
                        k = self.rope(k, xpos)
         | 
| 306 | 
            +
             | 
| 307 | 
            +
                    q = q.transpose(1, 2)
         | 
| 308 | 
            +
                    k = k.transpose(1, 2)
         | 
| 309 | 
            +
                    v = v.transpose(1, 2)
         | 
| 310 | 
            +
             | 
| 311 | 
            +
                    x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
         | 
| 312 | 
            +
                    x = x.reshape([B, N, C])
         | 
| 313 | 
            +
             | 
| 314 | 
            +
                    # score_matrix = (q.permute(0, 2, 1, 3) * self.scale @ k.permute(0, 2, 1, 3).transpose(-2, -1)).sum(dim=1).reshape(frame_num, 261, frame_num, 261).mean(dim=[1, 3]).sum(1)         # for frame attention matrix
         | 
| 315 | 
            +
                    # global_valid_id = torch.where(score_matrix > 0)
         | 
| 316 | 
            +
                    # score_matrix = (q.permute(0, 2, 1, 3) * self.scale @ k.permute(0, 2, 1, 3).transpose(-2, -1)).sum(dim=1)
         | 
| 317 | 
            +
             | 
| 318 | 
            +
                    x = self.proj(x)
         | 
| 319 | 
            +
                    x = self.proj_drop(x)
         | 
| 320 | 
            +
                    return x
         | 
| 321 | 
            +
             | 
| 322 | 
            +
                
         | 
| 323 | 
            +
            class FlashAttentionRope(AttentionRope):
         | 
| 324 | 
            +
                def forward(self, x: Tensor, attn_bias=None, xpos=None) -> Tensor:
         | 
| 325 | 
            +
                    B, N, C = x.shape
         | 
| 326 | 
            +
                    qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).transpose(1, 3)
         | 
| 327 | 
            +
             | 
| 328 | 
            +
                    # q, k, v = unbind(qkv, 2)
         | 
| 329 | 
            +
                    q, k, v = [qkv[:,:,i] for i in range(3)]
         | 
| 330 | 
            +
                    q, k = self.q_norm(q).to(v.dtype), self.k_norm(k).to(v.dtype)
         | 
| 331 | 
            +
             | 
| 332 | 
            +
                    if self.rope is not None:
         | 
| 333 | 
            +
                        q = self.rope(q, xpos)
         | 
| 334 | 
            +
                        k = self.rope(k, xpos)
         | 
| 335 | 
            +
             | 
| 336 | 
            +
                    if q.dtype == torch.bfloat16:
         | 
| 337 | 
            +
                        with nn.attention.sdpa_kernel(SDPBackend.FLASH_ATTENTION):
         | 
| 338 | 
            +
                            x = scaled_dot_product_attention(q, k, v)
         | 
| 339 | 
            +
                    else:
         | 
| 340 | 
            +
                        with nn.attention.sdpa_kernel([SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION]):
         | 
| 341 | 
            +
                            x = scaled_dot_product_attention(q, k, v)
         | 
| 342 | 
            +
             | 
| 343 | 
            +
                    x = x.transpose(1, 2).reshape([B, N, C])
         | 
| 344 | 
            +
             | 
| 345 | 
            +
                    x = self.proj(x)
         | 
| 346 | 
            +
                    x = self.proj_drop(x)
         | 
| 347 | 
            +
                    return x
         | 
| 348 | 
            +
             | 
| 349 | 
            +
            def get_attn_score(blk_class, x, frame_num, token_length, xpos=None):
         | 
| 350 | 
            +
                x = blk_class.norm1(x)
         | 
| 351 | 
            +
                
         | 
| 352 | 
            +
                B, N, C = x.shape
         | 
| 353 | 
            +
                qkv = blk_class.attn.qkv(x).reshape(B, N, 3, blk_class.attn.num_heads, C // blk_class.attn.num_heads)
         | 
| 354 | 
            +
                
         | 
| 355 | 
            +
                qkv = qkv.transpose(1, 3)
         | 
| 356 | 
            +
                # q, k, v = unbind(qkv, 2)
         | 
| 357 | 
            +
                q, k, v = [qkv[:,:,i] for i in range(3)]
         | 
| 358 | 
            +
                q, k = blk_class.attn.q_norm(q).to(v.dtype), blk_class.attn.k_norm(k).to(v.dtype)
         | 
| 359 | 
            +
             | 
| 360 | 
            +
                if blk_class.attn.rope is not None:
         | 
| 361 | 
            +
                    q = blk_class.attn.rope(q, xpos)
         | 
| 362 | 
            +
                    k = blk_class.attn.rope(k, xpos)
         | 
| 363 | 
            +
             | 
| 364 | 
            +
                q = q.transpose(1, 2)
         | 
| 365 | 
            +
                k = k.transpose(1, 2)
         | 
| 366 | 
            +
             | 
| 367 | 
            +
                score = (q.permute(0, 2, 1, 3) * blk_class.attn.scale @ k.permute(0, 2, 1, 3).transpose(-2, -1)).sum(dim=1).reshape(B, frame_num, token_length, frame_num, token_length).mean(dim=[2, 4]).sum(-1)
         | 
| 368 | 
            +
             | 
| 369 | 
            +
                return score
         | 
    	
        pi3/models/layers/block.py
    ADDED
    
    | @@ -0,0 +1,406 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Meta Platforms, Inc. and affiliates.
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # This source code is licensed under the Apache License, Version 2.0
         | 
| 4 | 
            +
            # found in the LICENSE file in the root directory of this source tree.
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            # References:
         | 
| 7 | 
            +
            #   https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
         | 
| 8 | 
            +
            #   https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            import logging
         | 
| 11 | 
            +
            import os
         | 
| 12 | 
            +
            from typing import Callable, List, Any, Tuple, Dict
         | 
| 13 | 
            +
            import warnings
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            import torch
         | 
| 16 | 
            +
            from torch import nn, Tensor
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            from .attention import Attention, MemEffAttention, CrossAttentionRope, MemEffCrossAttentionRope, FlashAttentionRope
         | 
| 19 | 
            +
            from ..dinov2.layers.drop_path import DropPath
         | 
| 20 | 
            +
            from ..dinov2.layers.layer_scale import LayerScale
         | 
| 21 | 
            +
            from ..dinov2.layers.mlp import Mlp
         | 
| 22 | 
            +
             | 
| 23 | 
            +
             | 
| 24 | 
            +
            XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
         | 
| 25 | 
            +
            try:
         | 
| 26 | 
            +
                if XFORMERS_ENABLED:
         | 
| 27 | 
            +
                    from xformers.ops import fmha, scaled_index_add, index_select_cat
         | 
| 28 | 
            +
             | 
| 29 | 
            +
                    XFORMERS_AVAILABLE = True
         | 
| 30 | 
            +
                    # warnings.warn("xFormers is available (Block)")
         | 
| 31 | 
            +
                else:
         | 
| 32 | 
            +
                    # warnings.warn("xFormers is disabled (Block)")
         | 
| 33 | 
            +
                    raise ImportError
         | 
| 34 | 
            +
            except ImportError:
         | 
| 35 | 
            +
                XFORMERS_AVAILABLE = False
         | 
| 36 | 
            +
                # warnings.warn("xFormers is not available (Block)")
         | 
| 37 | 
            +
             | 
| 38 | 
            +
             | 
| 39 | 
            +
            class Block(nn.Module):
         | 
| 40 | 
            +
                def __init__(
         | 
| 41 | 
            +
                    self,
         | 
| 42 | 
            +
                    dim: int,
         | 
| 43 | 
            +
                    num_heads: int,
         | 
| 44 | 
            +
                    mlp_ratio: float = 4.0,
         | 
| 45 | 
            +
                    qkv_bias: bool = False,
         | 
| 46 | 
            +
                    proj_bias: bool = True,
         | 
| 47 | 
            +
                    ffn_bias: bool = True,
         | 
| 48 | 
            +
                    drop: float = 0.0,
         | 
| 49 | 
            +
                    attn_drop: float = 0.0,
         | 
| 50 | 
            +
                    init_values=None,
         | 
| 51 | 
            +
                    drop_path: float = 0.0,
         | 
| 52 | 
            +
                    act_layer: Callable[..., nn.Module] = nn.GELU,
         | 
| 53 | 
            +
                    norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
         | 
| 54 | 
            +
                    attn_class: Callable[..., nn.Module] = Attention,
         | 
| 55 | 
            +
                    ffn_layer: Callable[..., nn.Module] = Mlp,
         | 
| 56 | 
            +
                ) -> None:
         | 
| 57 | 
            +
                    super().__init__()
         | 
| 58 | 
            +
                    # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
         | 
| 59 | 
            +
                    self.norm1 = norm_layer(dim)
         | 
| 60 | 
            +
                    self.attn = attn_class(
         | 
| 61 | 
            +
                        dim,
         | 
| 62 | 
            +
                        num_heads=num_heads,
         | 
| 63 | 
            +
                        qkv_bias=qkv_bias,
         | 
| 64 | 
            +
                        proj_bias=proj_bias,
         | 
| 65 | 
            +
                        attn_drop=attn_drop,
         | 
| 66 | 
            +
                        proj_drop=drop,
         | 
| 67 | 
            +
                    )
         | 
| 68 | 
            +
             | 
| 69 | 
            +
                    self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
         | 
| 70 | 
            +
                    self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
         | 
| 71 | 
            +
             | 
| 72 | 
            +
                    self.norm2 = norm_layer(dim)
         | 
| 73 | 
            +
                    mlp_hidden_dim = int(dim * mlp_ratio)
         | 
| 74 | 
            +
                    self.mlp = ffn_layer(
         | 
| 75 | 
            +
                        in_features=dim,
         | 
| 76 | 
            +
                        hidden_features=mlp_hidden_dim,
         | 
| 77 | 
            +
                        act_layer=act_layer,
         | 
| 78 | 
            +
                        drop=drop,
         | 
| 79 | 
            +
                        bias=ffn_bias,
         | 
| 80 | 
            +
                    )
         | 
| 81 | 
            +
                    self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
         | 
| 82 | 
            +
                    self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
         | 
| 83 | 
            +
             | 
| 84 | 
            +
                    self.sample_drop_ratio = drop_path
         | 
| 85 | 
            +
             | 
| 86 | 
            +
                def forward(self, x: Tensor) -> Tensor:
         | 
| 87 | 
            +
                    def attn_residual_func(x: Tensor) -> Tensor:
         | 
| 88 | 
            +
                        return self.ls1(self.attn(self.norm1(x)))
         | 
| 89 | 
            +
             | 
| 90 | 
            +
                    def ffn_residual_func(x: Tensor) -> Tensor:
         | 
| 91 | 
            +
                        return self.ls2(self.mlp(self.norm2(x)))
         | 
| 92 | 
            +
             | 
| 93 | 
            +
                    if self.training and self.sample_drop_ratio > 0.1:
         | 
| 94 | 
            +
                        # the overhead is compensated only for a drop path rate larger than 0.1
         | 
| 95 | 
            +
                        x = drop_add_residual_stochastic_depth(
         | 
| 96 | 
            +
                            x,
         | 
| 97 | 
            +
                            residual_func=attn_residual_func,
         | 
| 98 | 
            +
                            sample_drop_ratio=self.sample_drop_ratio,
         | 
| 99 | 
            +
                        )
         | 
| 100 | 
            +
                        x = drop_add_residual_stochastic_depth(
         | 
| 101 | 
            +
                            x,
         | 
| 102 | 
            +
                            residual_func=ffn_residual_func,
         | 
| 103 | 
            +
                            sample_drop_ratio=self.sample_drop_ratio,
         | 
| 104 | 
            +
                        )
         | 
| 105 | 
            +
                    elif self.training and self.sample_drop_ratio > 0.0:
         | 
| 106 | 
            +
                        x = x + self.drop_path1(attn_residual_func(x))
         | 
| 107 | 
            +
                        x = x + self.drop_path1(ffn_residual_func(x))  # FIXME: drop_path2
         | 
| 108 | 
            +
                    else:
         | 
| 109 | 
            +
                        x = x + attn_residual_func(x)
         | 
| 110 | 
            +
                        x = x + ffn_residual_func(x)
         | 
| 111 | 
            +
                    return x
         | 
| 112 | 
            +
             | 
| 113 | 
            +
             | 
| 114 | 
            +
            def drop_add_residual_stochastic_depth(
         | 
| 115 | 
            +
                x: Tensor,
         | 
| 116 | 
            +
                residual_func: Callable[[Tensor], Tensor],
         | 
| 117 | 
            +
                sample_drop_ratio: float = 0.0,
         | 
| 118 | 
            +
            ) -> Tensor:
         | 
| 119 | 
            +
                # 1) extract subset using permutation
         | 
| 120 | 
            +
                b, n, d = x.shape
         | 
| 121 | 
            +
                sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
         | 
| 122 | 
            +
                brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
         | 
| 123 | 
            +
                x_subset = x[brange]
         | 
| 124 | 
            +
             | 
| 125 | 
            +
                # 2) apply residual_func to get residual
         | 
| 126 | 
            +
                residual = residual_func(x_subset)
         | 
| 127 | 
            +
             | 
| 128 | 
            +
                x_flat = x.flatten(1)
         | 
| 129 | 
            +
                residual = residual.flatten(1)
         | 
| 130 | 
            +
             | 
| 131 | 
            +
                residual_scale_factor = b / sample_subset_size
         | 
| 132 | 
            +
             | 
| 133 | 
            +
                # 3) add the residual
         | 
| 134 | 
            +
                x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
         | 
| 135 | 
            +
                return x_plus_residual.view_as(x)
         | 
| 136 | 
            +
             | 
| 137 | 
            +
             | 
| 138 | 
            +
            def get_branges_scales(x, sample_drop_ratio=0.0):
         | 
| 139 | 
            +
                b, n, d = x.shape
         | 
| 140 | 
            +
                sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
         | 
| 141 | 
            +
                brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
         | 
| 142 | 
            +
                residual_scale_factor = b / sample_subset_size
         | 
| 143 | 
            +
                return brange, residual_scale_factor
         | 
| 144 | 
            +
             | 
| 145 | 
            +
             | 
| 146 | 
            +
            def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
         | 
| 147 | 
            +
                if scaling_vector is None:
         | 
| 148 | 
            +
                    x_flat = x.flatten(1)
         | 
| 149 | 
            +
                    residual = residual.flatten(1)
         | 
| 150 | 
            +
                    x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
         | 
| 151 | 
            +
                else:
         | 
| 152 | 
            +
                    x_plus_residual = scaled_index_add(
         | 
| 153 | 
            +
                        x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
         | 
| 154 | 
            +
                    )
         | 
| 155 | 
            +
                return x_plus_residual
         | 
| 156 | 
            +
             | 
| 157 | 
            +
             | 
| 158 | 
            +
            attn_bias_cache: Dict[Tuple, Any] = {}
         | 
| 159 | 
            +
             | 
| 160 | 
            +
             | 
| 161 | 
            +
            def get_attn_bias_and_cat(x_list, branges=None):
         | 
| 162 | 
            +
                """
         | 
| 163 | 
            +
                this will perform the index select, cat the tensors, and provide the attn_bias from cache
         | 
| 164 | 
            +
                """
         | 
| 165 | 
            +
                batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
         | 
| 166 | 
            +
                all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
         | 
| 167 | 
            +
                if all_shapes not in attn_bias_cache.keys():
         | 
| 168 | 
            +
                    seqlens = []
         | 
| 169 | 
            +
                    for b, x in zip(batch_sizes, x_list):
         | 
| 170 | 
            +
                        for _ in range(b):
         | 
| 171 | 
            +
                            seqlens.append(x.shape[1])
         | 
| 172 | 
            +
                    attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
         | 
| 173 | 
            +
                    attn_bias._batch_sizes = batch_sizes
         | 
| 174 | 
            +
                    attn_bias_cache[all_shapes] = attn_bias
         | 
| 175 | 
            +
             | 
| 176 | 
            +
                if branges is not None:
         | 
| 177 | 
            +
                    cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
         | 
| 178 | 
            +
                else:
         | 
| 179 | 
            +
                    tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
         | 
| 180 | 
            +
                    cat_tensors = torch.cat(tensors_bs1, dim=1)
         | 
| 181 | 
            +
             | 
| 182 | 
            +
                return attn_bias_cache[all_shapes], cat_tensors
         | 
| 183 | 
            +
             | 
| 184 | 
            +
             | 
| 185 | 
            +
            def drop_add_residual_stochastic_depth_list(
         | 
| 186 | 
            +
                x_list: List[Tensor],
         | 
| 187 | 
            +
                residual_func: Callable[[Tensor, Any], Tensor],
         | 
| 188 | 
            +
                sample_drop_ratio: float = 0.0,
         | 
| 189 | 
            +
                scaling_vector=None,
         | 
| 190 | 
            +
            ) -> Tensor:
         | 
| 191 | 
            +
                # 1) generate random set of indices for dropping samples in the batch
         | 
| 192 | 
            +
                branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
         | 
| 193 | 
            +
                branges = [s[0] for s in branges_scales]
         | 
| 194 | 
            +
                residual_scale_factors = [s[1] for s in branges_scales]
         | 
| 195 | 
            +
             | 
| 196 | 
            +
                # 2) get attention bias and index+concat the tensors
         | 
| 197 | 
            +
                attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
         | 
| 198 | 
            +
             | 
| 199 | 
            +
                # 3) apply residual_func to get residual, and split the result
         | 
| 200 | 
            +
                residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias))  # type: ignore
         | 
| 201 | 
            +
             | 
| 202 | 
            +
                outputs = []
         | 
| 203 | 
            +
                for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
         | 
| 204 | 
            +
                    outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
         | 
| 205 | 
            +
                return outputs
         | 
| 206 | 
            +
             | 
| 207 | 
            +
             | 
| 208 | 
            +
            class NestedTensorBlock(Block):
         | 
| 209 | 
            +
                def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
         | 
| 210 | 
            +
                    """
         | 
| 211 | 
            +
                    x_list contains a list of tensors to nest together and run
         | 
| 212 | 
            +
                    """
         | 
| 213 | 
            +
                    assert isinstance(self.attn, MemEffAttention)
         | 
| 214 | 
            +
             | 
| 215 | 
            +
                    if self.training and self.sample_drop_ratio > 0.0:
         | 
| 216 | 
            +
             | 
| 217 | 
            +
                        def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
         | 
| 218 | 
            +
                            return self.attn(self.norm1(x), attn_bias=attn_bias)
         | 
| 219 | 
            +
             | 
| 220 | 
            +
                        def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
         | 
| 221 | 
            +
                            return self.mlp(self.norm2(x))
         | 
| 222 | 
            +
             | 
| 223 | 
            +
                        x_list = drop_add_residual_stochastic_depth_list(
         | 
| 224 | 
            +
                            x_list,
         | 
| 225 | 
            +
                            residual_func=attn_residual_func,
         | 
| 226 | 
            +
                            sample_drop_ratio=self.sample_drop_ratio,
         | 
| 227 | 
            +
                            scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
         | 
| 228 | 
            +
                        )
         | 
| 229 | 
            +
                        x_list = drop_add_residual_stochastic_depth_list(
         | 
| 230 | 
            +
                            x_list,
         | 
| 231 | 
            +
                            residual_func=ffn_residual_func,
         | 
| 232 | 
            +
                            sample_drop_ratio=self.sample_drop_ratio,
         | 
| 233 | 
            +
                            scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
         | 
| 234 | 
            +
                        )
         | 
| 235 | 
            +
                        return x_list
         | 
| 236 | 
            +
                    else:
         | 
| 237 | 
            +
             | 
| 238 | 
            +
                        def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
         | 
| 239 | 
            +
                            return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
         | 
| 240 | 
            +
             | 
| 241 | 
            +
                        def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
         | 
| 242 | 
            +
                            return self.ls2(self.mlp(self.norm2(x)))
         | 
| 243 | 
            +
             | 
| 244 | 
            +
                        attn_bias, x = get_attn_bias_and_cat(x_list)
         | 
| 245 | 
            +
                        x = x + attn_residual_func(x, attn_bias=attn_bias)
         | 
| 246 | 
            +
                        x = x + ffn_residual_func(x)
         | 
| 247 | 
            +
                        return attn_bias.split(x)
         | 
| 248 | 
            +
             | 
| 249 | 
            +
                def forward(self, x_or_x_list):
         | 
| 250 | 
            +
                    if isinstance(x_or_x_list, Tensor):
         | 
| 251 | 
            +
                        return super().forward(x_or_x_list)
         | 
| 252 | 
            +
                    elif isinstance(x_or_x_list, list):
         | 
| 253 | 
            +
                        if not XFORMERS_AVAILABLE:
         | 
| 254 | 
            +
                            raise AssertionError("xFormers is required for using nested tensors")
         | 
| 255 | 
            +
                        return self.forward_nested(x_or_x_list)
         | 
| 256 | 
            +
                    else:
         | 
| 257 | 
            +
                        raise AssertionError
         | 
| 258 | 
            +
             | 
| 259 | 
            +
            class BlockRope(nn.Module):
         | 
| 260 | 
            +
                def __init__(
         | 
| 261 | 
            +
                    self,
         | 
| 262 | 
            +
                    dim: int,
         | 
| 263 | 
            +
                    num_heads: int,
         | 
| 264 | 
            +
                    mlp_ratio: float = 4.0,
         | 
| 265 | 
            +
                    qkv_bias: bool = False,
         | 
| 266 | 
            +
                    proj_bias: bool = True,
         | 
| 267 | 
            +
                    ffn_bias: bool = True,
         | 
| 268 | 
            +
                    drop: float = 0.0,
         | 
| 269 | 
            +
                    attn_drop: float = 0.0,
         | 
| 270 | 
            +
                    init_values=None,
         | 
| 271 | 
            +
                    drop_path: float = 0.0,
         | 
| 272 | 
            +
                    act_layer: Callable[..., nn.Module] = nn.GELU,
         | 
| 273 | 
            +
                    norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
         | 
| 274 | 
            +
                    attn_class: Callable[..., nn.Module] = Attention,
         | 
| 275 | 
            +
                    ffn_layer: Callable[..., nn.Module] = Mlp,
         | 
| 276 | 
            +
                    qk_norm: bool=False,
         | 
| 277 | 
            +
                    rope=None
         | 
| 278 | 
            +
                ) -> None:
         | 
| 279 | 
            +
                    super().__init__()
         | 
| 280 | 
            +
                    # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
         | 
| 281 | 
            +
                    self.norm1 = norm_layer(dim)
         | 
| 282 | 
            +
                    self.attn = attn_class(
         | 
| 283 | 
            +
                        dim,
         | 
| 284 | 
            +
                        num_heads=num_heads,
         | 
| 285 | 
            +
                        qkv_bias=qkv_bias,
         | 
| 286 | 
            +
                        proj_bias=proj_bias,
         | 
| 287 | 
            +
                        attn_drop=attn_drop,
         | 
| 288 | 
            +
                        proj_drop=drop,
         | 
| 289 | 
            +
                        qk_norm=qk_norm,
         | 
| 290 | 
            +
                        rope=rope
         | 
| 291 | 
            +
                    )
         | 
| 292 | 
            +
             | 
| 293 | 
            +
                    self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
         | 
| 294 | 
            +
                    self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
         | 
| 295 | 
            +
             | 
| 296 | 
            +
                    self.norm2 = norm_layer(dim)
         | 
| 297 | 
            +
                    mlp_hidden_dim = int(dim * mlp_ratio)
         | 
| 298 | 
            +
                    self.mlp = ffn_layer(
         | 
| 299 | 
            +
                        in_features=dim,
         | 
| 300 | 
            +
                        hidden_features=mlp_hidden_dim,
         | 
| 301 | 
            +
                        act_layer=act_layer,
         | 
| 302 | 
            +
                        drop=drop,
         | 
| 303 | 
            +
                        bias=ffn_bias,
         | 
| 304 | 
            +
                    )
         | 
| 305 | 
            +
                    self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
         | 
| 306 | 
            +
                    self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
         | 
| 307 | 
            +
             | 
| 308 | 
            +
                    self.sample_drop_ratio = drop_path
         | 
| 309 | 
            +
             | 
| 310 | 
            +
                def forward(self, x: Tensor, xpos=None) -> Tensor:
         | 
| 311 | 
            +
                    def attn_residual_func(x: Tensor) -> Tensor:
         | 
| 312 | 
            +
                        return self.ls1(self.attn(self.norm1(x), xpos=xpos))
         | 
| 313 | 
            +
             | 
| 314 | 
            +
                    def ffn_residual_func(x: Tensor) -> Tensor:
         | 
| 315 | 
            +
                        return self.ls2(self.mlp(self.norm2(x)))
         | 
| 316 | 
            +
             | 
| 317 | 
            +
                    if self.training and self.sample_drop_ratio > 0.1:
         | 
| 318 | 
            +
                        # the overhead is compensated only for a drop path rate larger than 0.1
         | 
| 319 | 
            +
                        x = drop_add_residual_stochastic_depth(
         | 
| 320 | 
            +
                            x,
         | 
| 321 | 
            +
                            residual_func=attn_residual_func,
         | 
| 322 | 
            +
                            sample_drop_ratio=self.sample_drop_ratio,
         | 
| 323 | 
            +
                        )
         | 
| 324 | 
            +
                        x = drop_add_residual_stochastic_depth(
         | 
| 325 | 
            +
                            x,
         | 
| 326 | 
            +
                            residual_func=ffn_residual_func,
         | 
| 327 | 
            +
                            sample_drop_ratio=self.sample_drop_ratio,
         | 
| 328 | 
            +
                        )
         | 
| 329 | 
            +
                    elif self.training and self.sample_drop_ratio > 0.0:
         | 
| 330 | 
            +
                        x = x + self.drop_path1(attn_residual_func(x))
         | 
| 331 | 
            +
                        x = x + self.drop_path1(ffn_residual_func(x))  # FIXME: drop_path2
         | 
| 332 | 
            +
                    else:
         | 
| 333 | 
            +
                        x = x + attn_residual_func(x)
         | 
| 334 | 
            +
                        x = x + ffn_residual_func(x)
         | 
| 335 | 
            +
                    return x
         | 
| 336 | 
            +
             | 
| 337 | 
            +
             | 
| 338 | 
            +
            class CrossBlockRope(nn.Module):
         | 
| 339 | 
            +
                def __init__(
         | 
| 340 | 
            +
                    self,
         | 
| 341 | 
            +
                    dim: int,
         | 
| 342 | 
            +
                    num_heads: int,
         | 
| 343 | 
            +
                    mlp_ratio: float = 4.0,
         | 
| 344 | 
            +
                    qkv_bias: bool = False,
         | 
| 345 | 
            +
                    proj_bias: bool = True,
         | 
| 346 | 
            +
                    ffn_bias: bool = True,
         | 
| 347 | 
            +
                    act_layer: Callable[..., nn.Module] = nn.GELU,
         | 
| 348 | 
            +
                    norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
         | 
| 349 | 
            +
                    attn_class: Callable[..., nn.Module] = Attention,
         | 
| 350 | 
            +
                    cross_attn_class: Callable[..., nn.Module] = CrossAttentionRope,
         | 
| 351 | 
            +
                    ffn_layer: Callable[..., nn.Module] = Mlp,
         | 
| 352 | 
            +
                    init_values=None,
         | 
| 353 | 
            +
                    qk_norm: bool=False,
         | 
| 354 | 
            +
                    rope=None
         | 
| 355 | 
            +
                ) -> None:
         | 
| 356 | 
            +
                    super().__init__()
         | 
| 357 | 
            +
                    # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
         | 
| 358 | 
            +
                    self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
         | 
| 359 | 
            +
                    self.norm1 = norm_layer(dim)
         | 
| 360 | 
            +
                    self.attn = attn_class(
         | 
| 361 | 
            +
                        dim,
         | 
| 362 | 
            +
                        num_heads=num_heads,
         | 
| 363 | 
            +
                        qkv_bias=qkv_bias,
         | 
| 364 | 
            +
                        proj_bias=proj_bias,
         | 
| 365 | 
            +
                        rope=rope,
         | 
| 366 | 
            +
                        qk_norm=qk_norm
         | 
| 367 | 
            +
                    )
         | 
| 368 | 
            +
             | 
| 369 | 
            +
                    self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
         | 
| 370 | 
            +
                    self.ls_y = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
         | 
| 371 | 
            +
                    self.norm2 = norm_layer(dim)
         | 
| 372 | 
            +
                    self.norm_y = norm_layer(dim)
         | 
| 373 | 
            +
                    self.cross_attn = cross_attn_class(
         | 
| 374 | 
            +
                        dim,
         | 
| 375 | 
            +
                        num_heads=num_heads,
         | 
| 376 | 
            +
                        qkv_bias=qkv_bias,
         | 
| 377 | 
            +
                        proj_bias=proj_bias,
         | 
| 378 | 
            +
                        rope=rope,
         | 
| 379 | 
            +
                        qk_norm=qk_norm
         | 
| 380 | 
            +
                    )
         | 
| 381 | 
            +
             | 
| 382 | 
            +
                    self.norm3 = norm_layer(dim)
         | 
| 383 | 
            +
                    mlp_hidden_dim = int(dim * mlp_ratio)
         | 
| 384 | 
            +
                    self.mlp = ffn_layer(
         | 
| 385 | 
            +
                        in_features=dim,
         | 
| 386 | 
            +
                        hidden_features=mlp_hidden_dim,
         | 
| 387 | 
            +
                        act_layer=act_layer,
         | 
| 388 | 
            +
                        bias=ffn_bias,
         | 
| 389 | 
            +
                    )
         | 
| 390 | 
            +
             | 
| 391 | 
            +
                def forward(self, x: Tensor, y: Tensor, xpos=None, ypos=None) -> Tensor:
         | 
| 392 | 
            +
                    def attn_residual_func(x: Tensor) -> Tensor:
         | 
| 393 | 
            +
                        return self.ls1(self.attn(self.norm1(x), xpos=xpos))
         | 
| 394 | 
            +
             | 
| 395 | 
            +
                    def cross_attn_residual_func(x: Tensor, y: Tensor) -> Tensor:
         | 
| 396 | 
            +
                        return self.ls_y(self.cross_attn(self.norm2(x), y, y, qpos=xpos, kpos=ypos))
         | 
| 397 | 
            +
             | 
| 398 | 
            +
                    def ffn_residual_func(x: Tensor) -> Tensor:
         | 
| 399 | 
            +
                        return self.ls2(self.mlp(self.norm3(x)))
         | 
| 400 | 
            +
             | 
| 401 | 
            +
                    x = x + attn_residual_func(x)
         | 
| 402 | 
            +
                    y_ = self.norm_y(y)
         | 
| 403 | 
            +
                    x = x + cross_attn_residual_func(x, y_)
         | 
| 404 | 
            +
                    x = x + ffn_residual_func(x)
         | 
| 405 | 
            +
             | 
| 406 | 
            +
                    return x
         | 
    	
        pi3/models/layers/camera_head.py
    ADDED
    
    | @@ -0,0 +1,93 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            import torch.nn as nn
         | 
| 3 | 
            +
            from copy import deepcopy
         | 
| 4 | 
            +
            import torch.nn.functional as F
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            # code adapted from 'https://github.com/nianticlabs/marepo/blob/9a45e2bb07e5bb8cb997620088d352b439b13e0e/transformer/transformer.py#L172'
         | 
| 7 | 
            +
            class ResConvBlock(nn.Module):
         | 
| 8 | 
            +
                """
         | 
| 9 | 
            +
                1x1 convolution residual block
         | 
| 10 | 
            +
                """
         | 
| 11 | 
            +
                def __init__(self, in_channels, out_channels):
         | 
| 12 | 
            +
                    super().__init__()
         | 
| 13 | 
            +
                    self.in_channels = in_channels
         | 
| 14 | 
            +
                    self.out_channels = out_channels
         | 
| 15 | 
            +
                    self.head_skip = nn.Identity() if self.in_channels == self.out_channels else nn.Conv2d(self.in_channels, self.out_channels, 1, 1, 0)
         | 
| 16 | 
            +
                    # self.res_conv1 = nn.Conv2d(self.in_channels, self.out_channels, 1, 1, 0)
         | 
| 17 | 
            +
                    # self.res_conv2 = nn.Conv2d(self.out_channels, self.out_channels, 1, 1, 0)
         | 
| 18 | 
            +
                    # self.res_conv3 = nn.Conv2d(self.out_channels, self.out_channels, 1, 1, 0)
         | 
| 19 | 
            +
             | 
| 20 | 
            +
                    # change 1x1 convolution to linear
         | 
| 21 | 
            +
                    self.res_conv1 = nn.Linear(self.in_channels, self.out_channels)
         | 
| 22 | 
            +
                    self.res_conv2 = nn.Linear(self.out_channels, self.out_channels)
         | 
| 23 | 
            +
                    self.res_conv3 = nn.Linear(self.out_channels, self.out_channels)
         | 
| 24 | 
            +
             | 
| 25 | 
            +
                def forward(self, res):
         | 
| 26 | 
            +
                    x = F.relu(self.res_conv1(res))
         | 
| 27 | 
            +
                    x = F.relu(self.res_conv2(x))
         | 
| 28 | 
            +
                    x = F.relu(self.res_conv3(x))
         | 
| 29 | 
            +
                    res = self.head_skip(res) + x
         | 
| 30 | 
            +
                    return res
         | 
| 31 | 
            +
             | 
| 32 | 
            +
            class CameraHead(nn.Module):
         | 
| 33 | 
            +
                def __init__(self, dim=512):
         | 
| 34 | 
            +
                    super().__init__()
         | 
| 35 | 
            +
                    output_dim = dim
         | 
| 36 | 
            +
                    self.res_conv = nn.ModuleList([deepcopy(ResConvBlock(output_dim, output_dim)) 
         | 
| 37 | 
            +
                            for _ in range(2)])
         | 
| 38 | 
            +
                    self.avgpool = nn.AdaptiveAvgPool2d(1)
         | 
| 39 | 
            +
                    self.more_mlps = nn.Sequential(
         | 
| 40 | 
            +
                        nn.Linear(output_dim,output_dim),
         | 
| 41 | 
            +
                        nn.ReLU(),
         | 
| 42 | 
            +
                        nn.Linear(output_dim,output_dim),
         | 
| 43 | 
            +
                        nn.ReLU()
         | 
| 44 | 
            +
                        )
         | 
| 45 | 
            +
                    self.fc_t = nn.Linear(output_dim, 3)
         | 
| 46 | 
            +
                    self.fc_rot = nn.Linear(output_dim, 9)
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                def forward(self, feat, patch_h, patch_w):
         | 
| 49 | 
            +
                    BN, hw, c = feat.shape
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                    for i in range(2):
         | 
| 52 | 
            +
                        feat = self.res_conv[i](feat)
         | 
| 53 | 
            +
             | 
| 54 | 
            +
                    # feat = self.avgpool(feat)
         | 
| 55 | 
            +
                    feat = self.avgpool(feat.permute(0, 2, 1).reshape(BN, -1, patch_h, patch_w).contiguous())              ##########
         | 
| 56 | 
            +
                    feat = feat.view(feat.size(0), -1)
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                    feat = self.more_mlps(feat)  # [B, D_]
         | 
| 59 | 
            +
                    with torch.amp.autocast(device_type='cuda', enabled=False):
         | 
| 60 | 
            +
                        out_t = self.fc_t(feat.float())  # [B,3]
         | 
| 61 | 
            +
                        out_r = self.fc_rot(feat.float())  # [B,9]
         | 
| 62 | 
            +
                        pose = self.convert_pose_to_4x4(BN, out_r, out_t, feat.device)
         | 
| 63 | 
            +
             | 
| 64 | 
            +
                    return pose
         | 
| 65 | 
            +
             | 
| 66 | 
            +
                def convert_pose_to_4x4(self, B, out_r, out_t, device):
         | 
| 67 | 
            +
                    out_r = self.svd_orthogonalize(out_r)  # [N,3,3]
         | 
| 68 | 
            +
                    pose = torch.zeros((B, 4, 4), device=device)
         | 
| 69 | 
            +
                    pose[:, :3, :3] = out_r
         | 
| 70 | 
            +
                    pose[:, :3, 3] = out_t
         | 
| 71 | 
            +
                    pose[:, 3, 3] = 1.
         | 
| 72 | 
            +
                    return pose
         | 
| 73 | 
            +
             | 
| 74 | 
            +
                def svd_orthogonalize(self, m):
         | 
| 75 | 
            +
                    """Convert 9D representation to SO(3) using SVD orthogonalization.
         | 
| 76 | 
            +
             | 
| 77 | 
            +
                    Args:
         | 
| 78 | 
            +
                      m: [BATCH, 3, 3] 3x3 matrices.
         | 
| 79 | 
            +
             | 
| 80 | 
            +
                    Returns:
         | 
| 81 | 
            +
                      [BATCH, 3, 3] SO(3) rotation matrices.
         | 
| 82 | 
            +
                    """
         | 
| 83 | 
            +
                    if m.dim() < 3:
         | 
| 84 | 
            +
                        m = m.reshape((-1, 3, 3))
         | 
| 85 | 
            +
                    m_transpose = torch.transpose(torch.nn.functional.normalize(m, p=2, dim=-1), dim0=-1, dim1=-2)
         | 
| 86 | 
            +
                    u, s, v = torch.svd(m_transpose)
         | 
| 87 | 
            +
                    det = torch.det(torch.matmul(v, u.transpose(-2, -1)))
         | 
| 88 | 
            +
                    # Check orientation reflection.
         | 
| 89 | 
            +
                    r = torch.matmul(
         | 
| 90 | 
            +
                        torch.cat([v[:, :, :-1], v[:, :, -1:] * det.view(-1, 1, 1)], dim=2),
         | 
| 91 | 
            +
                        u.transpose(-2, -1)
         | 
| 92 | 
            +
                    )
         | 
| 93 | 
            +
                    return r
         | 
    	
        pi3/models/layers/pos_embed.py
    ADDED
    
    | @@ -0,0 +1,174 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (C) 2022-present Naver Corporation. All rights reserved.
         | 
| 2 | 
            +
            # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
         | 
| 3 | 
            +
             | 
| 4 | 
            +
             | 
| 5 | 
            +
            # --------------------------------------------------------
         | 
| 6 | 
            +
            # Position embedding utils
         | 
| 7 | 
            +
            # --------------------------------------------------------
         | 
| 8 | 
            +
             | 
| 9 | 
            +
             | 
| 10 | 
            +
             | 
| 11 | 
            +
            import numpy as np
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            import torch
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            # --------------------------------------------------------
         | 
| 16 | 
            +
            # 2D sine-cosine position embedding
         | 
| 17 | 
            +
            # References:
         | 
| 18 | 
            +
            # MAE: https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
         | 
| 19 | 
            +
            # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py
         | 
| 20 | 
            +
            # MoCo v3: https://github.com/facebookresearch/moco-v3
         | 
| 21 | 
            +
            # --------------------------------------------------------
         | 
| 22 | 
            +
            def get_2d_sincos_pos_embed(embed_dim, grid_size, n_cls_token=0):
         | 
| 23 | 
            +
                """
         | 
| 24 | 
            +
                grid_size: int of the grid height and width
         | 
| 25 | 
            +
                return:
         | 
| 26 | 
            +
                pos_embed: [grid_size*grid_size, embed_dim] or [n_cls_token+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
         | 
| 27 | 
            +
                """
         | 
| 28 | 
            +
                grid_h = np.arange(grid_size, dtype=np.float32)
         | 
| 29 | 
            +
                grid_w = np.arange(grid_size, dtype=np.float32)
         | 
| 30 | 
            +
                grid = np.meshgrid(grid_w, grid_h)  # here w goes first
         | 
| 31 | 
            +
                grid = np.stack(grid, axis=0)
         | 
| 32 | 
            +
             | 
| 33 | 
            +
                grid = grid.reshape([2, 1, grid_size, grid_size])
         | 
| 34 | 
            +
                pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
         | 
| 35 | 
            +
                if n_cls_token>0:
         | 
| 36 | 
            +
                    pos_embed = np.concatenate([np.zeros([n_cls_token, embed_dim]), pos_embed], axis=0)
         | 
| 37 | 
            +
                return pos_embed
         | 
| 38 | 
            +
             | 
| 39 | 
            +
             | 
| 40 | 
            +
            def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
         | 
| 41 | 
            +
                assert embed_dim % 2 == 0
         | 
| 42 | 
            +
             | 
| 43 | 
            +
                # use half of dimensions to encode grid_h
         | 
| 44 | 
            +
                emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])  # (H*W, D/2)
         | 
| 45 | 
            +
                emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])  # (H*W, D/2)
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
         | 
| 48 | 
            +
                return emb
         | 
| 49 | 
            +
             | 
| 50 | 
            +
             | 
| 51 | 
            +
            def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
         | 
| 52 | 
            +
                """
         | 
| 53 | 
            +
                embed_dim: output dimension for each position
         | 
| 54 | 
            +
                pos: a list of positions to be encoded: size (M,)
         | 
| 55 | 
            +
                out: (M, D)
         | 
| 56 | 
            +
                """
         | 
| 57 | 
            +
                assert embed_dim % 2 == 0
         | 
| 58 | 
            +
                omega = np.arange(embed_dim // 2, dtype=float)
         | 
| 59 | 
            +
                omega /= embed_dim / 2.
         | 
| 60 | 
            +
                omega = 1. / 10000**omega  # (D/2,)
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                pos = pos.reshape(-1)  # (M,)
         | 
| 63 | 
            +
                out = np.einsum('m,d->md', pos, omega)  # (M, D/2), outer product
         | 
| 64 | 
            +
             | 
| 65 | 
            +
                emb_sin = np.sin(out) # (M, D/2)
         | 
| 66 | 
            +
                emb_cos = np.cos(out) # (M, D/2)
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                emb = np.concatenate([emb_sin, emb_cos], axis=1)  # (M, D)
         | 
| 69 | 
            +
                return emb
         | 
| 70 | 
            +
             | 
| 71 | 
            +
             | 
| 72 | 
            +
            # --------------------------------------------------------
         | 
| 73 | 
            +
            # Interpolate position embeddings for high-resolution
         | 
| 74 | 
            +
            # References:
         | 
| 75 | 
            +
            # MAE: https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
         | 
| 76 | 
            +
            # DeiT: https://github.com/facebookresearch/deit
         | 
| 77 | 
            +
            # --------------------------------------------------------
         | 
| 78 | 
            +
            def interpolate_pos_embed(model, checkpoint_model):
         | 
| 79 | 
            +
                if 'pos_embed' in checkpoint_model:
         | 
| 80 | 
            +
                    pos_embed_checkpoint = checkpoint_model['pos_embed']
         | 
| 81 | 
            +
                    embedding_size = pos_embed_checkpoint.shape[-1]
         | 
| 82 | 
            +
                    num_patches = model.patch_embed.num_patches
         | 
| 83 | 
            +
                    num_extra_tokens = model.pos_embed.shape[-2] - num_patches
         | 
| 84 | 
            +
                    # height (== width) for the checkpoint position embedding
         | 
| 85 | 
            +
                    orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
         | 
| 86 | 
            +
                    # height (== width) for the new position embedding
         | 
| 87 | 
            +
                    new_size = int(num_patches ** 0.5)
         | 
| 88 | 
            +
                    # class_token and dist_token are kept unchanged
         | 
| 89 | 
            +
                    if orig_size != new_size:
         | 
| 90 | 
            +
                        print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
         | 
| 91 | 
            +
                        extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
         | 
| 92 | 
            +
                        # only the position tokens are interpolated
         | 
| 93 | 
            +
                        pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
         | 
| 94 | 
            +
                        pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
         | 
| 95 | 
            +
                        pos_tokens = torch.nn.functional.interpolate(
         | 
| 96 | 
            +
                            pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
         | 
| 97 | 
            +
                        pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
         | 
| 98 | 
            +
                        new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
         | 
| 99 | 
            +
                        checkpoint_model['pos_embed'] = new_pos_embed
         | 
| 100 | 
            +
             | 
| 101 | 
            +
             | 
| 102 | 
            +
            #----------------------------------------------------------
         | 
| 103 | 
            +
            # RoPE2D: RoPE implementation in 2D
         | 
| 104 | 
            +
            #----------------------------------------------------------
         | 
| 105 | 
            +
             | 
| 106 | 
            +
            try:
         | 
| 107 | 
            +
                from models.curope import cuRoPE2D
         | 
| 108 | 
            +
                RoPE2D = cuRoPE2D
         | 
| 109 | 
            +
            except ImportError:
         | 
| 110 | 
            +
                print('Warning, cannot find cuda-compiled version of RoPE2D, using a slow pytorch version instead')
         | 
| 111 | 
            +
             | 
| 112 | 
            +
                class RoPE2D(torch.nn.Module):
         | 
| 113 | 
            +
                    
         | 
| 114 | 
            +
                    def __init__(self, freq=100.0, F0=1.0):
         | 
| 115 | 
            +
                        super().__init__()
         | 
| 116 | 
            +
                        self.base = freq 
         | 
| 117 | 
            +
                        self.F0 = F0
         | 
| 118 | 
            +
                        self.cache = {}
         | 
| 119 | 
            +
             | 
| 120 | 
            +
                    def get_cos_sin(self, D, seq_len, device, dtype):
         | 
| 121 | 
            +
                        if (D,seq_len,device,dtype) not in self.cache:
         | 
| 122 | 
            +
                            inv_freq = 1.0 / (self.base ** (torch.arange(0, D, 2).float().to(device) / D))
         | 
| 123 | 
            +
                            t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype)
         | 
| 124 | 
            +
                            freqs = torch.einsum("i,j->ij", t, inv_freq).to(dtype)
         | 
| 125 | 
            +
                            freqs = torch.cat((freqs, freqs), dim=-1)
         | 
| 126 | 
            +
                            cos = freqs.cos() # (Seq, Dim)
         | 
| 127 | 
            +
                            sin = freqs.sin()
         | 
| 128 | 
            +
                            self.cache[D,seq_len,device,dtype] = (cos,sin)
         | 
| 129 | 
            +
                        return self.cache[D,seq_len,device,dtype]
         | 
| 130 | 
            +
                        
         | 
| 131 | 
            +
                    @staticmethod
         | 
| 132 | 
            +
                    def rotate_half(x):
         | 
| 133 | 
            +
                        x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
         | 
| 134 | 
            +
                        return torch.cat((-x2, x1), dim=-1)
         | 
| 135 | 
            +
                        
         | 
| 136 | 
            +
                    def apply_rope1d(self, tokens, pos1d, cos, sin):
         | 
| 137 | 
            +
                        assert pos1d.ndim==2
         | 
| 138 | 
            +
                        cos = torch.nn.functional.embedding(pos1d, cos)[:, None, :, :]
         | 
| 139 | 
            +
                        sin = torch.nn.functional.embedding(pos1d, sin)[:, None, :, :]
         | 
| 140 | 
            +
                        return (tokens * cos) + (self.rotate_half(tokens) * sin)
         | 
| 141 | 
            +
                        
         | 
| 142 | 
            +
                    def forward(self, tokens, positions):
         | 
| 143 | 
            +
                        """
         | 
| 144 | 
            +
                        input:
         | 
| 145 | 
            +
                            * tokens: batch_size x nheads x ntokens x dim
         | 
| 146 | 
            +
                            * positions: batch_size x ntokens x 2 (y and x position of each token)
         | 
| 147 | 
            +
                        output:
         | 
| 148 | 
            +
                            * tokens after appplying RoPE2D (batch_size x nheads x ntokens x dim)
         | 
| 149 | 
            +
                        """
         | 
| 150 | 
            +
                        assert tokens.size(3)%2==0, "number of dimensions should be a multiple of two"
         | 
| 151 | 
            +
                        D = tokens.size(3) // 2
         | 
| 152 | 
            +
                        assert positions.ndim==3 and positions.shape[-1] == 2 # Batch, Seq, 2
         | 
| 153 | 
            +
                        cos, sin = self.get_cos_sin(D, int(positions.max())+1, tokens.device, tokens.dtype)
         | 
| 154 | 
            +
                        # split features into two along the feature dimension, and apply rope1d on each half
         | 
| 155 | 
            +
                        y, x = tokens.chunk(2, dim=-1)
         | 
| 156 | 
            +
                        y = self.apply_rope1d(y, positions[:,:,0], cos, sin)
         | 
| 157 | 
            +
                        x = self.apply_rope1d(x, positions[:,:,1], cos, sin)
         | 
| 158 | 
            +
                        tokens = torch.cat((y, x), dim=-1)
         | 
| 159 | 
            +
                        return tokens
         | 
| 160 | 
            +
                 
         | 
| 161 | 
            +
            # patch embedding
         | 
| 162 | 
            +
            class PositionGetter(object):
         | 
| 163 | 
            +
                """ return positions of patches """
         | 
| 164 | 
            +
             | 
| 165 | 
            +
                def __init__(self):
         | 
| 166 | 
            +
                    self.cache_positions = {}
         | 
| 167 | 
            +
                    
         | 
| 168 | 
            +
                def __call__(self, b, h, w, device):
         | 
| 169 | 
            +
                    if not (h,w) in self.cache_positions:
         | 
| 170 | 
            +
                        x = torch.arange(w, device=device)
         | 
| 171 | 
            +
                        y = torch.arange(h, device=device)
         | 
| 172 | 
            +
                        self.cache_positions[h,w] = torch.cartesian_prod(y, x) # (h, w, 2)
         | 
| 173 | 
            +
                    pos = self.cache_positions[h,w].view(1, h*w, 2).expand(b, -1, 2).clone()
         | 
| 174 | 
            +
                    return pos
         | 
    	
        pi3/models/layers/transformer_head.py
    ADDED
    
    | @@ -0,0 +1,81 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from .attention import FlashAttentionRope
         | 
| 2 | 
            +
            from .block import BlockRope
         | 
| 3 | 
            +
            from ..dinov2.layers import Mlp
         | 
| 4 | 
            +
            import torch.nn as nn
         | 
| 5 | 
            +
            from functools import partial
         | 
| 6 | 
            +
            from torch.utils.checkpoint import checkpoint
         | 
| 7 | 
            +
            import torch.nn.functional as F
         | 
| 8 | 
            +
               
         | 
| 9 | 
            +
            class TransformerDecoder(nn.Module):
         | 
| 10 | 
            +
                def __init__(
         | 
| 11 | 
            +
                    self,
         | 
| 12 | 
            +
                    in_dim,
         | 
| 13 | 
            +
                    out_dim,
         | 
| 14 | 
            +
                    dec_embed_dim=512,
         | 
| 15 | 
            +
                    depth=5,
         | 
| 16 | 
            +
                    dec_num_heads=8,
         | 
| 17 | 
            +
                    mlp_ratio=4,
         | 
| 18 | 
            +
                    rope=None,
         | 
| 19 | 
            +
                    need_project=True,
         | 
| 20 | 
            +
                    use_checkpoint=False,
         | 
| 21 | 
            +
                ):
         | 
| 22 | 
            +
                    super().__init__()
         | 
| 23 | 
            +
             | 
| 24 | 
            +
                    self.projects = nn.Linear(in_dim, dec_embed_dim) if need_project else nn.Identity()
         | 
| 25 | 
            +
                    self.use_checkpoint = use_checkpoint
         | 
| 26 | 
            +
             | 
| 27 | 
            +
                    self.blocks = nn.ModuleList([
         | 
| 28 | 
            +
                        BlockRope(
         | 
| 29 | 
            +
                            dim=dec_embed_dim,
         | 
| 30 | 
            +
                            num_heads=dec_num_heads,
         | 
| 31 | 
            +
                            mlp_ratio=mlp_ratio,
         | 
| 32 | 
            +
                            qkv_bias=True,
         | 
| 33 | 
            +
                            proj_bias=True,
         | 
| 34 | 
            +
                            ffn_bias=True,
         | 
| 35 | 
            +
                            drop_path=0.0,
         | 
| 36 | 
            +
                            norm_layer=partial(nn.LayerNorm, eps=1e-6),
         | 
| 37 | 
            +
                            act_layer=nn.GELU,
         | 
| 38 | 
            +
                            ffn_layer=Mlp,
         | 
| 39 | 
            +
                            init_values=None,
         | 
| 40 | 
            +
                            qk_norm=False,
         | 
| 41 | 
            +
                            # attn_class=MemEffAttentionRope,
         | 
| 42 | 
            +
                            attn_class=FlashAttentionRope,
         | 
| 43 | 
            +
                            rope=rope
         | 
| 44 | 
            +
                        ) for _ in range(depth)])
         | 
| 45 | 
            +
             | 
| 46 | 
            +
                    self.linear_out = nn.Linear(dec_embed_dim, out_dim)
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                def forward(self, hidden, xpos=None):
         | 
| 49 | 
            +
                    hidden = self.projects(hidden)
         | 
| 50 | 
            +
                    for i, blk in enumerate(self.blocks):
         | 
| 51 | 
            +
                        if self.use_checkpoint and self.training:
         | 
| 52 | 
            +
                            hidden = checkpoint(blk, hidden, xpos=xpos, use_reentrant=False)
         | 
| 53 | 
            +
                        else:
         | 
| 54 | 
            +
                            hidden = blk(hidden, xpos=xpos)
         | 
| 55 | 
            +
                    out = self.linear_out(hidden)
         | 
| 56 | 
            +
                    return out
         | 
| 57 | 
            +
             | 
| 58 | 
            +
            class LinearPts3d (nn.Module):
         | 
| 59 | 
            +
                """ 
         | 
| 60 | 
            +
                Linear head for dust3r
         | 
| 61 | 
            +
                Each token outputs: - 16x16 3D points (+ confidence)
         | 
| 62 | 
            +
                """
         | 
| 63 | 
            +
             | 
| 64 | 
            +
                def __init__(self, patch_size, dec_embed_dim, output_dim=3,):
         | 
| 65 | 
            +
                    super().__init__()
         | 
| 66 | 
            +
                    self.patch_size = patch_size
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                    self.proj = nn.Linear(dec_embed_dim, (output_dim)*self.patch_size**2)
         | 
| 69 | 
            +
             | 
| 70 | 
            +
                def forward(self, decout, img_shape):
         | 
| 71 | 
            +
                    H, W = img_shape
         | 
| 72 | 
            +
                    tokens = decout[-1]
         | 
| 73 | 
            +
                    B, S, D = tokens.shape
         | 
| 74 | 
            +
             | 
| 75 | 
            +
                    # extract 3D points
         | 
| 76 | 
            +
                    feat = self.proj(tokens)  # B,S,D
         | 
| 77 | 
            +
                    feat = feat.transpose(-1, -2).view(B, -1, H//self.patch_size, W//self.patch_size)
         | 
| 78 | 
            +
                    feat = F.pixel_shuffle(feat, self.patch_size)  # B,3,H,W
         | 
| 79 | 
            +
             | 
| 80 | 
            +
                    # permute + norm depth
         | 
| 81 | 
            +
                    return feat.permute(0, 2, 3, 1)
         | 
    	
        pi3/models/pi3.py
    ADDED
    
    | @@ -0,0 +1,216 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            import torch.nn as nn
         | 
| 3 | 
            +
            from functools import partial
         | 
| 4 | 
            +
            from copy import deepcopy
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            from .dinov2.layers import Mlp
         | 
| 7 | 
            +
            from ..utils.geometry import homogenize_points
         | 
| 8 | 
            +
            from .layers.pos_embed import RoPE2D, PositionGetter
         | 
| 9 | 
            +
            from .layers.block import BlockRope
         | 
| 10 | 
            +
            from .layers.attention import FlashAttentionRope
         | 
| 11 | 
            +
            from .layers.transformer_head import TransformerDecoder, LinearPts3d
         | 
| 12 | 
            +
            from .layers.camera_head import CameraHead
         | 
| 13 | 
            +
            from .dinov2.hub.backbones import dinov2_vitl14, dinov2_vitl14_reg
         | 
| 14 | 
            +
            from huggingface_hub import PyTorchModelHubMixin
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            class Pi3(nn.Module, PyTorchModelHubMixin):
         | 
| 17 | 
            +
                def __init__(
         | 
| 18 | 
            +
                        self,
         | 
| 19 | 
            +
                        pos_type='rope100',
         | 
| 20 | 
            +
                        decoder_size='large',
         | 
| 21 | 
            +
                    ):
         | 
| 22 | 
            +
                    super().__init__()
         | 
| 23 | 
            +
             | 
| 24 | 
            +
                    # ----------------------
         | 
| 25 | 
            +
                    #        Encoder
         | 
| 26 | 
            +
                    # ----------------------
         | 
| 27 | 
            +
                    self.encoder = dinov2_vitl14_reg(pretrained=False)
         | 
| 28 | 
            +
                    self.patch_size = 14
         | 
| 29 | 
            +
                    del self.encoder.mask_token
         | 
| 30 | 
            +
             | 
| 31 | 
            +
                    # ----------------------
         | 
| 32 | 
            +
                    #  Positonal Encoding
         | 
| 33 | 
            +
                    # ----------------------
         | 
| 34 | 
            +
                    self.pos_type = pos_type if pos_type is not None else 'none'
         | 
| 35 | 
            +
                    self.rope=None
         | 
| 36 | 
            +
                    if self.pos_type.startswith('rope'): # eg rope100 
         | 
| 37 | 
            +
                        if RoPE2D is None: raise ImportError("Cannot find cuRoPE2D, please install it following the README instructions")
         | 
| 38 | 
            +
                        freq = float(self.pos_type[len('rope'):])
         | 
| 39 | 
            +
                        self.rope = RoPE2D(freq=freq)
         | 
| 40 | 
            +
                        self.position_getter = PositionGetter()
         | 
| 41 | 
            +
                    else:
         | 
| 42 | 
            +
                        raise NotImplementedError
         | 
| 43 | 
            +
                    
         | 
| 44 | 
            +
             | 
| 45 | 
            +
                    # ----------------------
         | 
| 46 | 
            +
                    #        Decoder
         | 
| 47 | 
            +
                    # ----------------------
         | 
| 48 | 
            +
                    enc_embed_dim = self.encoder.blocks[0].attn.qkv.in_features        # 1024
         | 
| 49 | 
            +
                    if decoder_size == 'small':
         | 
| 50 | 
            +
                        dec_embed_dim = 384
         | 
| 51 | 
            +
                        dec_num_heads = 6
         | 
| 52 | 
            +
                        mlp_ratio = 4
         | 
| 53 | 
            +
                        dec_depth = 24
         | 
| 54 | 
            +
                    elif decoder_size == 'base':
         | 
| 55 | 
            +
                        dec_embed_dim = 768
         | 
| 56 | 
            +
                        dec_num_heads = 12
         | 
| 57 | 
            +
                        mlp_ratio = 4
         | 
| 58 | 
            +
                        dec_depth = 24
         | 
| 59 | 
            +
                    elif decoder_size == 'large':
         | 
| 60 | 
            +
                        dec_embed_dim = 1024
         | 
| 61 | 
            +
                        dec_num_heads = 16
         | 
| 62 | 
            +
                        mlp_ratio = 4
         | 
| 63 | 
            +
                        dec_depth = 36
         | 
| 64 | 
            +
                    else:
         | 
| 65 | 
            +
                        raise NotImplementedError
         | 
| 66 | 
            +
                    self.decoder = nn.ModuleList([
         | 
| 67 | 
            +
                        BlockRope(
         | 
| 68 | 
            +
                            dim=dec_embed_dim,
         | 
| 69 | 
            +
                            num_heads=dec_num_heads,
         | 
| 70 | 
            +
                            mlp_ratio=mlp_ratio,
         | 
| 71 | 
            +
                            qkv_bias=True,
         | 
| 72 | 
            +
                            proj_bias=True,
         | 
| 73 | 
            +
                            ffn_bias=True,
         | 
| 74 | 
            +
                            drop_path=0.0,
         | 
| 75 | 
            +
                            norm_layer=partial(nn.LayerNorm, eps=1e-6),
         | 
| 76 | 
            +
                            act_layer=nn.GELU,
         | 
| 77 | 
            +
                            ffn_layer=Mlp,
         | 
| 78 | 
            +
                            init_values=0.01,
         | 
| 79 | 
            +
                            qk_norm=True,
         | 
| 80 | 
            +
                            attn_class=FlashAttentionRope,
         | 
| 81 | 
            +
                            rope=self.rope
         | 
| 82 | 
            +
                        ) for _ in range(dec_depth)])
         | 
| 83 | 
            +
                    self.dec_embed_dim = dec_embed_dim
         | 
| 84 | 
            +
             | 
| 85 | 
            +
                    # ----------------------
         | 
| 86 | 
            +
                    #     Register_token
         | 
| 87 | 
            +
                    # ----------------------
         | 
| 88 | 
            +
                    num_register_tokens = 5
         | 
| 89 | 
            +
                    self.patch_start_idx = num_register_tokens
         | 
| 90 | 
            +
                    self.register_token = nn.Parameter(torch.randn(1, 1, num_register_tokens, self.dec_embed_dim))
         | 
| 91 | 
            +
                    nn.init.normal_(self.register_token, std=1e-6)
         | 
| 92 | 
            +
             | 
| 93 | 
            +
                    # ----------------------
         | 
| 94 | 
            +
                    #  Local Points Decoder
         | 
| 95 | 
            +
                    # ----------------------
         | 
| 96 | 
            +
                    self.point_decoder = TransformerDecoder(
         | 
| 97 | 
            +
                        in_dim=2*self.dec_embed_dim, 
         | 
| 98 | 
            +
                        dec_embed_dim=1024,
         | 
| 99 | 
            +
                        dec_num_heads=16,
         | 
| 100 | 
            +
                        out_dim=1024,
         | 
| 101 | 
            +
                        rope=self.rope,
         | 
| 102 | 
            +
                    )
         | 
| 103 | 
            +
                    self.point_head = LinearPts3d(patch_size=14, dec_embed_dim=1024, output_dim=3)
         | 
| 104 | 
            +
             | 
| 105 | 
            +
                    # ----------------------
         | 
| 106 | 
            +
                    #     Conf Decoder
         | 
| 107 | 
            +
                    # ----------------------
         | 
| 108 | 
            +
                    self.conf_decoder = deepcopy(self.point_decoder)
         | 
| 109 | 
            +
                    self.conf_head = LinearPts3d(patch_size=14, dec_embed_dim=1024, output_dim=1)
         | 
| 110 | 
            +
             | 
| 111 | 
            +
                    # ----------------------
         | 
| 112 | 
            +
                    #  Camera Pose Decoder
         | 
| 113 | 
            +
                    # ----------------------
         | 
| 114 | 
            +
                    self.camera_decoder = TransformerDecoder(
         | 
| 115 | 
            +
                        in_dim=2*self.dec_embed_dim, 
         | 
| 116 | 
            +
                        dec_embed_dim=1024,
         | 
| 117 | 
            +
                        dec_num_heads=16,                # 8
         | 
| 118 | 
            +
                        out_dim=512,
         | 
| 119 | 
            +
                        rope=self.rope,
         | 
| 120 | 
            +
                        use_checkpoint=False
         | 
| 121 | 
            +
                    )
         | 
| 122 | 
            +
                    self.camera_head = CameraHead(dim=512)
         | 
| 123 | 
            +
             | 
| 124 | 
            +
                    # For ImageNet Normalize
         | 
| 125 | 
            +
                    image_mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
         | 
| 126 | 
            +
                    image_std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
         | 
| 127 | 
            +
             | 
| 128 | 
            +
                    self.register_buffer("image_mean", image_mean)
         | 
| 129 | 
            +
                    self.register_buffer("image_std", image_std)
         | 
| 130 | 
            +
             | 
| 131 | 
            +
             | 
| 132 | 
            +
                def decode(self, hidden, N, H, W):
         | 
| 133 | 
            +
                    BN, hw, _ = hidden.shape
         | 
| 134 | 
            +
                    B = BN // N
         | 
| 135 | 
            +
             | 
| 136 | 
            +
                    final_output = []
         | 
| 137 | 
            +
                    
         | 
| 138 | 
            +
                    hidden = hidden.reshape(B*N, hw, -1)
         | 
| 139 | 
            +
             | 
| 140 | 
            +
                    register_token = self.register_token.repeat(B, N, 1, 1).reshape(B*N, *self.register_token.shape[-2:])
         | 
| 141 | 
            +
             | 
| 142 | 
            +
                    # Concatenate special tokens with patch tokens
         | 
| 143 | 
            +
                    hidden = torch.cat([register_token, hidden], dim=1)
         | 
| 144 | 
            +
                    hw = hidden.shape[1]
         | 
| 145 | 
            +
             | 
| 146 | 
            +
                    if self.pos_type.startswith('rope'):
         | 
| 147 | 
            +
                        pos = self.position_getter(B * N, H//self.patch_size, W//self.patch_size, hidden.device)
         | 
| 148 | 
            +
             | 
| 149 | 
            +
                    if self.patch_start_idx > 0:
         | 
| 150 | 
            +
                        # do not use position embedding for special tokens (camera and register tokens)
         | 
| 151 | 
            +
                        # so set pos to 0 for the special tokens
         | 
| 152 | 
            +
                        pos = pos + 1
         | 
| 153 | 
            +
                        pos_special = torch.zeros(B * N, self.patch_start_idx, 2).to(hidden.device).to(pos.dtype)
         | 
| 154 | 
            +
                        pos = torch.cat([pos_special, pos], dim=1)
         | 
| 155 | 
            +
                   
         | 
| 156 | 
            +
                    for i in range(len(self.decoder)):
         | 
| 157 | 
            +
                        blk = self.decoder[i]
         | 
| 158 | 
            +
             | 
| 159 | 
            +
                        if i % 2 == 0:
         | 
| 160 | 
            +
                            pos = pos.reshape(B*N, hw, -1)
         | 
| 161 | 
            +
                            hidden = hidden.reshape(B*N, hw, -1)
         | 
| 162 | 
            +
                        else:
         | 
| 163 | 
            +
                            pos = pos.reshape(B, N*hw, -1)
         | 
| 164 | 
            +
                            hidden = hidden.reshape(B, N*hw, -1)
         | 
| 165 | 
            +
             | 
| 166 | 
            +
                        hidden = blk(hidden, xpos=pos)
         | 
| 167 | 
            +
             | 
| 168 | 
            +
                        if i+1 in [len(self.decoder)-1, len(self.decoder)]:
         | 
| 169 | 
            +
                            final_output.append(hidden.reshape(B*N, hw, -1))
         | 
| 170 | 
            +
             | 
| 171 | 
            +
                    return torch.cat([final_output[0], final_output[1]], dim=-1), pos.reshape(B*N, hw, -1)
         | 
| 172 | 
            +
                
         | 
| 173 | 
            +
                def forward(self, imgs):
         | 
| 174 | 
            +
                    imgs = (imgs - self.image_mean) / self.image_std
         | 
| 175 | 
            +
             | 
| 176 | 
            +
                    B, N, _, H, W = imgs.shape
         | 
| 177 | 
            +
                    patch_h, patch_w = H // 14, W // 14
         | 
| 178 | 
            +
                    
         | 
| 179 | 
            +
                    # encode by dinov2
         | 
| 180 | 
            +
                    imgs = imgs.reshape(B*N, _, H, W)
         | 
| 181 | 
            +
                    hidden = self.encoder(imgs, is_training=True)
         | 
| 182 | 
            +
             | 
| 183 | 
            +
                    if isinstance(hidden, dict):
         | 
| 184 | 
            +
                        hidden = hidden["x_norm_patchtokens"]
         | 
| 185 | 
            +
             | 
| 186 | 
            +
                    hidden, pos = self.decode(hidden, N, H, W)
         | 
| 187 | 
            +
             | 
| 188 | 
            +
                    point_hidden = self.point_decoder(hidden, xpos=pos)
         | 
| 189 | 
            +
                    conf_hidden = self.conf_decoder(hidden, xpos=pos)
         | 
| 190 | 
            +
                    camera_hidden = self.camera_decoder(hidden, xpos=pos)
         | 
| 191 | 
            +
             | 
| 192 | 
            +
                    with torch.amp.autocast(device_type='cuda', enabled=False):
         | 
| 193 | 
            +
                        # local points
         | 
| 194 | 
            +
                        point_hidden = point_hidden.float()
         | 
| 195 | 
            +
                        ret = self.point_head([point_hidden[:, self.patch_start_idx:]], (H, W)).reshape(B, N, H, W, -1)
         | 
| 196 | 
            +
                        xy, z = ret.split([2, 1], dim=-1)
         | 
| 197 | 
            +
                        z = torch.exp(z)
         | 
| 198 | 
            +
                        local_points = torch.cat([xy * z, z], dim=-1)
         | 
| 199 | 
            +
             | 
| 200 | 
            +
                        # confidence
         | 
| 201 | 
            +
                        conf_hidden = conf_hidden.float()
         | 
| 202 | 
            +
                        conf = self.conf_head([conf_hidden[:, self.patch_start_idx:]], (H, W)).reshape(B, N, H, W, -1)
         | 
| 203 | 
            +
             | 
| 204 | 
            +
                        # camera
         | 
| 205 | 
            +
                        camera_hidden = camera_hidden.float()
         | 
| 206 | 
            +
                        camera_poses = self.camera_head(camera_hidden[:, self.patch_start_idx:], patch_h, patch_w).reshape(B, N, 4, 4)
         | 
| 207 | 
            +
             | 
| 208 | 
            +
                        # unproject local points using camera poses
         | 
| 209 | 
            +
                        points = torch.einsum('bnij, bnhwj -> bnhwi', camera_poses, homogenize_points(local_points))[..., :3]
         | 
| 210 | 
            +
             | 
| 211 | 
            +
                    return dict(
         | 
| 212 | 
            +
                        points=points,
         | 
| 213 | 
            +
                        local_points=local_points,
         | 
| 214 | 
            +
                        conf=conf,
         | 
| 215 | 
            +
                        camera_poses=camera_poses,
         | 
| 216 | 
            +
                    )
         | 
    	
        pi3/utils/basic.py
    ADDED
    
    | @@ -0,0 +1,223 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import os
         | 
| 2 | 
            +
            import os.path as osp
         | 
| 3 | 
            +
            import math
         | 
| 4 | 
            +
            import cv2
         | 
| 5 | 
            +
            from PIL import Image
         | 
| 6 | 
            +
            import torch
         | 
| 7 | 
            +
            from torchvision import transforms
         | 
| 8 | 
            +
            from plyfile import PlyData, PlyElement
         | 
| 9 | 
            +
            import numpy as np
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            def load_images_as_tensor(path='data/truck', interval=1, PIXEL_LIMIT=255000):
         | 
| 12 | 
            +
                """
         | 
| 13 | 
            +
                Loads images from a directory or video, resizes them to a uniform size,
         | 
| 14 | 
            +
                then converts and stacks them into a single [N, 3, H, W] PyTorch tensor.
         | 
| 15 | 
            +
                """
         | 
| 16 | 
            +
                sources = [] 
         | 
| 17 | 
            +
                
         | 
| 18 | 
            +
                # --- 1. Load image paths or video frames ---
         | 
| 19 | 
            +
                if osp.isdir(path):
         | 
| 20 | 
            +
                    print(f"Loading images from directory: {path}")
         | 
| 21 | 
            +
                    filenames = sorted([x for x in os.listdir(path) if x.lower().endswith(('.png', '.jpg', '.jpeg'))])
         | 
| 22 | 
            +
                    for i in range(0, len(filenames), interval):
         | 
| 23 | 
            +
                        img_path = osp.join(path, filenames[i])
         | 
| 24 | 
            +
                        try:
         | 
| 25 | 
            +
                            sources.append(Image.open(img_path).convert('RGB'))
         | 
| 26 | 
            +
                        except Exception as e:
         | 
| 27 | 
            +
                            print(f"Could not load image {filenames[i]}: {e}")
         | 
| 28 | 
            +
                elif path.lower().endswith('.mp4'):
         | 
| 29 | 
            +
                    print(f"Loading frames from video: {path}")
         | 
| 30 | 
            +
                    cap = cv2.VideoCapture(path)
         | 
| 31 | 
            +
                    if not cap.isOpened(): raise IOError(f"Cannot open video file: {path}")
         | 
| 32 | 
            +
                    frame_idx = 0
         | 
| 33 | 
            +
                    while True:
         | 
| 34 | 
            +
                        ret, frame = cap.read()
         | 
| 35 | 
            +
                        if not ret: break
         | 
| 36 | 
            +
                        if frame_idx % interval == 0:
         | 
| 37 | 
            +
                            rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
         | 
| 38 | 
            +
                            sources.append(Image.fromarray(rgb_frame))
         | 
| 39 | 
            +
                        frame_idx += 1
         | 
| 40 | 
            +
                    cap.release()
         | 
| 41 | 
            +
                else:
         | 
| 42 | 
            +
                    raise ValueError(f"Unsupported path. Must be a directory or a .mp4 file: {path}")
         | 
| 43 | 
            +
             | 
| 44 | 
            +
                if not sources:
         | 
| 45 | 
            +
                    print("No images found or loaded.")
         | 
| 46 | 
            +
                    return torch.empty(0)
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                print(f"Found {len(sources)} images/frames. Processing...")
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                # --- 2. Determine a uniform target size for all images based on the first image ---
         | 
| 51 | 
            +
                # This is necessary to ensure all tensors have the same dimensions for stacking.
         | 
| 52 | 
            +
                first_img = sources[0]
         | 
| 53 | 
            +
                W_orig, H_orig = first_img.size
         | 
| 54 | 
            +
                scale = math.sqrt(PIXEL_LIMIT / (W_orig * H_orig)) if W_orig * H_orig > 0 else 1
         | 
| 55 | 
            +
                W_target, H_target = W_orig * scale, H_orig * scale
         | 
| 56 | 
            +
                k, m = round(W_target / 14), round(H_target / 14)
         | 
| 57 | 
            +
                while (k * 14) * (m * 14) > PIXEL_LIMIT:
         | 
| 58 | 
            +
                    if k / m > W_target / H_target: k -= 1
         | 
| 59 | 
            +
                    else: m -= 1
         | 
| 60 | 
            +
                TARGET_W, TARGET_H = max(1, k) * 14, max(1, m) * 14
         | 
| 61 | 
            +
                print(f"All images will be resized to a uniform size: ({TARGET_W}, {TARGET_H})")
         | 
| 62 | 
            +
             | 
| 63 | 
            +
                # --- 3. Resize images and convert them to tensors in the [0, 1] range ---
         | 
| 64 | 
            +
                tensor_list = []
         | 
| 65 | 
            +
                # Define a transform to convert a PIL Image to a CxHxW tensor and normalize to [0,1]
         | 
| 66 | 
            +
                to_tensor_transform = transforms.ToTensor()
         | 
| 67 | 
            +
                
         | 
| 68 | 
            +
                for img_pil in sources:
         | 
| 69 | 
            +
                    try:
         | 
| 70 | 
            +
                        # Resize to the uniform target size
         | 
| 71 | 
            +
                        resized_img = img_pil.resize((TARGET_W, TARGET_H), Image.Resampling.LANCZOS)
         | 
| 72 | 
            +
                        # Convert to tensor
         | 
| 73 | 
            +
                        img_tensor = to_tensor_transform(resized_img)
         | 
| 74 | 
            +
                        tensor_list.append(img_tensor)
         | 
| 75 | 
            +
                    except Exception as e:
         | 
| 76 | 
            +
                        print(f"Error processing an image: {e}")
         | 
| 77 | 
            +
             | 
| 78 | 
            +
                if not tensor_list:
         | 
| 79 | 
            +
                    print("No images were successfully processed.")
         | 
| 80 | 
            +
                    return torch.empty(0)
         | 
| 81 | 
            +
             | 
| 82 | 
            +
                # --- 4. Stack the list of tensors into a single [N, C, H, W] batch tensor ---
         | 
| 83 | 
            +
                return torch.stack(tensor_list, dim=0)
         | 
| 84 | 
            +
             | 
| 85 | 
            +
             | 
| 86 | 
            +
            def tensor_to_pil(tensor):
         | 
| 87 | 
            +
                """
         | 
| 88 | 
            +
                Converts a PyTorch tensor to a PIL image. Automatically moves the channel dimension 
         | 
| 89 | 
            +
                (if it has size 3) to the last axis before converting.
         | 
| 90 | 
            +
             | 
| 91 | 
            +
                Args:
         | 
| 92 | 
            +
                    tensor (torch.Tensor): Input tensor. Expected shape can be [C, H, W], [H, W, C], or [H, W].
         | 
| 93 | 
            +
                
         | 
| 94 | 
            +
                Returns:
         | 
| 95 | 
            +
                    PIL.Image: The converted PIL image.
         | 
| 96 | 
            +
                """
         | 
| 97 | 
            +
                if torch.is_tensor(tensor):
         | 
| 98 | 
            +
                    array = tensor.detach().cpu().numpy()
         | 
| 99 | 
            +
                else:
         | 
| 100 | 
            +
                    array = tensor
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                return array_to_pil(array)
         | 
| 103 | 
            +
             | 
| 104 | 
            +
             | 
| 105 | 
            +
            def array_to_pil(array):
         | 
| 106 | 
            +
                """
         | 
| 107 | 
            +
                Converts a NumPy array to a PIL image. Automatically:
         | 
| 108 | 
            +
                    - Squeezes dimensions of size 1.
         | 
| 109 | 
            +
                    - Moves the channel dimension (if it has size 3) to the last axis.
         | 
| 110 | 
            +
                
         | 
| 111 | 
            +
                Args:
         | 
| 112 | 
            +
                    array (np.ndarray): Input array. Expected shape can be [C, H, W], [H, W, C], or [H, W].
         | 
| 113 | 
            +
                
         | 
| 114 | 
            +
                Returns:
         | 
| 115 | 
            +
                    PIL.Image: The converted PIL image.
         | 
| 116 | 
            +
                """
         | 
| 117 | 
            +
                # Remove singleton dimensions
         | 
| 118 | 
            +
                array = np.squeeze(array)
         | 
| 119 | 
            +
                
         | 
| 120 | 
            +
                # Ensure the array has the channel dimension as the last axis
         | 
| 121 | 
            +
                if array.ndim == 3 and array.shape[0] == 3:  # If the channel is the first axis
         | 
| 122 | 
            +
                    array = np.transpose(array, (1, 2, 0))  # Move channel to the last axis
         | 
| 123 | 
            +
                
         | 
| 124 | 
            +
                # Handle single-channel grayscale images
         | 
| 125 | 
            +
                if array.ndim == 2:  # [H, W]
         | 
| 126 | 
            +
                    return Image.fromarray((array * 255).astype(np.uint8), mode="L")
         | 
| 127 | 
            +
                elif array.ndim == 3 and array.shape[2] == 3:  # [H, W, C] with 3 channels
         | 
| 128 | 
            +
                    return Image.fromarray((array * 255).astype(np.uint8), mode="RGB")
         | 
| 129 | 
            +
                else:
         | 
| 130 | 
            +
                    raise ValueError(f"Unsupported array shape for PIL conversion: {array.shape}")
         | 
| 131 | 
            +
             | 
| 132 | 
            +
             | 
| 133 | 
            +
            def rotate_target_dim_to_last_axis(x, target_dim=3):
         | 
| 134 | 
            +
                shape = x.shape
         | 
| 135 | 
            +
                axis_to_move = -1
         | 
| 136 | 
            +
                # Iterate backwards to find the first occurrence from the end 
         | 
| 137 | 
            +
                # (which corresponds to the last dimension of size 3 in the original order).
         | 
| 138 | 
            +
                for i in range(len(shape) - 1, -1, -1):
         | 
| 139 | 
            +
                    if shape[i] == target_dim:
         | 
| 140 | 
            +
                        axis_to_move = i
         | 
| 141 | 
            +
                        break
         | 
| 142 | 
            +
             | 
| 143 | 
            +
                # 2. If the axis is found and it's not already in the last position, move it.
         | 
| 144 | 
            +
                if axis_to_move != -1 and axis_to_move != len(shape) - 1:
         | 
| 145 | 
            +
                    # Create the new dimension order.
         | 
| 146 | 
            +
                    dims_order = list(range(len(shape)))
         | 
| 147 | 
            +
                    dims_order.pop(axis_to_move)
         | 
| 148 | 
            +
                    dims_order.append(axis_to_move)
         | 
| 149 | 
            +
                    
         | 
| 150 | 
            +
                    # Use permute to reorder the dimensions.
         | 
| 151 | 
            +
                    ret = x.transpose(*dims_order)
         | 
| 152 | 
            +
                else:
         | 
| 153 | 
            +
                    ret = x
         | 
| 154 | 
            +
             | 
| 155 | 
            +
                return ret
         | 
| 156 | 
            +
             | 
| 157 | 
            +
             | 
| 158 | 
            +
            def write_ply(
         | 
| 159 | 
            +
                xyz,
         | 
| 160 | 
            +
                rgb=None,
         | 
| 161 | 
            +
                path='output.ply',
         | 
| 162 | 
            +
            ) -> None:
         | 
| 163 | 
            +
                if torch.is_tensor(xyz):
         | 
| 164 | 
            +
                    xyz = xyz.detach().cpu().numpy()
         | 
| 165 | 
            +
             | 
| 166 | 
            +
                if torch.is_tensor(rgb):
         | 
| 167 | 
            +
                    rgb = rgb.detach().cpu().numpy()
         | 
| 168 | 
            +
             | 
| 169 | 
            +
                if rgb is not None and rgb.max() > 1:
         | 
| 170 | 
            +
                    rgb = rgb / 255.
         | 
| 171 | 
            +
             | 
| 172 | 
            +
                xyz = rotate_target_dim_to_last_axis(xyz, 3)
         | 
| 173 | 
            +
                xyz = xyz.reshape(-1, 3)
         | 
| 174 | 
            +
             | 
| 175 | 
            +
                if rgb is not None:
         | 
| 176 | 
            +
                    rgb = rotate_target_dim_to_last_axis(rgb, 3)
         | 
| 177 | 
            +
                    rgb = rgb.reshape(-1, 3)
         | 
| 178 | 
            +
                
         | 
| 179 | 
            +
                if rgb is None:
         | 
| 180 | 
            +
                    min_coord = np.min(xyz, axis=0)
         | 
| 181 | 
            +
                    max_coord = np.max(xyz, axis=0)
         | 
| 182 | 
            +
                    normalized_coord = (xyz - min_coord) / (max_coord - min_coord + 1e-8)
         | 
| 183 | 
            +
                    
         | 
| 184 | 
            +
                    hue = 0.7 * normalized_coord[:,0] + 0.2 * normalized_coord[:,1] + 0.1 * normalized_coord[:,2]
         | 
| 185 | 
            +
                    hsv = np.stack([hue, 0.9*np.ones_like(hue), 0.8*np.ones_like(hue)], axis=1)
         | 
| 186 | 
            +
             | 
| 187 | 
            +
                    c = hsv[:,2:] * hsv[:,1:2]
         | 
| 188 | 
            +
                    x = c * (1 - np.abs( (hsv[:,0:1]*6) % 2 - 1 ))
         | 
| 189 | 
            +
                    m = hsv[:,2:] - c
         | 
| 190 | 
            +
                    
         | 
| 191 | 
            +
                    rgb = np.zeros_like(hsv)
         | 
| 192 | 
            +
                    cond = (0 <= hsv[:,0]*6%6) & (hsv[:,0]*6%6 < 1)
         | 
| 193 | 
            +
                    rgb[cond] = np.hstack([c[cond], x[cond], np.zeros_like(x[cond])])
         | 
| 194 | 
            +
                    cond = (1 <= hsv[:,0]*6%6) & (hsv[:,0]*6%6 < 2)
         | 
| 195 | 
            +
                    rgb[cond] = np.hstack([x[cond], c[cond], np.zeros_like(x[cond])])
         | 
| 196 | 
            +
                    cond = (2 <= hsv[:,0]*6%6) & (hsv[:,0]*6%6 < 3)
         | 
| 197 | 
            +
                    rgb[cond] = np.hstack([np.zeros_like(x[cond]), c[cond], x[cond]])
         | 
| 198 | 
            +
                    cond = (3 <= hsv[:,0]*6%6) & (hsv[:,0]*6%6 < 4)
         | 
| 199 | 
            +
                    rgb[cond] = np.hstack([np.zeros_like(x[cond]), x[cond], c[cond]])
         | 
| 200 | 
            +
                    cond = (4 <= hsv[:,0]*6%6) & (hsv[:,0]*6%6 < 5)
         | 
| 201 | 
            +
                    rgb[cond] = np.hstack([x[cond], np.zeros_like(x[cond]), c[cond]])
         | 
| 202 | 
            +
                    cond = (5 <= hsv[:,0]*6%6) & (hsv[:,0]*6%6 < 6)
         | 
| 203 | 
            +
                    rgb[cond] = np.hstack([c[cond], np.zeros_like(x[cond]), x[cond]])
         | 
| 204 | 
            +
                    rgb = (rgb + m)
         | 
| 205 | 
            +
             | 
| 206 | 
            +
                dtype = [
         | 
| 207 | 
            +
                    ("x", "f4"),
         | 
| 208 | 
            +
                    ("y", "f4"),
         | 
| 209 | 
            +
                    ("z", "f4"),
         | 
| 210 | 
            +
                    ("nx", "f4"),
         | 
| 211 | 
            +
                    ("ny", "f4"),
         | 
| 212 | 
            +
                    ("nz", "f4"),
         | 
| 213 | 
            +
                    ("red", "u1"),
         | 
| 214 | 
            +
                    ("green", "u1"),
         | 
| 215 | 
            +
                    ("blue", "u1"),
         | 
| 216 | 
            +
                ]
         | 
| 217 | 
            +
                normals = np.zeros_like(xyz)
         | 
| 218 | 
            +
                elements = np.empty(xyz.shape[0], dtype=dtype)
         | 
| 219 | 
            +
                attributes = np.concatenate((xyz, normals, rgb * 255), axis=1)
         | 
| 220 | 
            +
                elements[:] = list(map(tuple, attributes))
         | 
| 221 | 
            +
                vertex_element = PlyElement.describe(elements, "vertex")
         | 
| 222 | 
            +
                ply_data = PlyData([vertex_element])
         | 
| 223 | 
            +
                ply_data.write(path)
         | 
    	
        pi3/utils/debug.py
    ADDED
    
    | @@ -0,0 +1,63 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import os
         | 
| 2 | 
            +
            import json
         | 
| 3 | 
            +
            import debugpy
         | 
| 4 | 
            +
            import socket
         | 
| 5 | 
            +
            import random
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            def update_vscode_launch_file(host: str, port: int):
         | 
| 8 | 
            +
                """Update the .vscode/launch.json file with the new host and port."""
         | 
| 9 | 
            +
                launch_file_path = ".vscode/launch.json"
         | 
| 10 | 
            +
                # Desired configuration
         | 
| 11 | 
            +
                new_config = {
         | 
| 12 | 
            +
                    "version": "0.2.0",
         | 
| 13 | 
            +
                    "configurations": [
         | 
| 14 | 
            +
                        {
         | 
| 15 | 
            +
                            "name": "bash_debug",
         | 
| 16 | 
            +
                            "type": "debugpy",
         | 
| 17 | 
            +
                            "request": "attach",
         | 
| 18 | 
            +
                            "connect": {
         | 
| 19 | 
            +
                                "host": host,
         | 
| 20 | 
            +
                                "port": port
         | 
| 21 | 
            +
                            },
         | 
| 22 | 
            +
                            "justMyCode": False
         | 
| 23 | 
            +
                        },
         | 
| 24 | 
            +
                    ]
         | 
| 25 | 
            +
                }
         | 
| 26 | 
            +
             | 
| 27 | 
            +
                # Ensure the .vscode directory exists
         | 
| 28 | 
            +
                if not os.path.exists(".vscode"):
         | 
| 29 | 
            +
                    os.makedirs(".vscode")
         | 
| 30 | 
            +
             | 
| 31 | 
            +
                # Write the updated configuration to launch.json
         | 
| 32 | 
            +
                with open(launch_file_path, "w") as f:
         | 
| 33 | 
            +
                    json.dump(new_config, f, indent=4)
         | 
| 34 | 
            +
                print(f"Updated {launch_file_path} with host: {host} and port: {port}")
         | 
| 35 | 
            +
             | 
| 36 | 
            +
            def is_port_in_use(host, port):
         | 
| 37 | 
            +
                with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
         | 
| 38 | 
            +
                    return s.connect_ex((host, port)) == 0
         | 
| 39 | 
            +
             | 
| 40 | 
            +
            def setup_debug(is_main_process=True, max_retries=10, port_range=(10000, 20000)):
         | 
| 41 | 
            +
                if is_main_process:
         | 
| 42 | 
            +
                    host = os.environ['SLURM_NODELIST'].split(',')[0]
         | 
| 43 | 
            +
             | 
| 44 | 
            +
                    for _ in range(max_retries):
         | 
| 45 | 
            +
                        port = random.randint(*port_range)
         | 
| 46 | 
            +
                        try:
         | 
| 47 | 
            +
                            if is_port_in_use(host, port):
         | 
| 48 | 
            +
                                print(f"Port {port} is already in use, trying another...")
         | 
| 49 | 
            +
                                continue
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                            # 更新 launch.json
         | 
| 52 | 
            +
                            update_vscode_launch_file(host, port)
         | 
| 53 | 
            +
             | 
| 54 | 
            +
                            print("master_addr = ", host)
         | 
| 55 | 
            +
                            debugpy.listen((host, port))
         | 
| 56 | 
            +
                            print(f"Waiting for debugger attach at port {port}...", flush=True)
         | 
| 57 | 
            +
                            debugpy.wait_for_client()
         | 
| 58 | 
            +
                            print("Debugger attached", flush=True)
         | 
| 59 | 
            +
                            return
         | 
| 60 | 
            +
                        except Exception as e:
         | 
| 61 | 
            +
                            print(f"Failed to bind to port {port}: {e}")
         | 
| 62 | 
            +
             | 
| 63 | 
            +
                    raise RuntimeError("Could not find a free port for debugpy after several attempts.")
         | 
    	
        pi3/utils/geometry.py
    ADDED
    
    | @@ -0,0 +1,375 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import numpy as np
         | 
| 2 | 
            +
            import torch
         | 
| 3 | 
            +
            import torch.nn.functional as F
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            def se3_inverse(T):
         | 
| 6 | 
            +
                """
         | 
| 7 | 
            +
                Computes the inverse of a batch of SE(3) matrices.
         | 
| 8 | 
            +
                T: Tensor of shape (B, 4, 4)
         | 
| 9 | 
            +
                """
         | 
| 10 | 
            +
                if len(T.shape) == 2:
         | 
| 11 | 
            +
                    T = T[None]
         | 
| 12 | 
            +
                    unseq_flag = True
         | 
| 13 | 
            +
                else:
         | 
| 14 | 
            +
                    unseq_flag = False
         | 
| 15 | 
            +
             | 
| 16 | 
            +
                if torch.is_tensor(T):
         | 
| 17 | 
            +
                    R = T[:, :3, :3]
         | 
| 18 | 
            +
                    t = T[:, :3, 3].unsqueeze(-1)
         | 
| 19 | 
            +
                    R_inv = R.transpose(-2, -1)
         | 
| 20 | 
            +
                    t_inv = -torch.matmul(R_inv, t)
         | 
| 21 | 
            +
                    T_inv = torch.cat([
         | 
| 22 | 
            +
                        torch.cat([R_inv, t_inv], dim=-1),
         | 
| 23 | 
            +
                        torch.tensor([0, 0, 0, 1], device=T.device, dtype=T.dtype).repeat(T.shape[0], 1, 1)
         | 
| 24 | 
            +
                    ], dim=1)
         | 
| 25 | 
            +
                else:
         | 
| 26 | 
            +
                    R = T[:, :3, :3]
         | 
| 27 | 
            +
                    t = T[:, :3, 3, np.newaxis]
         | 
| 28 | 
            +
             | 
| 29 | 
            +
                    R_inv = np.swapaxes(R, -2, -1)
         | 
| 30 | 
            +
                    t_inv = -R_inv @ t
         | 
| 31 | 
            +
             | 
| 32 | 
            +
                    bottom_row = np.zeros((T.shape[0], 1, 4), dtype=T.dtype)
         | 
| 33 | 
            +
                    bottom_row[:, :, 3] = 1
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                    top_part = np.concatenate([R_inv, t_inv], axis=-1)
         | 
| 36 | 
            +
                    T_inv = np.concatenate([top_part, bottom_row], axis=1)
         | 
| 37 | 
            +
             | 
| 38 | 
            +
                if unseq_flag:
         | 
| 39 | 
            +
                    T_inv = T_inv[0]
         | 
| 40 | 
            +
                return T_inv
         | 
| 41 | 
            +
             | 
| 42 | 
            +
            def get_pixel(H, W):
         | 
| 43 | 
            +
                # get 2D pixels (u, v) for image_a in cam_a pixel space
         | 
| 44 | 
            +
                u_a, v_a = np.meshgrid(np.arange(W), np.arange(H))
         | 
| 45 | 
            +
                # u_a = np.flip(u_a, axis=1)
         | 
| 46 | 
            +
                # v_a = np.flip(v_a, axis=0)
         | 
| 47 | 
            +
                pixels_a = np.stack([
         | 
| 48 | 
            +
                    u_a.flatten() + 0.5, 
         | 
| 49 | 
            +
                    v_a.flatten() + 0.5, 
         | 
| 50 | 
            +
                    np.ones_like(u_a.flatten())
         | 
| 51 | 
            +
                ], axis=0)
         | 
| 52 | 
            +
                
         | 
| 53 | 
            +
                return pixels_a
         | 
| 54 | 
            +
             | 
| 55 | 
            +
            def depthmap_to_absolute_camera_coordinates(depthmap, camera_intrinsics, camera_pose, z_far=0, **kw):
         | 
| 56 | 
            +
                """
         | 
| 57 | 
            +
                Args:
         | 
| 58 | 
            +
                    - depthmap (HxW array):
         | 
| 59 | 
            +
                    - camera_intrinsics: a 3x3 matrix
         | 
| 60 | 
            +
                    - camera_pose: a 4x3 or 4x4 cam2world matrix
         | 
| 61 | 
            +
                Returns:
         | 
| 62 | 
            +
                    pointmap of absolute coordinates (HxWx3 array), and a mask specifying valid pixels."""
         | 
| 63 | 
            +
                X_cam, valid_mask = depthmap_to_camera_coordinates(depthmap, camera_intrinsics)
         | 
| 64 | 
            +
                if z_far > 0:
         | 
| 65 | 
            +
                    valid_mask = valid_mask & (depthmap < z_far)
         | 
| 66 | 
            +
             | 
| 67 | 
            +
                X_world = X_cam # default
         | 
| 68 | 
            +
                if camera_pose is not None:
         | 
| 69 | 
            +
                    # R_cam2world = np.float32(camera_params["R_cam2world"])
         | 
| 70 | 
            +
                    # t_cam2world = np.float32(camera_params["t_cam2world"]).squeeze()
         | 
| 71 | 
            +
                    R_cam2world = camera_pose[:3, :3]
         | 
| 72 | 
            +
                    t_cam2world = camera_pose[:3, 3]
         | 
| 73 | 
            +
             | 
| 74 | 
            +
                    # Express in absolute coordinates (invalid depth values)
         | 
| 75 | 
            +
                    X_world = np.einsum("ik, vuk -> vui", R_cam2world, X_cam) + t_cam2world[None, None, :]
         | 
| 76 | 
            +
             | 
| 77 | 
            +
                return X_world, valid_mask
         | 
| 78 | 
            +
             | 
| 79 | 
            +
             | 
| 80 | 
            +
            def depthmap_to_camera_coordinates(depthmap, camera_intrinsics, pseudo_focal=None):
         | 
| 81 | 
            +
                """
         | 
| 82 | 
            +
                Args:
         | 
| 83 | 
            +
                    - depthmap (HxW array):
         | 
| 84 | 
            +
                    - camera_intrinsics: a 3x3 matrix
         | 
| 85 | 
            +
                Returns:
         | 
| 86 | 
            +
                    pointmap of absolute coordinates (HxWx3 array), and a mask specifying valid pixels.
         | 
| 87 | 
            +
                """
         | 
| 88 | 
            +
                camera_intrinsics = np.float32(camera_intrinsics)
         | 
| 89 | 
            +
                H, W = depthmap.shape
         | 
| 90 | 
            +
             | 
| 91 | 
            +
                # Compute 3D ray associated with each pixel
         | 
| 92 | 
            +
                # Strong assumption: there are no skew terms
         | 
| 93 | 
            +
                # assert camera_intrinsics[0, 1] == 0.0
         | 
| 94 | 
            +
                # assert camera_intrinsics[1, 0] == 0.0
         | 
| 95 | 
            +
                if pseudo_focal is None:
         | 
| 96 | 
            +
                    fu = camera_intrinsics[0, 0]
         | 
| 97 | 
            +
                    fv = camera_intrinsics[1, 1]
         | 
| 98 | 
            +
                else:
         | 
| 99 | 
            +
                    assert pseudo_focal.shape == (H, W)
         | 
| 100 | 
            +
                    fu = fv = pseudo_focal
         | 
| 101 | 
            +
                cu = camera_intrinsics[0, 2]
         | 
| 102 | 
            +
                cv = camera_intrinsics[1, 2]
         | 
| 103 | 
            +
             | 
| 104 | 
            +
                u, v = np.meshgrid(np.arange(W), np.arange(H))
         | 
| 105 | 
            +
                z_cam = depthmap
         | 
| 106 | 
            +
                x_cam = (u - cu) * z_cam / fu
         | 
| 107 | 
            +
                y_cam = (v - cv) * z_cam / fv
         | 
| 108 | 
            +
                X_cam = np.stack((x_cam, y_cam, z_cam), axis=-1).astype(np.float32)
         | 
| 109 | 
            +
             | 
| 110 | 
            +
                # Mask for valid coordinates
         | 
| 111 | 
            +
                valid_mask = (depthmap > 0.0)
         | 
| 112 | 
            +
                # Invalid any depth > 80m
         | 
| 113 | 
            +
                valid_mask = valid_mask
         | 
| 114 | 
            +
                return X_cam, valid_mask
         | 
| 115 | 
            +
             | 
| 116 | 
            +
            def homogenize_points(
         | 
| 117 | 
            +
                points,
         | 
| 118 | 
            +
            ):
         | 
| 119 | 
            +
                """Convert batched points (xyz) to (xyz1)."""
         | 
| 120 | 
            +
                return torch.cat([points, torch.ones_like(points[..., :1])], dim=-1)
         | 
| 121 | 
            +
             | 
| 122 | 
            +
             | 
| 123 | 
            +
            def get_gt_warp(depth1, depth2, T_1to2, K1, K2, depth_interpolation_mode = 'bilinear', relative_depth_error_threshold = 0.05, H = None, W = None):
         | 
| 124 | 
            +
                
         | 
| 125 | 
            +
                if H is None:
         | 
| 126 | 
            +
                    B,H,W = depth1.shape
         | 
| 127 | 
            +
                else:
         | 
| 128 | 
            +
                    B = depth1.shape[0]
         | 
| 129 | 
            +
                with torch.no_grad():
         | 
| 130 | 
            +
                    x1_n = torch.meshgrid(
         | 
| 131 | 
            +
                        *[
         | 
| 132 | 
            +
                            torch.linspace(
         | 
| 133 | 
            +
                                -1 + 1 / n, 1 - 1 / n, n, device=depth1.device
         | 
| 134 | 
            +
                            )
         | 
| 135 | 
            +
                            for n in (B, H, W)
         | 
| 136 | 
            +
                        ],
         | 
| 137 | 
            +
                        indexing = 'ij'
         | 
| 138 | 
            +
                    )
         | 
| 139 | 
            +
                    x1_n = torch.stack((x1_n[2], x1_n[1]), dim=-1).reshape(B, H * W, 2)
         | 
| 140 | 
            +
                    mask, x2 = warp_kpts(
         | 
| 141 | 
            +
                        x1_n.double(),
         | 
| 142 | 
            +
                        depth1.double(),
         | 
| 143 | 
            +
                        depth2.double(),
         | 
| 144 | 
            +
                        T_1to2.double(),
         | 
| 145 | 
            +
                        K1.double(),
         | 
| 146 | 
            +
                        K2.double(),
         | 
| 147 | 
            +
                        depth_interpolation_mode = depth_interpolation_mode,
         | 
| 148 | 
            +
                        relative_depth_error_threshold = relative_depth_error_threshold,
         | 
| 149 | 
            +
                    )
         | 
| 150 | 
            +
                    prob = mask.float().reshape(B, H, W)
         | 
| 151 | 
            +
                    x2 = x2.reshape(B, H, W, 2)
         | 
| 152 | 
            +
                    return x2, prob
         | 
| 153 | 
            +
             | 
| 154 | 
            +
            @torch.no_grad()
         | 
| 155 | 
            +
            def warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1, smooth_mask = False, return_relative_depth_error = False, depth_interpolation_mode = "bilinear", relative_depth_error_threshold = 0.05):
         | 
| 156 | 
            +
                """Warp kpts0 from I0 to I1 with depth, K and Rt
         | 
| 157 | 
            +
                Also check covisibility and depth consistency.
         | 
| 158 | 
            +
                Depth is consistent if relative error < 0.2 (hard-coded).
         | 
| 159 | 
            +
                # https://github.com/zju3dv/LoFTR/blob/94e98b695be18acb43d5d3250f52226a8e36f839/src/loftr/utils/geometry.py adapted from here
         | 
| 160 | 
            +
                Args:
         | 
| 161 | 
            +
                    kpts0 (torch.Tensor): [N, L, 2] - <x, y>, should be normalized in (-1,1)
         | 
| 162 | 
            +
                    depth0 (torch.Tensor): [N, H, W],
         | 
| 163 | 
            +
                    depth1 (torch.Tensor): [N, H, W],
         | 
| 164 | 
            +
                    T_0to1 (torch.Tensor): [N, 3, 4],
         | 
| 165 | 
            +
                    K0 (torch.Tensor): [N, 3, 3],
         | 
| 166 | 
            +
                    K1 (torch.Tensor): [N, 3, 3],
         | 
| 167 | 
            +
                Returns:
         | 
| 168 | 
            +
                    calculable_mask (torch.Tensor): [N, L]
         | 
| 169 | 
            +
                    warped_keypoints0 (torch.Tensor): [N, L, 2] <x0_hat, y1_hat>
         | 
| 170 | 
            +
                """
         | 
| 171 | 
            +
                (
         | 
| 172 | 
            +
                    n,
         | 
| 173 | 
            +
                    h,
         | 
| 174 | 
            +
                    w,
         | 
| 175 | 
            +
                ) = depth0.shape
         | 
| 176 | 
            +
                if depth_interpolation_mode == "combined":
         | 
| 177 | 
            +
                    # Inspired by approach in inloc, try to fill holes from bilinear interpolation by nearest neighbour interpolation
         | 
| 178 | 
            +
                    if smooth_mask:
         | 
| 179 | 
            +
                        raise NotImplementedError("Combined bilinear and NN warp not implemented")
         | 
| 180 | 
            +
                    valid_bilinear, warp_bilinear = warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1, 
         | 
| 181 | 
            +
                              smooth_mask = smooth_mask, 
         | 
| 182 | 
            +
                              return_relative_depth_error = return_relative_depth_error, 
         | 
| 183 | 
            +
                              depth_interpolation_mode = "bilinear",
         | 
| 184 | 
            +
                              relative_depth_error_threshold = relative_depth_error_threshold)
         | 
| 185 | 
            +
                    valid_nearest, warp_nearest = warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1, 
         | 
| 186 | 
            +
                              smooth_mask = smooth_mask, 
         | 
| 187 | 
            +
                              return_relative_depth_error = return_relative_depth_error, 
         | 
| 188 | 
            +
                              depth_interpolation_mode = "nearest-exact",
         | 
| 189 | 
            +
                              relative_depth_error_threshold = relative_depth_error_threshold)
         | 
| 190 | 
            +
                    nearest_valid_bilinear_invalid = (~valid_bilinear).logical_and(valid_nearest) 
         | 
| 191 | 
            +
                    warp = warp_bilinear.clone()
         | 
| 192 | 
            +
                    warp[nearest_valid_bilinear_invalid] = warp_nearest[nearest_valid_bilinear_invalid]
         | 
| 193 | 
            +
                    valid = valid_bilinear | valid_nearest
         | 
| 194 | 
            +
                    return valid, warp
         | 
| 195 | 
            +
                    
         | 
| 196 | 
            +
                    
         | 
| 197 | 
            +
                kpts0_depth = F.grid_sample(depth0[:, None], kpts0[:, :, None], mode = depth_interpolation_mode, align_corners=False)[
         | 
| 198 | 
            +
                    :, 0, :, 0
         | 
| 199 | 
            +
                ]
         | 
| 200 | 
            +
                kpts0 = torch.stack(
         | 
| 201 | 
            +
                    (w * (kpts0[..., 0] + 1) / 2, h * (kpts0[..., 1] + 1) / 2), dim=-1
         | 
| 202 | 
            +
                )  # [-1+1/h, 1-1/h] -> [0.5, h-0.5]
         | 
| 203 | 
            +
                # Sample depth, get calculable_mask on depth != 0
         | 
| 204 | 
            +
                # nonzero_mask = kpts0_depth != 0
         | 
| 205 | 
            +
                # Sample depth, get calculable_mask on depth > 0
         | 
| 206 | 
            +
                nonzero_mask = kpts0_depth > 0
         | 
| 207 | 
            +
             | 
| 208 | 
            +
                # Unproject
         | 
| 209 | 
            +
                kpts0_h = (
         | 
| 210 | 
            +
                    torch.cat([kpts0, torch.ones_like(kpts0[:, :, [0]])], dim=-1)
         | 
| 211 | 
            +
                    * kpts0_depth[..., None]
         | 
| 212 | 
            +
                )  # (N, L, 3)
         | 
| 213 | 
            +
                kpts0_n = K0.inverse() @ kpts0_h.transpose(2, 1)  # (N, 3, L)
         | 
| 214 | 
            +
                kpts0_cam = kpts0_n
         | 
| 215 | 
            +
             | 
| 216 | 
            +
                # Rigid Transform
         | 
| 217 | 
            +
                w_kpts0_cam = T_0to1[:, :3, :3] @ kpts0_cam + T_0to1[:, :3, [3]]  # (N, 3, L)
         | 
| 218 | 
            +
                w_kpts0_depth_computed = w_kpts0_cam[:, 2, :]
         | 
| 219 | 
            +
             | 
| 220 | 
            +
                # Project
         | 
| 221 | 
            +
                w_kpts0_h = (K1 @ w_kpts0_cam).transpose(2, 1)  # (N, L, 3)
         | 
| 222 | 
            +
                w_kpts0 = w_kpts0_h[:, :, :2] / (
         | 
| 223 | 
            +
                    w_kpts0_h[:, :, [2]] + 1e-4
         | 
| 224 | 
            +
                )  # (N, L, 2), +1e-4 to avoid zero depth
         | 
| 225 | 
            +
             | 
| 226 | 
            +
                # Covisible Check
         | 
| 227 | 
            +
                h, w = depth1.shape[1:3]
         | 
| 228 | 
            +
                covisible_mask = (
         | 
| 229 | 
            +
                    (w_kpts0[:, :, 0] > 0)
         | 
| 230 | 
            +
                    * (w_kpts0[:, :, 0] < w - 1)
         | 
| 231 | 
            +
                    * (w_kpts0[:, :, 1] > 0)
         | 
| 232 | 
            +
                    * (w_kpts0[:, :, 1] < h - 1)
         | 
| 233 | 
            +
                )
         | 
| 234 | 
            +
                w_kpts0 = torch.stack(
         | 
| 235 | 
            +
                    (2 * w_kpts0[..., 0] / w - 1, 2 * w_kpts0[..., 1] / h - 1), dim=-1
         | 
| 236 | 
            +
                )  # from [0.5,h-0.5] -> [-1+1/h, 1-1/h]
         | 
| 237 | 
            +
                # w_kpts0[~covisible_mask, :] = -5 # xd
         | 
| 238 | 
            +
             | 
| 239 | 
            +
                w_kpts0_depth = F.grid_sample(
         | 
| 240 | 
            +
                    depth1[:, None], w_kpts0[:, :, None], mode=depth_interpolation_mode, align_corners=False
         | 
| 241 | 
            +
                )[:, 0, :, 0]
         | 
| 242 | 
            +
                
         | 
| 243 | 
            +
                relative_depth_error = (
         | 
| 244 | 
            +
                    (w_kpts0_depth - w_kpts0_depth_computed) / w_kpts0_depth
         | 
| 245 | 
            +
                ).abs()
         | 
| 246 | 
            +
                if not smooth_mask:
         | 
| 247 | 
            +
                    consistent_mask = relative_depth_error < relative_depth_error_threshold
         | 
| 248 | 
            +
                else:
         | 
| 249 | 
            +
                    consistent_mask = (-relative_depth_error/smooth_mask).exp()
         | 
| 250 | 
            +
                valid_mask = nonzero_mask * covisible_mask * consistent_mask
         | 
| 251 | 
            +
                if return_relative_depth_error:
         | 
| 252 | 
            +
                    return relative_depth_error, w_kpts0
         | 
| 253 | 
            +
                else:
         | 
| 254 | 
            +
                    return valid_mask, w_kpts0
         | 
| 255 | 
            +
             | 
| 256 | 
            +
             | 
| 257 | 
            +
            def geotrf(Trf, pts, ncol=None, norm=False):
         | 
| 258 | 
            +
                """ Apply a geometric transformation to a list of 3-D points.
         | 
| 259 | 
            +
             | 
| 260 | 
            +
                H: 3x3 or 4x4 projection matrix (typically a Homography)
         | 
| 261 | 
            +
                p: numpy/torch/tuple of coordinates. Shape must be (...,2) or (...,3)
         | 
| 262 | 
            +
             | 
| 263 | 
            +
                ncol: int. number of columns of the result (2 or 3)
         | 
| 264 | 
            +
                norm: float. if != 0, the resut is projected on the z=norm plane.
         | 
| 265 | 
            +
             | 
| 266 | 
            +
                Returns an array of projected 2d points.
         | 
| 267 | 
            +
                """
         | 
| 268 | 
            +
                assert Trf.ndim >= 2
         | 
| 269 | 
            +
                if isinstance(Trf, np.ndarray):
         | 
| 270 | 
            +
                    pts = np.asarray(pts)
         | 
| 271 | 
            +
                elif isinstance(Trf, torch.Tensor):
         | 
| 272 | 
            +
                    pts = torch.as_tensor(pts, dtype=Trf.dtype)
         | 
| 273 | 
            +
             | 
| 274 | 
            +
                # adapt shape if necessary
         | 
| 275 | 
            +
                output_reshape = pts.shape[:-1]
         | 
| 276 | 
            +
                ncol = ncol or pts.shape[-1]
         | 
| 277 | 
            +
             | 
| 278 | 
            +
                # optimized code
         | 
| 279 | 
            +
                if (isinstance(Trf, torch.Tensor) and isinstance(pts, torch.Tensor) and
         | 
| 280 | 
            +
                        Trf.ndim == 3 and pts.ndim == 4):
         | 
| 281 | 
            +
                    d = pts.shape[3]
         | 
| 282 | 
            +
                    if Trf.shape[-1] == d:
         | 
| 283 | 
            +
                        pts = torch.einsum("bij, bhwj -> bhwi", Trf, pts)
         | 
| 284 | 
            +
                    elif Trf.shape[-1] == d + 1:
         | 
| 285 | 
            +
                        pts = torch.einsum("bij, bhwj -> bhwi", Trf[:, :d, :d], pts) + Trf[:, None, None, :d, d]
         | 
| 286 | 
            +
                    else:
         | 
| 287 | 
            +
                        raise ValueError(f'bad shape, not ending with 3 or 4, for {pts.shape=}')
         | 
| 288 | 
            +
                else:
         | 
| 289 | 
            +
                    if Trf.ndim >= 3:
         | 
| 290 | 
            +
                        n = Trf.ndim - 2
         | 
| 291 | 
            +
                        assert Trf.shape[:n] == pts.shape[:n], 'batch size does not match'
         | 
| 292 | 
            +
                        Trf = Trf.reshape(-1, Trf.shape[-2], Trf.shape[-1])
         | 
| 293 | 
            +
             | 
| 294 | 
            +
                        if pts.ndim > Trf.ndim:
         | 
| 295 | 
            +
                            # Trf == (B,d,d) & pts == (B,H,W,d) --> (B, H*W, d)
         | 
| 296 | 
            +
                            pts = pts.reshape(Trf.shape[0], -1, pts.shape[-1])
         | 
| 297 | 
            +
                        elif pts.ndim == 2:
         | 
| 298 | 
            +
                            # Trf == (B,d,d) & pts == (B,d) --> (B, 1, d)
         | 
| 299 | 
            +
                            pts = pts[:, None, :]
         | 
| 300 | 
            +
             | 
| 301 | 
            +
                    if pts.shape[-1] + 1 == Trf.shape[-1]:
         | 
| 302 | 
            +
                        Trf = Trf.swapaxes(-1, -2)  # transpose Trf
         | 
| 303 | 
            +
                        pts = pts @ Trf[..., :-1, :] + Trf[..., -1:, :]
         | 
| 304 | 
            +
                    elif pts.shape[-1] == Trf.shape[-1]:
         | 
| 305 | 
            +
                        Trf = Trf.swapaxes(-1, -2)  # transpose Trf
         | 
| 306 | 
            +
                        pts = pts @ Trf
         | 
| 307 | 
            +
                    else:
         | 
| 308 | 
            +
                        pts = Trf @ pts.T
         | 
| 309 | 
            +
                        if pts.ndim >= 2:
         | 
| 310 | 
            +
                            pts = pts.swapaxes(-1, -2)
         | 
| 311 | 
            +
             | 
| 312 | 
            +
                if norm:
         | 
| 313 | 
            +
                    pts = pts / pts[..., -1:]  # DONT DO /= BECAUSE OF WEIRD PYTORCH BUG
         | 
| 314 | 
            +
                    if norm != 1:
         | 
| 315 | 
            +
                        pts *= norm
         | 
| 316 | 
            +
             | 
| 317 | 
            +
                res = pts[..., :ncol].reshape(*output_reshape, ncol)
         | 
| 318 | 
            +
                return res
         | 
| 319 | 
            +
             | 
| 320 | 
            +
             | 
| 321 | 
            +
            def inv(mat):
         | 
| 322 | 
            +
                """ Invert a torch or numpy matrix
         | 
| 323 | 
            +
                """
         | 
| 324 | 
            +
                if isinstance(mat, torch.Tensor):
         | 
| 325 | 
            +
                    return torch.linalg.inv(mat)
         | 
| 326 | 
            +
                if isinstance(mat, np.ndarray):
         | 
| 327 | 
            +
                    return np.linalg.inv(mat)
         | 
| 328 | 
            +
                raise ValueError(f'bad matrix type = {type(mat)}')
         | 
| 329 | 
            +
             | 
| 330 | 
            +
            def opencv_camera_to_plucker(poses, K, H, W):
         | 
| 331 | 
            +
                device = poses.device
         | 
| 332 | 
            +
                B = poses.shape[0]
         | 
| 333 | 
            +
             | 
| 334 | 
            +
                pixel = torch.from_numpy(get_pixel(H, W).astype(np.float32)).to(device).T.reshape(H, W, 3)[None].repeat(B, 1, 1, 1)         # (3, H, W)
         | 
| 335 | 
            +
                pixel = torch.einsum('bij, bhwj -> bhwi', torch.inverse(K), pixel)
         | 
| 336 | 
            +
                ray_directions = torch.einsum('bij, bhwj -> bhwi', poses[..., :3, :3], pixel)
         | 
| 337 | 
            +
             | 
| 338 | 
            +
                ray_origins = poses[..., :3, 3][:, None, None].repeat(1, H, W, 1)
         | 
| 339 | 
            +
             | 
| 340 | 
            +
                ray_directions = ray_directions / ray_directions.norm(dim=-1, keepdim=True)
         | 
| 341 | 
            +
                plucker_normal = torch.cross(ray_origins, ray_directions, dim=-1)
         | 
| 342 | 
            +
                plucker_ray = torch.cat([ray_directions, plucker_normal], dim=-1)
         | 
| 343 | 
            +
             | 
| 344 | 
            +
                return plucker_ray
         | 
| 345 | 
            +
             | 
| 346 | 
            +
             | 
| 347 | 
            +
            def depth_edge(depth: torch.Tensor, atol: float = None, rtol: float = None, kernel_size: int = 3, mask: torch.Tensor = None) -> torch.BoolTensor:
         | 
| 348 | 
            +
                """
         | 
| 349 | 
            +
                Compute the edge mask of a depth map. The edge is defined as the pixels whose neighbors have a large difference in depth.
         | 
| 350 | 
            +
                
         | 
| 351 | 
            +
                Args:
         | 
| 352 | 
            +
                    depth (torch.Tensor): shape (..., height, width), linear depth map
         | 
| 353 | 
            +
                    atol (float): absolute tolerance
         | 
| 354 | 
            +
                    rtol (float): relative tolerance
         | 
| 355 | 
            +
             | 
| 356 | 
            +
                Returns:
         | 
| 357 | 
            +
                    edge (torch.Tensor): shape (..., height, width) of dtype torch.bool
         | 
| 358 | 
            +
                """
         | 
| 359 | 
            +
                shape = depth.shape
         | 
| 360 | 
            +
                depth = depth.reshape(-1, 1, *shape[-2:])
         | 
| 361 | 
            +
                if mask is not None:
         | 
| 362 | 
            +
                    mask = mask.reshape(-1, 1, *shape[-2:])
         | 
| 363 | 
            +
             | 
| 364 | 
            +
                if mask is None:
         | 
| 365 | 
            +
                    diff = (F.max_pool2d(depth, kernel_size, stride=1, padding=kernel_size // 2) + F.max_pool2d(-depth, kernel_size, stride=1, padding=kernel_size // 2))
         | 
| 366 | 
            +
                else:
         | 
| 367 | 
            +
                    diff = (F.max_pool2d(torch.where(mask, depth, -torch.inf), kernel_size, stride=1, padding=kernel_size // 2) + F.max_pool2d(torch.where(mask, -depth, -torch.inf), kernel_size, stride=1, padding=kernel_size // 2))
         | 
| 368 | 
            +
             | 
| 369 | 
            +
                edge = torch.zeros_like(depth, dtype=torch.bool)
         | 
| 370 | 
            +
                if atol is not None:
         | 
| 371 | 
            +
                    edge |= diff > atol
         | 
| 372 | 
            +
                if rtol is not None:
         | 
| 373 | 
            +
                    edge |= (diff / depth).nan_to_num_() > rtol
         | 
| 374 | 
            +
                edge = edge.reshape(*shape)
         | 
| 375 | 
            +
                return edge
         | 
    	
        requirements.txt
    ADDED
    
    | @@ -0,0 +1,13 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            torch==2.5.1
         | 
| 2 | 
            +
            torchvision==0.20.1
         | 
| 3 | 
            +
            numpy==1.26.4
         | 
| 4 | 
            +
            pillow
         | 
| 5 | 
            +
            opencv-python
         | 
| 6 | 
            +
            plyfile
         | 
| 7 | 
            +
            huggingface_hub
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            # below for gradio
         | 
| 10 | 
            +
            gradio
         | 
| 11 | 
            +
            trimesh
         | 
| 12 | 
            +
            matplotlib
         | 
| 13 | 
            +
            scipy
         |