diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..dc913611dd82e4b649b8f24b162d2cd5b24fe807
--- /dev/null
+++ b/app.py
@@ -0,0 +1,157 @@
+import sys
+sys.path.append("flash3d")
+
+from omegaconf import OmegaConf
+import gradio as gr
+import spaces
+import torch
+import torchvision.transforms as TT
+import torchvision.transforms.functional as TTF
+from huggingface_hub import hf_hub_download
+
+from networks.gaussian_predictor import GaussianPredictor
+from util.vis3d import save_ply
+
+
+def main():
+    if torch.cuda.is_available():
+        device = "cuda:0"
+    else:
+        device = "cpu"
+
+    model_cfg_path = hf_hub_download(repo_id="einsafutdinov/flash3d", 
+                                     filename="config_re10k_v1.yaml")
+    model_path = hf_hub_download(repo_id="einsafutdinov/flash3d", 
+                                 filename="model_re10k_v1.pth")
+
+    cfg = OmegaConf.load(model_cfg_path)
+    model = GaussianPredictor(cfg)
+    device = torch.device("cuda:0")
+    model.to(device)
+    model.load_model(model_path)
+
+    pad_border_fn = TT.Pad((cfg.dataset.pad_border_aug, cfg.dataset.pad_border_aug))
+    to_tensor = TT.ToTensor()
+
+    def check_input_image(input_image):
+        if input_image is None:
+            raise gr.Error("No image uploaded!")
+
+    def preprocess(image):
+        image = TTF.resize(
+            image, (cfg.dataset.height, cfg.dataset.width), 
+            interpolation=TT.InterpolationMode.BICUBIC
+        )
+        image = pad_border_fn(image)
+        return image
+
+    @spaces.GPU()
+    def reconstruct_and_export(image):
+        """
+        Passes image through model, outputs reconstruction in form of a dict of tensors.
+        """
+        image = to_tensor(image).to(device).unsqueeze(0)
+        inputs = {
+            ("color_aug", 0, 0): image,
+        }
+
+        outputs = model(inputs)
+
+        # export reconstruction to ply
+        save_ply(outputs, ply_out_path, num_gauss=2)
+
+        return ply_out_path
+    
+    ply_out_path = f'./mesh.ply'
+
+    css = """
+        h1 {
+            text-align: center;
+            display:block;
+        }
+        """
+
+    with gr.Blocks(css=css) as demo:
+        gr.Markdown(
+            """
+            # Flash3D
+            """
+            )
+        with gr.Row(variant="panel"):
+            with gr.Column(scale=1):
+                with gr.Row():
+                    input_image = gr.Image(
+                        label="Input Image",
+                        image_mode="RGBA",
+                        sources="upload",
+                        type="pil",
+                        elem_id="content_image",
+                    )
+                with gr.Row():
+                    submit = gr.Button("Generate", elem_id="generate", variant="primary")
+
+                with gr.Row(variant="panel"): 
+                    gr.Examples(
+                        examples=[
+                            './demo_examples/bedroom_01.png',
+                            './demo_examples/kitti_02.png',
+                            './demo_examples/kitti_03.png',
+                            './demo_examples/re10k_04.jpg',
+                            './demo_examples/re10k_05.jpg',
+                            './demo_examples/re10k_06.jpg',
+                        ],
+                        inputs=[input_image],
+                        cache_examples=False,
+                        label="Examples",
+                        examples_per_page=20,
+                    )
+
+                with gr.Row():
+                    processed_image = gr.Image(label="Processed Image", interactive=False)
+
+            with gr.Column(scale=2):
+                with gr.Row():
+                    with gr.Tab("Reconstruction"):
+                        output_model = gr.Model3D(
+                            height=512,
+                            label="Output Model",
+                            interactive=False
+                        )
+
+        # gr.Markdown(
+        # """
+        #     ## Comments:
+        #     1. If you run the demo online, the first example you upload should take about 4.5 seconds (with preprocessing, saving and overhead), the following take about 1.5s.
+        #     2. The 3D viewer shows a .ply mesh extracted from a mix of 3D Gaussians. This is only an approximations and artefacts might show.
+        #     3. Known limitations include:
+        #     - a black dot appearing on the model from some viewpoints
+        #     - see-through parts of objects, especially on the back: this is due to the model performing less well on more complicated shapes
+        #     - back of objects are blurry: this is a model limiation due to it being deterministic
+        #     4. Our model is of comparable quality to state-of-the-art methods, and is **much** cheaper to train and run.
+        #     ## How does it work?
+        #     Splatter Image formulates 3D reconstruction as an image-to-image translation task. It maps the input image to another image, 
+        #     in which every pixel represents one 3D Gaussian and the channels of the output represent parameters of these Gaussians, including their shapes, colours and locations.
+        #     The resulting image thus represents a set of Gaussians (almost like a point cloud) which reconstruct the shape and colour of the object.
+        #     The method is very cheap: the reconstruction amounts to a single forward pass of a neural network with only 2D operators (2D convolutions and attention).
+        #     The rendering is also very fast, due to using Gaussian Splatting.
+        #     Combined, this results in very cheap training and high-quality results.
+        #     For more results see the [project page](https://szymanowiczs.github.io/splatter-image) and the [CVPR article](https://arxiv.org/abs/2312.13150).
+        #     """
+        # )
+
+        submit.click(fn=check_input_image, inputs=[input_image]).success(
+            fn=preprocess,
+            inputs=[input_image],
+            outputs=[processed_image],
+        ).success(
+            fn=reconstruct_and_export,
+            inputs=[processed_image],
+            outputs=[output_model],
+        )
+
+    demo.queue(max_size=1)
+    demo.launch(share=True)
+
+
+if __name__ == "__main__":
+    main()
diff --git a/demo_examples/bedroom_01.png b/demo_examples/bedroom_01.png
new file mode 100644
index 0000000000000000000000000000000000000000..5e1e7f4940a28cde585f8be4e337e50e71e3a0ac
Binary files /dev/null and b/demo_examples/bedroom_01.png differ
diff --git a/demo_examples/kitti_02.png b/demo_examples/kitti_02.png
new file mode 100644
index 0000000000000000000000000000000000000000..e4bcf249f1280a2df1553c4ae6bb8d123a0500c9
Binary files /dev/null and b/demo_examples/kitti_02.png differ
diff --git a/demo_examples/kitti_03.png b/demo_examples/kitti_03.png
new file mode 100644
index 0000000000000000000000000000000000000000..4e25037fff2e08c5f28e7dff1df50e72b0ede003
Binary files /dev/null and b/demo_examples/kitti_03.png differ
diff --git a/demo_examples/re10k_04.jpg b/demo_examples/re10k_04.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..3e87f75b23d49f1c1ad15c3fd5bf7dde20a33a26
Binary files /dev/null and b/demo_examples/re10k_04.jpg differ
diff --git a/demo_examples/re10k_05.jpg b/demo_examples/re10k_05.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..e76a0c1d5febff7836fb7806ef93e052c11be3f9
Binary files /dev/null and b/demo_examples/re10k_05.jpg differ
diff --git a/demo_examples/re10k_06.jpg b/demo_examples/re10k_06.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..6f8532c7f5b6b7b1e50c6da24e7abdbe093ee7e1
Binary files /dev/null and b/demo_examples/re10k_06.jpg differ
diff --git a/flash3d/networks/depth_decoder.py b/flash3d/networks/depth_decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..059d8a8714f61bb1773aee60f786788d605f4eec
--- /dev/null
+++ b/flash3d/networks/depth_decoder.py
@@ -0,0 +1,81 @@
+# Copyright Niantic 2019. Patent Pending. All rights reserved.
+#
+# This software is licensed under the terms of the Monodepth2 licence
+# which allows for non-commercial use only, the full terms of which are made
+# available in the LICENSE file.
+
+import numpy as np
+import torch
+import torch.nn as nn
+
+from collections import OrderedDict
+from networks.layers import upsample, ConvBlock, Conv3x3
+
+from einops import rearrange
+
+
+class DepthDecoder(nn.Module):
+    def __init__(self, cfg, num_ch_enc, num_output_channels=1, use_skips=True):
+        super(DepthDecoder, self).__init__()
+
+        self.cfg = cfg
+        depth_num = cfg.model.gaussians_per_pixel - 1 if "unidepth" in cfg.model.name else cfg.model.gaussians_per_pixel
+        self.num_output_channels = num_output_channels * depth_num
+        self.use_skips = use_skips
+        self.upsample_mode = 'nearest'
+        self.scales = cfg.model.scales
+
+        self.num_ch_enc = num_ch_enc
+        self.num_ch_dec = np.array([16, 32, 64, 128, 256])
+
+        # decoder
+        self.convs = OrderedDict()
+        for i in range(4, -1, -1):
+            # upconv_0
+            num_ch_in = self.num_ch_enc[-1] if i == 4 else self.num_ch_dec[i + 1]
+            num_ch_out = self.num_ch_dec[i]
+            self.convs[("upconv", i, 0)] = ConvBlock(num_ch_in, num_ch_out)
+
+            # upconv_1
+            num_ch_in = self.num_ch_dec[i]
+            if self.use_skips and i > 0:
+                num_ch_in += self.num_ch_enc[i - 1]
+            num_ch_out = self.num_ch_dec[i]
+            self.convs[("upconv", i, 1)] = ConvBlock(num_ch_in, num_ch_out)
+
+        for s in self.scales:
+            out = Conv3x3(self.num_ch_dec[s], self.num_output_channels)
+            self.convs[("dispconv", s)] = out
+            nn.init.xavier_uniform_(out.conv.weight, cfg.model.depth_scale)
+            nn.init.constant_(out.conv.bias, cfg.model.depth_bias)
+
+        self.decoder = nn.ModuleList(list(self.convs.values()))
+        if cfg.model.depth_type in ["disp", "disp_inc"]:
+            self.activate = nn.Sigmoid()
+        elif cfg.model.depth_type == "depth":
+            self.activate = nn.Softplus()
+        elif cfg.model.depth_type == "depth_inc":
+            self.activate = torch.exp
+
+    def forward(self, input_features):
+        outputs = {}
+        x = input_features[-1]
+        for i in range(4, -1, -1):
+            x = self.convs[("upconv", i, 0)](x)
+            x = [upsample(x)]
+            if self.use_skips and i > 0:
+                x += [input_features[i - 1]]
+            x = torch.cat(x, 1)
+            x = self.convs[("upconv", i, 1)](x)
+            if i in self.scales:
+                depth_num = self.cfg.model.gaussians_per_pixel - 1 if "unidepth" in self.cfg.model.name else self.cfg.model.gaussians_per_pixel
+                if self.cfg.model.depth_type == "depth_inc":
+                    outputs[("depth", i)] = rearrange(self.activate(torch.clamp(self.convs[("dispconv", i)](x), min=-10.0, max=6.0)),
+                                                 'b (n c) ...-> (b n) c ...', n = depth_num)
+                elif self.cfg.model.depth_type in ["disp", "disp_inc"]:
+                    outputs[("disp", i)] = rearrange(self.activate(self.convs[("dispconv", i)](x)),
+                                                 'b (n c) ...-> (b n) c ...', n = depth_num)
+                else:
+                    outputs[(self.cfg.model.depth_type, i)] = rearrange(self.activate(self.convs[("dispconv", i)](x)),
+                                                 'b (n c) ...-> (b n) c ...', n = depth_num)
+        return outputs
diff --git a/flash3d/networks/gaussian_decoder.py b/flash3d/networks/gaussian_decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..38be79f6175a11be06463822cc24a1862212aa5f
--- /dev/null
+++ b/flash3d/networks/gaussian_decoder.py
@@ -0,0 +1,196 @@
+from collections import OrderedDict
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import numpy as np
+
+
+def upsample(x):
+    """Upsample input tensor by a factor of 2
+    """
+    return F.interpolate(x, scale_factor=2, mode="nearest")
+
+
+class Conv3x3(nn.Module):
+    """Layer to pad and convolve input
+    """
+    def __init__(self, in_channels, out_channels, use_refl=True):
+        super(Conv3x3, self).__init__()
+
+        if use_refl:
+            self.pad = nn.ReflectionPad2d(1)
+        else:
+            self.pad = nn.ZeroPad2d(1)
+        self.conv = nn.Conv2d(int(in_channels), int(out_channels), 3)
+
+    def forward(self, x):
+        out = self.pad(x)
+        out = self.conv(out)
+        return out
+
+
+class ConvBlock(nn.Module):
+    """Layer to perform a convolution followed by ELU
+    """
+    def __init__(self, in_channels, out_channels):
+        super(ConvBlock, self).__init__()
+
+        self.conv = Conv3x3(in_channels, out_channels)
+        self.nonlin = nn.ELU(inplace=True)
+
+    def forward(self, x):
+        out = self.conv(x)
+        out = self.nonlin(out)
+        return out
+
+
+def get_splits_and_inits(cfg):
+    split_dimensions = []
+    scale_inits = []
+    bias_inits = []
+
+    for g_idx in range(cfg.model.gaussians_per_pixel):
+        if cfg.model.predict_offset:
+            split_dimensions += [3]
+            scale_inits += [cfg.model.xyz_scale]
+            bias_inits += [cfg.model.xyz_bias]
+
+        split_dimensions += [1, 3, 4, 3]
+        scale_inits += [cfg.model.opacity_scale, 
+                        cfg.model.scale_scale,
+                        1.0,
+                        5.0]
+        bias_inits += [cfg.model.opacity_bias,
+                        np.log(cfg.model.scale_bias),
+                        0.0,
+                        0.0]
+
+        if cfg.model.max_sh_degree != 0:
+            sh_num = (cfg.model.max_sh_degree + 1) ** 2 - 1
+            sh_num_rgb = sh_num * 3
+            split_dimensions.append(sh_num_rgb)
+            scale_inits.append(cfg.model.sh_scale)
+            bias_inits.append(0.0)
+        if not cfg.model.one_gauss_decoder:
+            break
+
+    return split_dimensions, scale_inits, bias_inits, 
+
+
+class GaussianDecoder(nn.Module):
+    def __init__(self, cfg, num_ch_enc, use_skips=True):
+        super(GaussianDecoder, self).__init__()
+
+        self.cfg = cfg
+        self.use_skips = use_skips
+        self.upsample_mode = 'nearest'
+
+        self.num_ch_enc = num_ch_enc
+        self.num_ch_dec = np.array(cfg.model.num_ch_dec)
+
+        split_dimensions, scale, bias = get_splits_and_inits(cfg)
+
+        # [offset], opacity, scaling, rotation, feat_dc
+        assert not cfg.model.unified_decoder
+
+        self.split_dimensions = split_dimensions
+
+        self.num_output_channels = sum(self.split_dimensions)
+
+        # decoder
+        self.convs = OrderedDict()
+        for i in range(4, -1, -1):
+            # upconv_0
+            num_ch_in = self.num_ch_enc[-1] if i == 4 else self.num_ch_dec[i + 1]
+            num_ch_out = self.num_ch_dec[i]
+            self.convs[("upconv", i, 0)] = ConvBlock(num_ch_in, num_ch_out)
+
+            # upconv_1
+            num_ch_in = self.num_ch_dec[i]
+            if self.use_skips and i > 0:
+                num_ch_in += self.num_ch_enc[i - 1]
+            num_ch_out = self.num_ch_dec[i]
+            self.convs[("upconv", i, 1)] = ConvBlock(num_ch_in, num_ch_out)
+
+        self.out = nn.Conv2d(self.num_ch_dec[0], self.num_output_channels, 1)
+
+        out_channels = self.split_dimensions
+        start_channels = 0
+        for out_channel, b, s in zip(out_channels, bias, scale):
+            nn.init.xavier_uniform_(
+                self.out.weight[start_channels:start_channels+out_channel,
+                                :, :, :], s)
+            nn.init.constant_(
+                self.out.bias[start_channels:start_channels+out_channel], b)
+            start_channels += out_channel
+
+        self.decoder = nn.ModuleList(list(self.convs.values()))
+
+        self.scaling_activation = torch.exp
+        self.opacity_activation = torch.sigmoid
+        self.rotation_activation = torch.nn.functional.normalize
+        self.scaling_lambda = cfg.model.scale_lambda
+        self.sigmoid = nn.Sigmoid()
+
+    def forward(self, input_features):
+        self.outputs = {}
+
+        # decoder
+        x = input_features[-1]
+        for i in range(4, -1, -1):
+            x = self.convs[("upconv", i, 0)](x)
+            x = [upsample(x)]
+            if self.use_skips and i > 0:
+                x += [input_features[i - 1]]
+            x = torch.cat(x, 1)
+            x = self.convs[("upconv", i, 1)](x)
+
+        x = self.out(x)
+
+        split_network_outputs = x.split(self.split_dimensions, dim=1)
+
+        offset_list = []
+        opacity_list = []
+        scaling_list = []
+        rotation_list = []
+        feat_dc_list = []
+        feat_rest_list = []
+
+        assert not self.cfg.model.unified_decoder
+
+        for i in range(self.cfg.model.gaussians_per_pixel):
+            assert self.cfg.model.max_sh_degree > 0
+            if self.cfg.model.predict_offset:
+                offset_s, opacity_s, scaling_s, \
+                    rotation_s, feat_dc_s, features_rest_s = split_network_outputs[i*6:(i+1)*6]
+                offset_list.append(offset_s[:, None, ...])
+            else:
+                opacity_s, scaling_s, rotation_s, feat_dc_s, features_rest_s = split_network_outputs[i*5:(i+1)*5]
+            opacity_list.append(opacity_s[:, None, ...])
+            scaling_list.append(scaling_s[:, None, ...])
+            rotation_list.append(rotation_s[:, None, ...])
+            feat_dc_list.append(feat_dc_s[:, None, ...])
+            feat_rest_list.append(features_rest_s[:, None, ...])
+            if not self.cfg.model.one_gauss_decoder:
+                break
+
+        # squeezing will remove dimension if there is only one gaussian per pixel
+        opacity = torch.cat(opacity_list, dim=1).squeeze(1)
+        scaling = torch.cat(scaling_list, dim=1).squeeze(1)
+        rotation = torch.cat(rotation_list, dim=1).squeeze(1)
+        feat_dc = torch.cat(feat_dc_list, dim=1).squeeze(1)
+        features_rest = torch.cat(feat_rest_list, dim=1).squeeze(1)
+
+        out = {
+            ("gauss_opacity", 0): self.opacity_activation(opacity),
+            ("gauss_scaling", 0): self.scaling_activation(scaling) * self.scaling_lambda,
+            ("gauss_rotation", 0): self.rotation_activation(rotation),
+            ("gauss_features_dc", 0): feat_dc,
+            ("gauss_features_rest", 0): features_rest
+        }
+
+        if self.cfg.model.predict_offset:
+            offset = torch.cat(offset_list, dim=1).squeeze(1)
+            out[("gauss_offset", 0)] = offset
+        return out
+
diff --git a/flash3d/networks/gaussian_predictor.py b/flash3d/networks/gaussian_predictor.py
new file mode 100644
index 0000000000000000000000000000000000000000..c0291462a0b793e22c69d8c73192ae3cff7b7ef3
--- /dev/null
+++ b/flash3d/networks/gaussian_predictor.py
@@ -0,0 +1,293 @@
+from pathlib import Path
+import logging
+
+import torch
+import torch.nn as nn
+from einops import rearrange
+
+from networks.layers import BackprojectDepth, disp_to_depth
+from networks.resnet_encoder import ResnetEncoder
+from networks.depth_decoder import DepthDecoder
+from networks.gaussian_decoder import GaussianDecoder
+
+
+def default_param_group(model):
+    return [{'params': model.parameters()}]
+
+
+def to_device(inputs, device):
+    for key, ipt in inputs.items():
+        if isinstance(ipt, torch.Tensor):
+            inputs[key] = ipt.to(device)
+    return inputs
+
+
+class GaussianPredictor(nn.Module):
+    def __init__(self, cfg):
+        super().__init__()
+        self.cfg = cfg
+
+        # checking height and width are multiples of 32
+        # assert cfg.dataset.width % 32 == 0, "'width' must be a multiple of 32"
+
+        models = {}
+        self.parameters_to_train = []
+
+        self.num_scales = len(cfg.model.scales)
+
+        assert cfg.model.frame_ids[0] == 0, "frame_ids must start with 0"
+
+        if cfg.model.use_stereo:
+            cfg.model.frame_ids.append("s")
+
+        model_name = cfg.model.name
+        if model_name == "resnet":
+            models["encoder"] = ResnetEncoder(
+                cfg.model.num_layers,
+                cfg.model.weights_init == "pretrained",
+                cfg.model.resnet_bn_order
+            )
+            self.parameters_to_train += default_param_group(models["encoder"])
+            if not cfg.model.unified_decoder:
+                models["depth"] = DepthDecoder(
+                    cfg, models["encoder"].num_ch_enc)
+                self.parameters_to_train += default_param_group(models["depth"])
+            if cfg.model.gaussian_rendering:
+                for i in range(cfg.model.gaussians_per_pixel):
+                    gauss_decoder = GaussianDecoder(
+                        cfg, models["encoder"].num_ch_enc,
+                    )
+                    self.parameters_to_train += default_param_group(gauss_decoder)
+                    models["gauss_decoder_"+str(i)] = gauss_decoder
+        elif model_name == "unidepth":
+            from networks.unidepth import UniDepthSplatter
+            models["unidepth"] = UniDepthSplatter(cfg)
+            self.parameters_to_train += models["unidepth"].get_parameter_groups()
+        elif model_name in ["unidepth_unprojector_vit", "unidepth_unprojector_cnvnxtl"]:
+            from networks.unidepth import UniDepthUnprojector
+            models["unidepth"] = UniDepthUnprojector(cfg)
+            self.parameters_to_train += models["unidepth"].get_parameter_groups()
+        elif model_name in ["unidepth_extension_vit", "unidepth_extension_cnvnxtl"]:
+            from networks.unidepth_extension import UniDepthExtended
+            models["unidepth_extended"] = UniDepthExtended(cfg)
+            self.parameters_to_train += models["unidepth_extended"].get_parameter_groups()
+
+        self.models = nn.ModuleDict(models)
+
+        backproject_depth = {}
+        H = cfg.dataset.height
+        W = cfg.dataset.width
+        for scale in cfg.model.scales:
+            h = H // (2 ** scale)
+            w = W // (2 ** scale)
+            if cfg.model.shift_rays_half_pixel == "zero":
+                shift_rays_half_pixel = 0
+            elif cfg.model.shift_rays_half_pixel == "forward":
+                shift_rays_half_pixel = 0.5
+            elif cfg.model.shift_rays_half_pixel == "backward":
+                shift_rays_half_pixel = -0.5
+            else:
+                raise NotImplementedError
+            backproject_depth[str(scale)] = BackprojectDepth(
+                cfg.optimiser.batch_size * cfg.model.gaussians_per_pixel, 
+                # backprojection can be different if padding was used
+                h + 2 * self.cfg.dataset.pad_border_aug, 
+                w + 2 * self.cfg.dataset.pad_border_aug,
+                shift_rays_half_pixel=shift_rays_half_pixel
+            )
+        self.backproject_depth = nn.ModuleDict(backproject_depth)
+
+    def set_train(self):
+        """Convert all models to training mode
+        """
+        for m in self.models.values():
+            m.train()
+        self._is_train = True
+
+    def set_eval(self):
+        """Convert all models to testing/evaluation mode
+        """
+        for m in self.models.values():
+            m.eval()
+        self._is_train = False
+    
+    def is_train(self):
+        return self._is_train
+    
+    def forward(self, inputs):
+        cfg = self.cfg
+        B = cfg.optimiser.batch_size
+
+        if cfg.model.name == "resnet":
+            do_flip = self.is_train() and \
+                    cfg.train.lazy_flip_augmentation and \
+                    (torch.rand(1) > .5).item()
+            # Otherwise, we only feed the image with frame_id 0 through the depth encoder
+            input_img = inputs["color_aug", 0, 0]
+            if do_flip:
+                input_img = torch.flip(input_img, dims=(-1, ))
+            features = self.models["encoder"](input_img)
+            if not cfg.model.unified_decoder:
+                outputs = self.models["depth"](features)
+            else:
+                outputs = dict()
+            
+            if self.cfg.model.gaussian_rendering:
+                # gauss_feats = self.models["gauss_encoder"](inputs["color_aug", 0, 0])
+                input_f_id = 0
+                gauss_feats = features
+                gauss_outs = dict()
+                for i in range(self.cfg.model.gaussians_per_pixel):
+                    outs = self.models["gauss_decoder_"+str(i)](gauss_feats)
+                    for key, v in outs.items():
+                        gauss_outs[key] = outs[key][:,None,...] if i==0 else torch.cat([gauss_outs[key], outs[key][:,None,...]], dim=1)
+                for key, v in gauss_outs.items():
+                    gauss_outs[key] = rearrange(gauss_outs[key], 'b n ... -> (b n) ...')
+                outputs |= gauss_outs
+                outputs = {(key[0], input_f_id, key[1]): v for key, v in outputs.items()}
+            else:
+                for scale in cfg.model.scales:
+                    outputs[("disp", 0, scale)] = outputs[("disp", scale)]
+            
+            # unflip all outputs
+            if do_flip:
+                for k, v in outputs.items():
+                    outputs[k] = torch.flip(v, dims=(-1, ))
+        elif "unidepth" in cfg.model.name:
+            if cfg.model.name in ["unidepth", 
+                                  "unidepth_unprojector_vit", 
+                                  "unidepth_unprojector_cnvnxtl"]:
+                outputs = self.models["unidepth"](inputs)
+            elif cfg.model.name in ["unidepth_extension_vit",
+                                    "unidepth_extension_cnvnxtl"]:
+                outputs = self.models["unidepth_extended"](inputs)
+
+            input_f_id = 0
+            outputs = {(key[0], input_f_id, key[1]): v for key, v in outputs.items()}
+
+        input_f_id = 0
+        scale = 0
+        if not ("depth", input_f_id, scale) in outputs:
+            disp = outputs[("disp", input_f_id, scale)]
+            _, depth = disp_to_depth(disp, cfg.model.min_depth, cfg.model.max_depth)
+            outputs[("depth", input_f_id, scale)] = depth
+
+        self.compute_gauss_means(inputs, outputs)
+
+        return outputs
+
+    def target_tensor_image_dims(self, inputs):
+        B, _, H, W = inputs["color", 0, 0].shape
+        return B, H, W
+
+    def compute_gauss_means(self, inputs, outputs):
+        cfg = self.cfg
+        input_f_id = 0
+        scale = 0
+        depth = outputs[("depth", input_f_id, scale)]
+        B, _, H, W = depth.shape
+        if ("inv_K_src", scale) in inputs:
+            inv_K = inputs[("inv_K_src", scale)]
+        else:
+            inv_K = outputs[("inv_K_src", input_f_id, scale)]
+        if self.cfg.model.gaussians_per_pixel > 1:
+            inv_K = rearrange(inv_K[:,None,...].
+                              repeat(1, self.cfg.model.gaussians_per_pixel, 1, 1),
+                              'b n ... -> (b n) ...')
+        xyz = self.backproject_depth[str(scale)](
+            depth, inv_K
+        )
+        inputs[("inv_K_src", scale)] = inv_K
+        if cfg.model.predict_offset:
+            offset = outputs[("gauss_offset", input_f_id, scale)]
+            if cfg.model.scaled_offset:
+                offset = offset * depth.detach()
+            offset = offset.view(B, 3, -1)
+            zeros = torch.zeros(B, 1, H * W, device=depth.device)
+            offset = torch.cat([offset, zeros], 1)
+            xyz = xyz + offset # [B, 4, W*H]
+        outputs[("gauss_means", input_f_id, scale)] = xyz
+
+    def checkpoint_dir(self):
+        return Path("checkpoints")
+    
+    def save_model(self, optimizer, step, ema=None):
+        """Save model weights to disk
+        """
+        save_folder = self.checkpoint_dir()
+        save_folder.mkdir(exist_ok=True, parents=True)
+
+        save_path = save_folder / f"model_{step:07}.pth"
+        logging.info(f"saving checkpoint to {str(save_path)}")
+
+        model = ema.ema_model if ema is not None else self
+        save_dict = {
+            "model": model.state_dict(),
+            "version": "1.0",
+            "optimiser": optimizer.state_dict(),
+            "step": step
+        }
+        torch.save(save_dict, save_path)
+
+        num_ckpts = self.cfg.optimiser.num_keep_ckpts
+        ckpts = sorted(list(save_folder.glob("model_*.pth")), reverse=True)
+        if len(ckpts) > num_ckpts:
+            for ckpt in ckpts[num_ckpts:]:
+                ckpt.unlink()
+
+    def load_model(self, weights_path, optimizer=None):
+        """Load model(s) from disk
+        """
+        weights_path = Path(weights_path)
+
+        # determine if it is an old or new saving format
+        if weights_path.is_dir() and weights_path.joinpath("encoder.pth").exists():
+            self.load_model_old(weights_path, optimizer)
+            return
+
+        logging.info(f"Loading weights from {weights_path}...")
+        state_dict = torch.load(weights_path)
+        if "version" in state_dict and state_dict["version"] == "1.0":
+            new_dict = {}
+            for k, v in state_dict["model"].items():
+                if "backproject_depth" in k:
+                    new_dict[k] = self.state_dict()[k].clone()
+                else:
+                    new_dict[k] = v.clone()
+            # for k, v in state_dict["model"].items():
+            #     if "backproject_depth" in k and ("pix_coords" in k or "ones" in k):
+            #         # model has these parameters set as a function of batch size
+            #         # when batch size changes in eval this results in a loading error
+            #         state_dict["model"][k] = v[:1, ...]
+            self.load_state_dict(new_dict, strict=False)
+        else:
+            # TODO remove loading according to the old format
+            for name in self.cfg.train.models_to_load:
+                if name not in self.models:
+                    continue
+                self.models[name].load_state_dict(state_dict[name])
+
+        # loading adam state
+        if optimizer is not None:
+            optimizer.load_state_dict(state_dict["optimiser"])
+            self.step = state_dict["step"]
+
+    def load_model_old(self, weights_folder, optimizer=None):
+        for n in self.cfg.train.models_to_load:
+            print(f"Loading {n} weights...")
+            path = weights_folder / f"{n}.pth"
+            if n not in self.models:
+                continue
+            model_dict = self.models[n].state_dict()
+            pretrained_dict = torch.load(path)
+            pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
+            model_dict.update(pretrained_dict)
+            self.models[n].load_state_dict(model_dict)
+
+        # loading adam state
+        optimizer_load_path = weights_folder / "adam.pth"
+        if optimizer is not None and optimizer_load_path.is_file():
+            print("Loading Adam weights")
+            optimizer_state = torch.load(optimizer_load_path)
+            optimizer.load_state_dict(optimizer_state["adam"])
+            self.step = optimizer_state["step"]
diff --git a/flash3d/networks/layers.py b/flash3d/networks/layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..ad9b89b6b4e9f58e30fcde6ad9f81bdf7bf07caa
--- /dev/null
+++ b/flash3d/networks/layers.py
@@ -0,0 +1,295 @@
+# Copyright Niantic 2019. Patent Pending. All rights reserved.
+#
+# This software is licensed under the terms of the Monodepth2 licence
+# which allows for non-commercial use only, the full terms of which are made
+# available in the LICENSE file.
+
+import numpy as np
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+def disp_to_depth(disp, min_depth, max_depth):
+    """Convert network's sigmoid output into depth prediction
+    The formula for this conversion is given in the 'additional considerations'
+    section of the paper.
+    """
+    min_disp = 1 / max_depth
+    max_disp = 1 / min_depth
+    scaled_disp = min_disp + (max_disp - min_disp) * disp
+    depth = 1 / scaled_disp
+    return scaled_disp, depth
+
+
+def transformation_from_parameters(axisangle, translation, invert=False):
+    """Convert the network's (axisangle, translation) output into a 4x4 matrix
+    """
+    R = rot_from_axisangle(axisangle)
+    t = translation.clone()
+
+    if invert:
+        R = R.transpose(1, 2)
+        t *= -1
+
+    T = get_translation_matrix(t)
+
+    if invert:
+        M = torch.matmul(R, T)
+    else:
+        M = torch.matmul(T, R)
+
+    return M
+
+
+def get_translation_matrix(translation_vector):
+    """Convert a translation vector into a 4x4 transformation matrix
+    """
+    T = torch.zeros(translation_vector.shape[0], 4, 4).to(device=translation_vector.device)
+
+    t = translation_vector.contiguous().view(-1, 3, 1)
+
+    T[:, 0, 0] = 1
+    T[:, 1, 1] = 1
+    T[:, 2, 2] = 1
+    T[:, 3, 3] = 1
+    T[:, :3, 3, None] = t
+
+    return T
+
+
+def rot_from_axisangle(vec):
+    """Convert an axisangle rotation into a 4x4 transformation matrix
+    (adapted from https://github.com/Wallacoloo/printipi)
+    Input 'vec' has to be Bx1x3
+    """
+    angle = torch.norm(vec, 2, 2, True)
+    axis = vec / (angle + 1e-7)
+
+    ca = torch.cos(angle)
+    sa = torch.sin(angle)
+    C = 1 - ca
+
+    x = axis[..., 0].unsqueeze(1)
+    y = axis[..., 1].unsqueeze(1)
+    z = axis[..., 2].unsqueeze(1)
+
+    xs = x * sa
+    ys = y * sa
+    zs = z * sa
+    xC = x * C
+    yC = y * C
+    zC = z * C
+    xyC = x * yC
+    yzC = y * zC
+    zxC = z * xC
+
+    rot = torch.zeros((vec.shape[0], 4, 4)).to(device=vec.device)
+
+    rot[:, 0, 0] = torch.squeeze(x * xC + ca)
+    rot[:, 0, 1] = torch.squeeze(xyC - zs)
+    rot[:, 0, 2] = torch.squeeze(zxC + ys)
+    rot[:, 1, 0] = torch.squeeze(xyC + zs)
+    rot[:, 1, 1] = torch.squeeze(y * yC + ca)
+    rot[:, 1, 2] = torch.squeeze(yzC - xs)
+    rot[:, 2, 0] = torch.squeeze(zxC - ys)
+    rot[:, 2, 1] = torch.squeeze(yzC + xs)
+    rot[:, 2, 2] = torch.squeeze(z * zC + ca)
+    rot[:, 3, 3] = 1
+
+    return rot
+
+
+class ConvBlock(nn.Module):
+    """Layer to perform a convolution followed by ELU
+    """
+    def __init__(self, in_channels, out_channels):
+        super(ConvBlock, self).__init__()
+
+        self.conv = Conv3x3(in_channels, out_channels)
+        self.nonlin = nn.ELU(inplace=True)
+
+    def forward(self, x):
+        out = self.conv(x)
+        out = self.nonlin(out)
+        return out
+
+
+class Conv3x3(nn.Module):
+    """Layer to pad and convolve input
+    """
+    def __init__(self, in_channels, out_channels, use_refl=True):
+        super(Conv3x3, self).__init__()
+
+        if use_refl:
+            self.pad = nn.ReflectionPad2d(1)
+        else:
+            self.pad = nn.ZeroPad2d(1)
+        self.conv = nn.Conv2d(int(in_channels), int(out_channels), 3)
+
+    def forward(self, x):
+        out = self.pad(x)
+        out = self.conv(out)
+        return out
+
+
+class BackprojectDepth(nn.Module):
+    """Layer to transform a depth image into a point cloud
+    """
+    def __init__(self, batch_size, height, width, shift_rays_half_pixel=0):
+        super(BackprojectDepth, self).__init__()
+
+        self.batch_size = batch_size
+        self.height = height
+        self.width = width
+
+        meshgrid = np.meshgrid(range(self.width), range(self.height), indexing='xy')
+        id_coords = np.stack(meshgrid, axis=0).astype(np.float32)
+        id_coords = torch.from_numpy(id_coords)
+
+        ones = torch.ones(self.batch_size, 1, self.height * self.width)
+
+        pix_coords = torch.unsqueeze(torch.stack(
+            [id_coords[0].view(-1), id_coords[1].view(-1)], 0), 0)
+        pix_coords = pix_coords.repeat(batch_size, 1, 1)
+        pix_coords = torch.cat([pix_coords + shift_rays_half_pixel, 
+                                ones], 1)
+        self.register_buffer("pix_coords", pix_coords)
+        self.register_buffer("id_coords", id_coords)
+        self.register_buffer("ones", ones)
+        # self.pix_coords = pix_coords
+        # self.ones = ones
+
+    def forward(self, depth, inv_K):
+        cam_points = torch.matmul(inv_K[:, :3, :3], self.pix_coords.to(depth.device))
+        cam_points = depth.view(self.batch_size, 1, -1) * cam_points
+        cam_points = torch.cat([cam_points, self.ones.to(depth.device)], 1)
+
+        return cam_points
+
+
+class Project3D(nn.Module):
+    """Layer which projects 3D points into a camera with intrinsics K and at position T
+    """
+    def __init__(self, batch_size, height, width, eps=1e-7):
+        super(Project3D, self).__init__()
+
+        self.batch_size = batch_size
+        self.height = height
+        self.width = width
+        self.eps = eps
+
+    def forward(self, points, K, T=None):
+        if T is None:
+            P = K
+        else:
+            P = torch.matmul(K, T)
+        P = P[:, :3, :]
+
+        cam_points = torch.matmul(P, points)
+
+        pix_coords = cam_points[:, :2, :] / (cam_points[:, 2, :].unsqueeze(1) + self.eps)
+        pix_coords = pix_coords.view(self.batch_size, 2, self.height, self.width)
+        pix_coords = pix_coords.permute(0, 2, 3, 1)
+        pix_coords[..., 0] /= self.width - 1
+        pix_coords[..., 1] /= self.height - 1
+        pix_coords = (pix_coords - 0.5) * 2
+        return pix_coords
+
+
+class Project3DSimple(nn.Module):
+    """Layer which projects 3D points into a camera with intrinsics K and at position T
+    """
+    def __init__(self, batch_size, height, width, eps=1e-7):
+        super(Project3DSimple, self).__init__()
+
+        self.batch_size = batch_size
+        self.height = height
+        self.width = width
+        self.eps = eps
+
+    def forward(self, points, K):
+        K = K[:, :3, :]
+
+        cam_points = torch.matmul(K, points)
+
+        pix_coords = cam_points[:, :2, :] / (cam_points[:, 2, :].unsqueeze(1) + self.eps)
+        pix_coords = pix_coords.view(self.batch_size, 2, self.height, self.width)
+        pix_coords = pix_coords.permute(0, 2, 3, 1)
+        return pix_coords
+
+def upsample(x):
+    """Upsample input tensor by a factor of 2
+    """
+    return F.interpolate(x, scale_factor=2, mode="nearest")
+
+
+def get_smooth_loss(disp, img):
+    """Computes the smoothness loss for a disparity image
+    The color image is used for edge-aware smoothness
+    """
+    grad_disp_x = torch.abs(disp[:, :, :, :-1] - disp[:, :, :, 1:])
+    grad_disp_y = torch.abs(disp[:, :, :-1, :] - disp[:, :, 1:, :])
+
+    grad_img_x = torch.mean(torch.abs(img[:, :, :, :-1] - img[:, :, :, 1:]), 1, keepdim=True)
+    grad_img_y = torch.mean(torch.abs(img[:, :, :-1, :] - img[:, :, 1:, :]), 1, keepdim=True)
+
+    grad_disp_x *= torch.exp(-grad_img_x)
+    grad_disp_y *= torch.exp(-grad_img_y)
+
+    return grad_disp_x.mean() + grad_disp_y.mean()
+
+
+class SSIM(nn.Module):
+    """Layer to compute the SSIM loss between a pair of images
+    """
+    def __init__(self):
+        super(SSIM, self).__init__()
+        self.mu_x_pool   = nn.AvgPool2d(3, 1)
+        self.mu_y_pool   = nn.AvgPool2d(3, 1)
+        self.sig_x_pool  = nn.AvgPool2d(3, 1)
+        self.sig_y_pool  = nn.AvgPool2d(3, 1)
+        self.sig_xy_pool = nn.AvgPool2d(3, 1)
+
+        self.refl = nn.ReflectionPad2d(1)
+
+        self.C1 = 0.01 ** 2
+        self.C2 = 0.03 ** 2
+
+    def forward(self, x, y):
+        x = self.refl(x)
+        y = self.refl(y)
+
+        mu_x = self.mu_x_pool(x)
+        mu_y = self.mu_y_pool(y)
+
+        sigma_x  = self.sig_x_pool(x ** 2) - mu_x ** 2
+        sigma_y  = self.sig_y_pool(y ** 2) - mu_y ** 2
+        sigma_xy = self.sig_xy_pool(x * y) - mu_x * mu_y
+
+        SSIM_n = (2 * mu_x * mu_y + self.C1) * (2 * sigma_xy + self.C2)
+        SSIM_d = (mu_x ** 2 + mu_y ** 2 + self.C1) * (sigma_x + sigma_y + self.C2)
+
+        return torch.clamp((1 - SSIM_n / SSIM_d) / 2, 0, 1)
+
+
+def compute_depth_errors(gt, pred):
+    """Computation of error metrics between predicted and ground truth depths
+    """
+    thresh = torch.max((gt / pred), (pred / gt))
+    a1 = (thresh < 1.25     ).float().mean()
+    a2 = (thresh < 1.25 ** 2).float().mean()
+    a3 = (thresh < 1.25 ** 3).float().mean()
+
+    rmse = (gt - pred) ** 2
+    rmse = torch.sqrt(rmse.mean())
+
+    rmse_log = (torch.log(gt) - torch.log(pred)) ** 2
+    rmse_log = torch.sqrt(rmse_log.mean())
+
+    abs_rel = torch.mean(torch.abs(gt - pred) / gt)
+
+    sq_rel = torch.mean((gt - pred) ** 2 / gt)
+
+    return abs_rel, sq_rel, rmse, rmse_log, a1, a2, a3
diff --git a/flash3d/networks/resnet_encoder.py b/flash3d/networks/resnet_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..dad6056ebc32f59f7fbbdbc844bfb017b0160360
--- /dev/null
+++ b/flash3d/networks/resnet_encoder.py
@@ -0,0 +1,115 @@
+# Copyright Niantic 2019. Patent Pending. All rights reserved.
+#
+# This software is licensed under the terms of the Monodepth2 licence
+# which allows for non-commercial use only, the full terms of which are made
+# available in the LICENSE file.
+
+import numpy as np
+
+import torch
+import torch.nn as nn
+import torchvision.models as models
+
+
+RESNETS = {18: (models.resnet18, models.ResNet18_Weights.IMAGENET1K_V1),
+           50: (models.resnet50, models.ResNet50_Weights.IMAGENET1K_V2)}
+
+
+class ResNetMultiImageInput(models.ResNet):
+    """Constructs a resnet model with varying number of input images.
+    Adapted from https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
+    """
+    def __init__(self, block, layers, num_classes=1000, num_input_images=1):
+        super(ResNetMultiImageInput, self).__init__(block, layers)
+        self.inplanes = 64
+        self.conv1 = nn.Conv2d(
+            num_input_images * 3, 64, kernel_size=7, stride=2, padding=3, bias=False)
+        self.bn1 = nn.BatchNorm2d(64)
+        self.relu = nn.ReLU(inplace=True)
+        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+        self.layer1 = self._make_layer(block, 64, layers[0])
+        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
+        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
+        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
+
+        for m in self.modules():
+            if isinstance(m, nn.Conv2d):
+                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+            elif isinstance(m, nn.BatchNorm2d):
+                nn.init.constant_(m.weight, 1)
+                nn.init.constant_(m.bias, 0)
+
+
+def resnet_multiimage_input(num_layers, pretrained=False, num_input_images=1):
+    """Constructs a ResNet model.
+    Args:
+        num_layers (int): Number of resnet layers. Must be 18 or 50
+        pretrained (bool): If True, returns a model pre-trained on ImageNet
+        num_input_images (int): Number of frames stacked as input
+    """
+    assert num_layers in [18, 50], "Can only run with 18 or 50 layer resnet"
+    blocks = {18: [2, 2, 2, 2], 50: [3, 4, 6, 3]}[num_layers]
+    block_type = {18: models.resnet.BasicBlock, 50: models.resnet.Bottleneck}[num_layers]
+    model = ResNetMultiImageInput(block_type, blocks, num_input_images=num_input_images)
+    model, weigths = RESNETS[num_layers]
+
+    if pretrained:
+        loaded = torch.hub.load_state_dict_from_url(weigths.url)
+        loaded['conv1.weight'] = torch.cat(
+            [loaded['conv1.weight']] * num_input_images, 1) / num_input_images
+        model.load_state_dict(loaded)
+    return model
+
+
+class ResnetEncoder(nn.Module):
+    """Pytorch module for a resnet encoder
+    """
+    def __init__(self, num_layers, pretrained, bn_order, num_input_images=1):
+        super(ResnetEncoder, self).__init__()
+
+        self.num_ch_enc = np.array([64, 64, 128, 256, 512])
+        self.bn_order = bn_order
+
+        if num_layers not in RESNETS:
+            raise ValueError("{} is not a valid number of resnet layers".format(num_layers))
+
+        if num_input_images > 1:
+            self.encoder = resnet_multiimage_input(num_layers, pretrained, num_input_images)
+        else:
+            model, weights = RESNETS[num_layers]
+            self.encoder = model(weights=weights)
+
+        if num_layers > 34:
+            self.num_ch_enc[1:] *= 4
+
+    def forward(self, input_image):
+        encoder = self.encoder
+        features = []
+        x = (input_image - 0.45) / 0.225
+        x = encoder.conv1(x)
+
+        if self.bn_order == "pre_bn":
+            # Concatenating pre-norm features allows us to 
+            # keep the scale and shift of RGB colours 
+            # and recover them at output
+            features.append(x)
+            x = encoder.bn1(x)
+            x = encoder.relu(x)
+            features.append(encoder.layer1(encoder.maxpool(x)))
+        elif self.bn_order == "monodepth":
+            # Batchnorm gets rid of constants due to colour shift
+            # will make the network not able to recover absolute colour shift
+            # of the input image
+            # used in old models
+            x = encoder.bn1(x)
+            x = encoder.relu(x)
+            features.append(x)
+            features.append(encoder.layer1(encoder.maxpool(x)))
+        else:
+            assert False
+
+        features.append(encoder.layer2(features[-1]))
+        features.append(encoder.layer3(features[-1]))
+        features.append(encoder.layer4(features[-1]))
+
+        return features
diff --git a/flash3d/networks/unidepth.py b/flash3d/networks/unidepth.py
new file mode 100644
index 0000000000000000000000000000000000000000..346e382643867ef00bcdd410a5b58d60e9bdb574
--- /dev/null
+++ b/flash3d/networks/unidepth.py
@@ -0,0 +1,577 @@
+import json
+from pathlib import Path
+from typing import List, Tuple
+from math import ceil
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torchvision.transforms.functional as TF
+from einops import rearrange
+
+from unidepth.models.unidepthv1 import UniDepthV1
+from unidepth.utils.constants import IMAGENET_DATASET_MEAN, IMAGENET_DATASET_STD
+from unidepth.utils.geometric import (
+    generate_rays,
+    spherical_zbuffer_to_euclidean,
+    flat_interpolate,
+)
+from unidepth.layers import (
+    MLP,
+    AttentionBlock,
+    NystromBlock,
+    PositionEmbeddingSine,
+    ConvUpsample,
+)
+from unidepth.utils.sht import rsh_cart_8
+
+from networks.gaussian_decoder import get_splits_and_inits
+
+
+# inference helpers
+def _paddings(image_shape, network_shape):
+    cur_h, cur_w = image_shape
+    h, w = network_shape
+    pad_top, pad_bottom = (h - cur_h) // 2, h - cur_h - (h - cur_h) // 2
+    pad_left, pad_right = (w - cur_w) // 2, w - cur_w - (w - cur_w) // 2
+    return pad_left, pad_right, pad_top, pad_bottom
+
+
+def _shapes(image_shape, network_shape):
+    h, w = image_shape
+    input_ratio = w / h
+    output_ratio = network_shape[1] / network_shape[0]
+    if output_ratio > input_ratio:
+        ratio = network_shape[0] / h
+    elif output_ratio <= input_ratio:
+        ratio = network_shape[1] / w
+    return (ceil(h * ratio - 0.5), ceil(w * ratio - 0.5)), ratio
+
+
+def _preprocess(rgbs, intrinsics, shapes, pads, ratio, output_shapes):
+    (pad_left, pad_right, pad_top, pad_bottom) = pads
+    rgbs = F.interpolate(
+        rgbs, size=shapes, mode="bilinear", align_corners=False, antialias=True
+    )
+    rgbs = F.pad(rgbs, (pad_left, pad_right, pad_top, pad_bottom), mode="constant")
+    if intrinsics is not None:
+        intrinsics = intrinsics.clone()
+        intrinsics[:, 0, 0] = intrinsics[:, 0, 0] * ratio
+        intrinsics[:, 1, 1] = intrinsics[:, 1, 1] * ratio
+        intrinsics[:, 0, 2] = intrinsics[:, 0, 2] * ratio + pad_left
+        intrinsics[:, 1, 2] = intrinsics[:, 1, 2] * ratio + pad_top
+        return rgbs, intrinsics
+    return rgbs, None
+
+
+def _postprocess(predictions, intrinsics, shapes, pads, ratio, original_shapes):
+    
+    (pad_left, pad_right, pad_top, pad_bottom) = pads
+    # pred mean, trim paddings, and upsample to input dim
+    predictions = sum(
+        [
+            F.interpolate(
+                x,
+                size=shapes,
+                mode="bilinear",
+                align_corners=False,
+                antialias=True,
+            )
+            for x in predictions
+        ]
+    ) / len(predictions)
+
+    shapes = predictions.shape[2:]
+    predictions = predictions[
+        ..., pad_top : shapes[0] - pad_bottom, pad_left : shapes[1] - pad_right
+    ]
+
+    predictions = F.interpolate(
+        predictions,
+        size=original_shapes,
+        mode="bilinear",
+        align_corners=False,
+        antialias=True,
+    )
+
+    if intrinsics is not None:
+        intrinsics[:, 0, 0] = intrinsics[:, 0, 0] / ratio
+        intrinsics[:, 1, 1] = intrinsics[:, 1, 1] / ratio
+        intrinsics[:, 0, 2] = (intrinsics[:, 0, 2] - pad_left) / ratio
+        intrinsics[:, 1, 2] = (intrinsics[:, 1, 2] - pad_top) / ratio
+
+    return predictions, intrinsics
+
+
+def scale_intrinsics_xy(intrinsics, x_ratio, y_ratio):
+    intrinsics = intrinsics.clone()
+    intrinsics[:, 0, 0] = intrinsics[:, 0, 0] * x_ratio
+    intrinsics[:, 1, 1] = intrinsics[:, 1, 1] * y_ratio
+    intrinsics[:, 0, 2] = intrinsics[:, 0, 2] * x_ratio
+    intrinsics[:, 1, 2] = intrinsics[:, 1, 2] * y_ratio
+    return intrinsics
+
+
+def scale_intrinsics(intrinsics, ratio):
+    intrinsics = intrinsics.clone()
+    intrinsics[:, 0, 0] = intrinsics[:, 0, 0] * ratio
+    intrinsics[:, 1, 1] = intrinsics[:, 1, 1] * ratio
+    intrinsics[:, 0, 2] = intrinsics[:, 0, 2] * ratio
+    intrinsics[:, 1, 2] = intrinsics[:, 1, 2] * ratio
+    return intrinsics
+
+
+def unidepthv1_forward(model, rgbs, intrinsics, skip_camera,
+                       return_raw_preds=False):
+    B, _, H, W = rgbs.shape
+
+    rgbs = TF.normalize(
+        rgbs,
+        mean=IMAGENET_DATASET_MEAN,
+        std=IMAGENET_DATASET_STD,
+    )
+
+    (h, w), ratio = _shapes((H, W), model.image_shape)
+    pad_left, pad_right, pad_top, pad_bottom = _paddings((h, w), model.image_shape)
+    rgbs, gt_intrinsics = _preprocess(
+        rgbs,
+        intrinsics,
+        (h, w),
+        (pad_left, pad_right, pad_top, pad_bottom),
+        ratio,
+        model.image_shape,
+    )
+    
+    encoder_outputs, cls_tokens = model.pixel_encoder(rgbs)
+    if "dino" in model.pixel_encoder.__class__.__name__.lower():
+        encoder_outputs = [
+            (x + y.unsqueeze(1)).contiguous()
+            for x, y in zip(encoder_outputs, cls_tokens)
+        ]
+    
+    # get data for decoder and adapt to given camera
+    inputs = {}
+    inputs["encoder_outputs"] = encoder_outputs
+    inputs["cls_tokens"] = cls_tokens
+    inputs["image"] = rgbs
+    if gt_intrinsics is not None:
+        rays, angles = generate_rays(
+            gt_intrinsics, model.image_shape, noisy=False
+        )
+        inputs["rays"] = rays
+        inputs["angles"] = angles
+        inputs["K"] = gt_intrinsics
+        model.pixel_decoder.test_fixed_camera = True
+        model.pixel_decoder.skip_camera = skip_camera
+
+    # decode all
+    pred_intrinsics, predictions, features, rays = model.pixel_decoder(inputs, {})
+
+    pads = (pad_left, pad_right, pad_top, pad_bottom)
+
+    # undo the reshaping and get original image size (slow)
+    predictions, pred_intrinsics = _postprocess(
+        predictions,
+        pred_intrinsics,
+        model.image_shape,
+        pads,
+        ratio,
+        (H, W),
+    )
+
+    if return_raw_preds:
+        return inputs, predictions
+
+    # final 3D points backprojection
+    intrinsics = gt_intrinsics if gt_intrinsics is not None else pred_intrinsics
+    angles = generate_rays(intrinsics, (H, W), noisy=False)[-1]
+    angles = rearrange(angles, "b (h w) c -> b c h w", h=H, w=W)
+    points_3d = torch.cat((angles, predictions), dim=1)
+    points_3d = spherical_zbuffer_to_euclidean(
+        points_3d.permute(0, 2, 3, 1)
+    ).permute(0, 3, 1, 2)
+
+    # output data
+    outputs = {
+        "intrinsics": intrinsics,
+        "points": points_3d,
+        "depth": predictions[:, -1:],
+        "depth_feats": features,
+        "rays": rays,
+        "padding": pads
+    }
+    model.pixel_decoder.test_fixed_camera = False
+    model.pixel_decoder.skip_camera = False
+    return inputs, outputs
+
+class UniDepthDepth(nn.Module):
+    def __init__(
+        self,
+        cfg,
+        return_raw_preds=False
+    ):
+        super().__init__()
+
+        self.cfg = cfg
+        self.return_raw_preds = return_raw_preds
+
+        if "cnvnxtl" in cfg.model.name:
+            self.depth_prediction_model = UniDepthV1.from_pretrained("lpiccinelli/unidepth-v1-cnvnxtl")
+        elif "vit" in cfg.model.name:
+            self.depth_prediction_model = UniDepthV1.from_pretrained("lpiccinelli/unidepth-v1-vitl14")
+
+        self.skip_camera = True
+
+    def get_depth(self, img, intrinsics):
+        depth_inputs, outputs = unidepthv1_forward(
+            self.depth_prediction_model, 
+            img, 
+            intrinsics, 
+            self.skip_camera,
+            return_raw_preds=self.return_raw_preds)
+        return outputs
+
+    def forward(self, inputs):
+        input_img = inputs["color_aug", 0, 0]
+        # here we need the intrinsics of the source image to condition on
+        # the depth prediction. needs to account for padding 
+        if ("K_src", 0) in inputs:
+            intrinsics = inputs[("K_src", 0)]
+        else:
+            intrinsics = None
+
+        depth_inputs, outputs = unidepthv1_forward(
+            self.depth_prediction_model, 
+            input_img, 
+            intrinsics, 
+            self.skip_camera,
+            return_raw_preds=self.return_raw_preds)
+
+        return depth_inputs, outputs
+
+class UniDepthUnprojector(nn.Module):
+    def __init__(
+        self,
+        cfg
+    ):
+        super().__init__()
+
+        self.cfg = cfg
+
+        if cfg.model.name == "unidepth_unprojector_cnvnxtl":
+            model = UniDepthV1.from_pretrained("lpiccinelli/unidepth-v1-cnvnxtl")
+        elif cfg.model.name == "unidepth_unprojector_vit":
+            model = UniDepthV1.from_pretrained("lpiccinelli/unidepth-v1-vitl14")
+        self.unidepth = model
+
+        self.skip_camera = True
+
+        self.register_buffer("gauss_opacity", torch.ones(1, 1, 1).float())
+        self.register_buffer("gauss_scaling", torch.ones(3, 1, 1).float())
+        self.register_buffer("gauss_rotation", torch.ones(4, 1, 1).float() * 0.5)
+        self.register_buffer("gauss_features_rest", torch.zeros(9, 1, 1).float())
+        self.register_buffer("gauss_offset", torch.zeros(3, 1, 1).float())
+
+        self.all_params = nn.ParameterDict({
+                           "opacity_scaling": nn.Parameter(torch.tensor(cfg.model.opacity_bias).float()),
+                           "scale_scaling": nn.Parameter(torch.tensor(cfg.model.scale_bias).float()),
+                           "colour_scaling": nn.Parameter(torch.tensor(self.cfg.model.colour_scale).float())})
+
+        
+        self.scaling_activation = torch.exp
+        self.opacity_activation = torch.sigmoid
+        self.relu = nn.ReLU()
+
+    def get_parameter_groups(self):
+        # tune scalars for size, opacity and colour modulation
+        return [{'params': self.all_params.parameters()}]
+
+    def forward(self, inputs):
+        model = self.unidepth
+        input_img = inputs["color_aug", 0, 0]
+        # here we need the intrinsics of the source image to condition on
+        # the depth prediction. needs to account for padding 
+        intrinsics = inputs[("K_src", 0)]
+        b, c, h, w = inputs["color_aug", 0, 0].shape
+
+        with torch.no_grad():
+            _, depth_outs = unidepthv1_forward(model, input_img, intrinsics, self.skip_camera)
+
+        outs = {}
+
+        outs[("gauss_opacity", 0)] = self.gauss_opacity.unsqueeze(0).expand(depth_outs["depth"].shape[0], -1, h, w) \
+            * self.opacity_activation(self.all_params["opacity_scaling"])
+        if not self.cfg.model.scale_with_depth:
+            outs[("gauss_scaling", 0)] = self.gauss_scaling.unsqueeze(0).expand(depth_outs["depth"].shape[0], -1, h, w) \
+                * self.scaling_activation(self.all_params["scale_scaling"])
+        else:
+            outs[("gauss_scaling", 0)] = self.gauss_scaling.unsqueeze(0).expand(depth_outs["depth"].shape[0], -1, h, w) \
+                * self.scaling_activation(self.all_params["scale_scaling"]) * depth_outs["depth"] / 10.0
+        outs[("gauss_rotation", 0)] = self.gauss_rotation.unsqueeze(0).expand(depth_outs["depth"].shape[0], -1, h, w)
+        outs[("gauss_offset", 0)] = self.gauss_offset.unsqueeze(0).expand(depth_outs["depth"].shape[0], -1, h, w)
+        outs[("gauss_features_rest", 0)] = self.gauss_features_rest.unsqueeze(0).expand(depth_outs["depth"].shape[0], -1, h, w)
+        # rendering adds 0.5 to go from rendered colours to output
+        outs[("gauss_features_dc", 0)] = (input_img - 0.5)* self.relu(self.all_params["colour_scaling"])
+
+        outs[("depth", 0)] = depth_outs["depth"]
+
+        return outs
+
+class UniDepthSplatter(nn.Module):
+    def __init__(
+        self,
+        cfg
+    ):
+        super().__init__()
+
+        self.cfg = cfg
+
+        config_path = Path("/work/eldar/src/UniDepth")
+        with open(config_path / "configs/config_v1_cnvnxtl.json") as f:
+            config = json.load(f)
+        self.unidepth = UniDepthDepth(self.cfg)
+
+        hidden_dim = config["model"]["pixel_decoder"]["hidden_dim"]
+        expansion = config["model"]["expansion"]
+        depth = config["model"]["pixel_decoder"]["depths"]
+        num_heads = config["model"]["num_heads"]
+        dropout = config["model"]["pixel_decoder"]["dropout"]
+        layer_scale = 1.0
+        self.splat_decoder = GaussSplatHead(
+            cfg,
+            hidden_dim=hidden_dim,
+            num_heads=num_heads,
+            expansion=expansion,
+            depths=depth,
+            camera_dim=81,
+            dropout=dropout,
+            layer_scale=layer_scale,
+        )
+
+        self.skip_camera = True
+
+    def get_parameter_groups(self):
+        base_lr = self.cfg.optimiser.learning_rate
+        return [
+            {'params': self.unidepth.parameters(), "lr": base_lr * 0.05},
+            {'params': self.splat_decoder.parameters()}
+        ]
+
+    def forward(self, inputs):
+        gauss_head = self.splat_decoder
+
+        depth_inputs, depth_outs = self.unidepth(inputs)
+        depth_feats = depth_outs["depth_feats"]
+        rays = depth_outs["rays"]
+        padding = depth_outs["padding"]
+        
+        B, _, H, W = depth_inputs["image"].shape
+
+        # TODO remove hardcoded shapes
+        common_shape = (28, 38)
+        gauss_head.set_shapes(common_shape)
+        gauss_head.set_original_shapes((H, W))
+
+        depth_feats = rearrange(depth_feats, "b c h w -> b (h w) c")
+        outs = gauss_head(
+            latents_16=depth_feats,
+            rays_hr=rays,
+        )
+        for k, v in outs.items():
+            pred, _ = _postprocess([v], None, self.unidepth.depth_prediction_model.image_shape, 
+                                   padding, None, inputs["color_aug", 0, 0].shape[2:4])
+            outs[k] = pred
+        outs[("depth", 0)] = depth_outs["depth"]
+
+        return outs
+
+
+class GaussSplatHead(nn.Module):
+    def __init__(
+        self,
+        cfg,
+        hidden_dim: int,
+        num_heads: int = 8,
+        expansion: int = 4,
+        depths: int | list[int] = 4,
+        camera_dim: int = 256,
+        dropout: float = 0.0,
+        layer_scale: float = 1.0,
+    ) -> None:
+        super().__init__()
+
+        self.cfg = cfg
+
+        if isinstance(depths, int):
+            depths = [depths] * 3
+        assert len(depths) == 3
+
+        self.project_rays16 = MLP(
+            camera_dim, expansion=expansion, dropout=dropout, output_dim=hidden_dim
+        )
+        self.project_rays8 = MLP(
+            camera_dim, expansion=expansion, dropout=dropout, output_dim=hidden_dim // 2
+        )
+        self.project_rays4 = MLP(
+            camera_dim, expansion=expansion, dropout=dropout, output_dim=hidden_dim // 4
+        )
+
+        self.layers_8 = nn.ModuleList([])
+        self.layers_4 = nn.ModuleList([])
+        layers_16 = nn.ModuleList([])
+
+        self.up8 = ConvUpsample(
+            hidden_dim, expansion=expansion, layer_scale=layer_scale
+        )
+        self.up4 = ConvUpsample(
+            hidden_dim // 2, expansion=expansion, layer_scale=layer_scale
+        )
+        self.up2 = ConvUpsample(
+            hidden_dim // 4, expansion=expansion, layer_scale=layer_scale
+        )
+
+        split_dimensions, scale, bias = get_splits_and_inits(cfg)
+        start = 1
+        self.split_dimensions = split_dimensions[start:]
+        scale = scale[start:]
+        bias = bias[start:]
+
+        self.num_output_channels = sum(self.split_dimensions)
+
+        self.out2 = nn.Conv2d(hidden_dim // 8, self.num_output_channels, 3, padding=1)
+        # self.out4 = nn.Conv2d(hidden_dim // 4, self.num_output_channels, 3, padding=1)
+        # self.out8 = nn.Conv2d(hidden_dim // 2, self.num_output_channels, 3, padding=1)
+
+        start_channels = 0
+        for out_channel, b, s in zip(self.split_dimensions, bias, scale):
+            nn.init.xavier_uniform_(
+                self.out2.weight[start_channels:start_channels+out_channel,
+                                :, :, :], s)
+            nn.init.constant_(
+                self.out2.bias[start_channels:start_channels+out_channel], b)
+            start_channels += out_channel
+
+        for i, (blk_lst, depth) in enumerate(
+            zip([layers_16, self.layers_8, self.layers_4], depths)
+        ):
+            if i == 0:
+                continue
+            attn_cls = AttentionBlock if i == 0 else NystromBlock
+            for _ in range(depth):
+                blk_lst.append(
+                    attn_cls(
+                        hidden_dim // (2**i),
+                        num_heads=num_heads // (2**i),
+                        expansion=expansion,
+                        dropout=dropout,
+                        layer_scale=layer_scale,
+                    )
+                )
+
+        self.scaling_activation = torch.exp
+        self.opacity_activation = torch.sigmoid
+        self.rotation_activation = torch.nn.functional.normalize
+        self.scaling_lambda = cfg.model.scale_lambda
+        self.sigmoid = nn.Sigmoid()
+
+    def set_original_shapes(self, shapes: Tuple[int, int]):
+        self.original_shapes = shapes
+
+    def set_shapes(self, shapes: Tuple[int, int]):
+        self.shapes = shapes
+
+    def forward(
+        self, latents_16: torch.Tensor, rays_hr: torch.Tensor
+    ) -> torch.Tensor:
+        shapes = self.shapes
+
+        # camera_embedding
+        # torch.cuda.synchronize()
+        # start = time()
+        rays_embedding_16 = F.normalize(
+            flat_interpolate(rays_hr, old=self.original_shapes, new=shapes), dim=-1
+        )
+        rays_embedding_8 = F.normalize(
+            flat_interpolate(
+                rays_hr, old=self.original_shapes, new=[x * 2 for x in shapes]
+            ),
+            dim=-1,
+        )
+        rays_embedding_4 = F.normalize(
+            flat_interpolate(
+                rays_hr, old=self.original_shapes, new=[x * 4 for x in shapes]
+            ),
+            dim=-1,
+        )
+        rays_embedding_16 = self.project_rays16(rsh_cart_8(rays_embedding_16))
+        rays_embedding_8 = self.project_rays8(rsh_cart_8(rays_embedding_8))
+        rays_embedding_4 = self.project_rays4(rsh_cart_8(rays_embedding_4))
+
+        # Block 16 - Out 8
+        latents_8 = self.up8(
+            rearrange(
+                latents_16 + rays_embedding_16,
+                "b (h w) c -> b c h w",
+                h=shapes[0],
+                w=shapes[1],
+            ).contiguous()
+        )
+        # out8 = self.out8(
+        #     rearrange(
+        #         latents_8, "b (h w) c -> b c h w", h=shapes[0] * 2, w=shapes[1] * 2
+        #     )
+        # )
+
+        # Block 8 - Out 4
+        for layer in self.layers_8:
+            latents_8 = layer(latents_8, pos_embed=rays_embedding_8)
+        latents_4 = self.up4(
+            rearrange(
+                latents_8 + rays_embedding_8,
+                "b (h w) c -> b c h w",
+                h=shapes[0] * 2,
+                w=shapes[1] * 2,
+            ).contiguous()
+        )
+        # out4 = self.out4(
+        #     rearrange(
+        #         latents_4, "b (h w) c -> b c h w", h=shapes[0] * 4, w=shapes[1] * 4
+        #     )
+        # )
+
+        # Block 4 - Out 2
+        for layer in self.layers_4:
+            latents_4 = layer(latents_4, pos_embed=rays_embedding_4)
+        latents_2 = self.up2(
+            rearrange(
+                latents_4 + rays_embedding_4,
+                "b (h w) c -> b c h w",
+                h=shapes[0] * 4,
+                w=shapes[1] * 4,
+            ).contiguous()
+        )
+        out2 = self.out2(
+            rearrange(
+                latents_2, "b (h w) c -> b c h w", h=shapes[0] * 8, w=shapes[1] * 8
+            )
+        )
+
+        split_network_outputs = out2.split(self.split_dimensions, dim=1)
+        last = 5
+        offset, opacity, scaling, rotation, feat_dc = split_network_outputs[:last]
+
+        out = {
+            ("gauss_opacity", 0): self.opacity_activation(opacity),
+            ("gauss_scaling", 0): self.scaling_activation(scaling) * self.scaling_lambda,
+            ("gauss_rotation", 0): self.rotation_activation(rotation),
+            ("gauss_features_dc", 0): feat_dc
+        }
+
+        if self.cfg.model.max_sh_degree > 0:
+            features_rest = split_network_outputs[last]
+            out[("gauss_features_rest", 0)] = features_rest
+
+        if self.cfg.model.predict_offset:
+            out[("gauss_offset", 0)] = offset
+
+        return out
+        # return out8, out4, out2, proj_latents_16
diff --git a/flash3d/networks/unidepth_extension.py b/flash3d/networks/unidepth_extension.py
new file mode 100644
index 0000000000000000000000000000000000000000..fcb1cffb9aaff211f55bd589e36d3713bfe42200
--- /dev/null
+++ b/flash3d/networks/unidepth_extension.py
@@ -0,0 +1,205 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from einops import rearrange
+
+from .unidepth import UniDepthDepth
+from unidepth.models import UniDepthV1
+from .resnet_encoder import ResnetEncoder
+from .gaussian_decoder import GaussianDecoder
+from .depth_decoder import DepthDecoder
+
+from networks.layers import disp_to_depth
+from networks.gaussian_decoder import get_splits_and_inits
+
+
+class UniDepthExtended(nn.Module):
+    def __init__(self,cfg):
+        super().__init__()
+
+        self.cfg = cfg
+
+        self.unidepth = UniDepthDepth(cfg)
+        # self.unidepth = UniDepthV1.from_pretrained("lpiccinelli/unidepth-v1-vitl14")
+
+        self.parameters_to_train = []
+        if self.cfg.model.splat_branch == "resnet":
+            self.encoder = ResnetEncoder(cfg.model.num_layers,
+                                         cfg.model.weights_init == "pretrained",
+                                         cfg.model.resnet_bn_order
+                                        )
+            # change encoder to take depth as conditioning
+            if self.cfg.model.depth_cond:
+                self.encoder.encoder.conv1 = nn.Conv2d(
+                    4,
+                    self.encoder.encoder.conv1.out_channels,
+                    kernel_size = self.encoder.encoder.conv1.kernel_size,
+                    padding = self.encoder.encoder.conv1.padding,
+                    stride = self.encoder.encoder.conv1.stride
+                )
+            self.parameters_to_train += [{"params": self.encoder.parameters()}]
+
+            # use depth branch only for more gaussians
+            if cfg.model.gaussians_per_pixel > 1:
+                models ={}
+                models["depth"] = DepthDecoder(cfg, self.encoder.num_ch_enc)
+                self.parameters_to_train +=[{"params": models["depth"].parameters()}]
+                for i in range(cfg.model.gaussians_per_pixel):
+                    models["gauss_decoder_"+str(i)] = GaussianDecoder(cfg, self.encoder.num_ch_enc)
+                    self.parameters_to_train += [{"params": models["gauss_decoder_"+str(i)].parameters()}]
+                    if cfg.model.one_gauss_decoder:
+                        break
+                self.models = nn.ModuleDict(models)
+            else:
+                self.gauss_decoder = GaussianDecoder(cfg, self.encoder.num_ch_enc)
+                self.parameters_to_train += [{"params": self.gauss_decoder.parameters()}]
+        
+        elif self.cfg.model.splat_branch == "unidepth_vit" or self.cfg.model.splat_branch == "unidepth_cnvnxtl":
+            self.splat_branch = UniDepthDepth(cfg,
+                                              return_raw_preds=True)
+            # modify the head to output the channels for Gaussian parameters
+            self.init_ouput_head_splat_branch()
+            self.parameters_to_train +=[{"params": self.splat_branch.parameters()}]
+
+        self.scaling_activation = torch.exp
+        self.opacity_activation = torch.sigmoid
+        self.rotation_activation = torch.nn.functional.normalize
+
+    def init_ouput_head_splat_branch(self):
+        split_dimensions, scale, bias = get_splits_and_inits(self.cfg)
+        # the first dim in the output is for depth - we don't use that in this branch
+        self.split_dimensions = split_dimensions[1:]
+        scale = scale[1:]
+        bias = bias[1:]
+
+        self.num_output_channels = sum(self.split_dimensions)
+
+        self.splat_branch.depth_prediction_model.pixel_decoder.depth_layer.out2 = \
+            nn.Conv2d(self.splat_branch.depth_prediction_model.pixel_decoder.depth_layer.out2.in_channels, 
+                      self.num_output_channels,
+                kernel_size = self.splat_branch.depth_prediction_model.pixel_decoder.depth_layer.out2.kernel_size,
+                padding = self.splat_branch.depth_prediction_model.pixel_decoder.depth_layer.out2.padding)
+
+        self.splat_branch.depth_prediction_model.pixel_decoder.depth_layer.out4 = \
+            nn.Conv2d(self.splat_branch.depth_prediction_model.pixel_decoder.depth_layer.out4.in_channels, 
+                      self.num_output_channels,
+                kernel_size = self.splat_branch.depth_prediction_model.pixel_decoder.depth_layer.out4.kernel_size,
+                padding = self.splat_branch.depth_prediction_model.pixel_decoder.depth_layer.out4.padding)
+
+        self.splat_branch.depth_prediction_model.pixel_decoder.depth_layer.out8 = \
+            nn.Conv2d(self.splat_branch.depth_prediction_model.pixel_decoder.depth_layer.out8.in_channels, 
+                      self.num_output_channels,
+                kernel_size = self.splat_branch.depth_prediction_model.pixel_decoder.depth_layer.out8.kernel_size,
+                padding = self.splat_branch.depth_prediction_model.pixel_decoder.depth_layer.out8.padding)
+
+        start_channels = 0
+        for out_channel, b, s in zip(split_dimensions, bias, scale):
+            nn.init.xavier_uniform_(
+                self.splat_branch.depth_prediction_model.pixel_decoder.depth_layer.out2.weight[start_channels:start_channels+out_channel,
+                                :, :, :], s)
+            nn.init.constant_(
+                self.splat_branch.depth_prediction_model.pixel_decoder.depth_layer.out2.bias[start_channels:start_channels+out_channel], b)
+            start_channels += out_channel
+        
+        start_channels = 0
+        for out_channel, b, s in zip(split_dimensions, bias, scale):
+            nn.init.xavier_uniform_(
+                self.splat_branch.depth_prediction_model.pixel_decoder.depth_layer.out4.weight[start_channels:start_channels+out_channel,
+                                :, :, :], s)
+            nn.init.constant_(
+                self.splat_branch.depth_prediction_model.pixel_decoder.depth_layer.out4.bias[start_channels:start_channels+out_channel], b)
+            start_channels += out_channel
+
+        start_channels = 0
+        for out_channel, b, s in zip(split_dimensions, bias, scale):
+            nn.init.xavier_uniform_(
+                self.splat_branch.depth_prediction_model.pixel_decoder.depth_layer.out8.weight[start_channels:start_channels+out_channel,
+                                :, :, :], s)
+            nn.init.constant_(
+                self.splat_branch.depth_prediction_model.pixel_decoder.depth_layer.out8.bias[start_channels:start_channels+out_channel], b)
+            start_channels += out_channel
+
+    def get_parameter_groups(self):
+        # only the resnet encoder and gaussian parameter decoder are optimisable
+        return self.parameters_to_train
+
+    def forward(self, inputs):
+        if ('unidepth', 0, 0) in inputs.keys() and inputs[('unidepth', 0, 0)] is not None:
+            depth_outs = dict()
+            depth_outs["depth"] = inputs[('unidepth', 0, 0)]
+        else:
+            with torch.no_grad():
+                # if self.training and self.cfg.dataset.pad_border_aug > 0:
+                #     pad = self.cfg.dataset.pad_border_aug
+                #     input = inputs["color_aug", 0, 0][:,:,pad:-pad, pad:-pad]
+                #     intrincs = inputs[("K_tgt", 0)]
+                # else:
+                #     input = inputs["color_aug", 0, 0]
+                #     intrincs = inputs[("K_src", 0)]
+                _, depth_outs = self.unidepth(inputs)
+                # depth_outs = self.unidepth.infer(input, intrincs)
+                # if self.training and self.cfg.dataset.pad_border_aug > 0:
+                #     depth_outs["depth"] = F.pad(depth_outs["depth"], (pad,pad,pad,pad), mode="replicate")
+
+        outputs_gauss = {}
+
+        K = depth_outs["intrinsics"]
+        outputs_gauss[("K_src", 0)] = K
+        outputs_gauss[("inv_K_src", 0)] = torch.linalg.inv(K)
+
+        if self.cfg.model.splat_branch == "resnet":
+            if self.cfg.model.depth_cond:
+                # division by 20 is to put depth in a similar range to RGB
+                resnet_input = torch.cat([inputs["color_aug", 0, 0], 
+                                          depth_outs["depth"] / 20.0], dim=1)
+            else:
+                resnet_input = inputs["color_aug", 0, 0]
+            resnet_features = self.encoder(resnet_input)
+            if self.cfg.model.gaussians_per_pixel > 1:
+                pred_depth = dict()
+                depth = self.models["depth"](resnet_features)
+                if self.cfg.model.depth_type == "disp":
+                    for key, v in depth.items():
+                        _, pred_depth[("depth", key[1])] = disp_to_depth(v, self.cfg.model.min_depth, self.cfg.model.max_depth)
+                elif self.cfg.model.depth_type in ["depth", "depth_inc"]:
+                    pred_depth = depth
+                pred_depth[("depth", 0)] = rearrange(pred_depth[("depth", 0)], "(b n) ... -> b n ...", n=self.cfg.model.gaussians_per_pixel - 1)
+                if self.cfg.model.depth_type in ["depth_inc", "disp_inc"]:
+                    pred_depth[("depth", 0)] = torch.cumsum(torch.cat((depth_outs["depth"][:,None,...], pred_depth[("depth", 0)]), dim=1), dim=1)
+                else:
+                    pred_depth[("depth", 0)] = torch.cat((depth_outs["depth"][:,None,...], pred_depth[("depth", 0)]), dim=1)
+                outputs_gauss[("depth", 0)] = rearrange(pred_depth[("depth", 0)], "b n c ... -> (b n) c ...", n = self.cfg.model.gaussians_per_pixel)
+                gauss_outs = dict()
+                for i in range(self.cfg.model.gaussians_per_pixel):
+                    outs = self.models["gauss_decoder_"+str(i)](resnet_features)
+                    if not self.cfg.model.one_gauss_decoder:
+                        for key, v in outs.items():
+                            gauss_outs[key] = outs[key][:,None,...] if i==0 else torch.cat([gauss_outs[key], outs[key][:,None,...]], dim=1)
+                    else:
+                        gauss_outs |= outs
+                for key, v in gauss_outs.items():
+                    gauss_outs[key] = rearrange(gauss_outs[key], 'b n ... -> (b n) ...')
+                outputs_gauss |= gauss_outs
+            else:
+                outputs_gauss[("depth", 0)] = depth_outs["depth"]
+                outputs_gauss |= self.gauss_decoder(resnet_features)
+        elif self.cfg.model.splat_branch == "unidepth_vit" or self.cfg.model.splat_branch == "unidepth_cnvnxtl":
+            split_network_outputs = self.splat_branch(inputs)[1].split(self.split_dimensions, dim=1)
+            offset, opacity, scaling, rotation, feat_dc = split_network_outputs[:5]
+
+            outputs_gauss |= {
+                ("gauss_opacity", 0): self.opacity_activation(opacity),
+                ("gauss_scaling", 0): self.scaling_activation(scaling),
+                ("gauss_rotation", 0): self.rotation_activation(rotation),
+                ("gauss_features_dc", 0): feat_dc
+            }
+
+            if self.cfg.model.max_sh_degree > 0:
+                features_rest = split_network_outputs[5]
+                outputs_gauss[("gauss_features_rest", 0)] = features_rest
+
+            assert self.cfg.model.predict_offset
+            outputs_gauss[("gauss_offset", 0)] = offset
+
+        return outputs_gauss
+    
diff --git a/flash3d/unidepth/layers/__init__.py b/flash3d/unidepth/layers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..4acb2508c715186fadfa3b0441b8a0e981bd41e3
--- /dev/null
+++ b/flash3d/unidepth/layers/__init__.py
@@ -0,0 +1,21 @@
+from .activation import SwiGLU, GEGLU
+from .convnext import CvnxtBlock
+from .attention import AttentionBlock, AttentionDecoderBlock
+from .nystrom_attention import NystromBlock
+from .positional_encoding import PositionEmbeddingSine
+from .upsample import ConvUpsample, ConvUpsampleShuffle
+from .mlp import MLP
+
+
+__all__ = [
+    "SwiGLU",
+    "GEGLU",
+    "CvnxtBlock",
+    "AttentionBlock",
+    "NystromBlock",
+    "PositionEmbeddingSine",
+    "ConvUpsample",
+    "MLP",
+    "ConvUpsampleShuffle",
+    "AttentionDecoderBlock",
+]
diff --git a/flash3d/unidepth/layers/activation.py b/flash3d/unidepth/layers/activation.py
new file mode 100644
index 0000000000000000000000000000000000000000..f5787a340013ba59e2956b6b829f724d9cfb7fcc
--- /dev/null
+++ b/flash3d/unidepth/layers/activation.py
@@ -0,0 +1,15 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class SwiGLU(nn.Module):
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        x, gates = x.chunk(2, dim=-1)
+        return x * F.silu(gates)
+
+
+class GEGLU(nn.Module):
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        x, gates = x.chunk(2, dim=-1)
+        return x * F.gelu(gates)
diff --git a/flash3d/unidepth/layers/attention.py b/flash3d/unidepth/layers/attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..c9fc5f79003e28815e65f9f8fe71474b7ed021a1
--- /dev/null
+++ b/flash3d/unidepth/layers/attention.py
@@ -0,0 +1,308 @@
+"""
+Author: Luigi Piccinelli
+Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/)
+"""
+
+from functools import partial
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from einops import rearrange
+
+from .layer_scale import LayerScale
+from .mlp import MLP
+
+
+class SimpleAttention(nn.Module):
+    def __init__(
+        self,
+        dim: int,
+        num_heads: int = 4,
+        dropout: float = 0.0,
+        cosine: bool = False,
+        context_dim: int | None = None,
+    ):
+        super().__init__()
+        self.dropout = dropout
+        self.num_heads = num_heads
+        self.hidden_dim = dim
+        context_dim = context_dim or dim
+
+        self.kv = nn.Linear(context_dim, dim * 2, bias=False)
+        self.q = nn.Linear(dim, dim, bias=False)
+        self.norm_attnx = nn.LayerNorm(dim)
+        self.norm_attnctx = nn.LayerNorm(context_dim)
+        self.cosine = cosine
+        self.out = nn.Linear(dim, dim)
+
+    def forward(
+        self,
+        x: torch.Tensor,
+        attn_bias: torch.Tensor | None = None,
+        context: torch.Tensor | None = None,
+        pos_embed: torch.Tensor | None = None,
+        pos_embed_context: torch.Tensor | None = None,
+        rope: nn.Module | None = None,
+    ) -> torch.Tensor:
+        context = x if context is None else context
+        x = self.norm_attnx(x)
+        context = self.norm_attnctx(context)
+        k, v = rearrange(
+            self.kv(context), "b n (kv h d) -> b h n d kv", h=self.num_heads, kv=2
+        ).unbind(dim=-1)
+        q = rearrange(self.q(x), "b n (h d) -> b h n d", h=self.num_heads)
+
+        if rope is not None:
+            q = rope(q)
+            k = rope(k)
+        else:
+            if pos_embed is not None:
+                pos_embed = rearrange(
+                    pos_embed, "b n (h d) -> b h n d", h=self.num_heads
+                )
+                q = q + pos_embed
+            if pos_embed_context is not None:
+                pos_embed_context = rearrange(
+                    pos_embed_context, "b n (h d) -> b h n d", h=self.num_heads
+                )
+                k = k + pos_embed_context
+
+        if self.cosine:
+            q, k = map(partial(F.normalize, p=2, dim=-1), (q, k))  # cosine sim
+        x = F.scaled_dot_product_attention(
+            q, k, v, dropout_p=self.dropout, attn_mask=attn_bias
+        )
+        x = rearrange(x, "b h n d -> b n (h d)")
+        x = self.out(x)
+        return x
+
+
+class AttentionBlock(nn.Module):
+    def __init__(
+        self,
+        dim: int,
+        num_heads: int = 4,
+        expansion: int = 4,
+        dropout: float = 0.0,
+        cosine: bool = False,
+        gated: bool = False,
+        layer_scale: float = 1.0,
+        context_dim: int | None = None,
+    ):
+        super().__init__()
+        self.dropout = dropout
+        self.num_heads = num_heads
+        self.hidden_dim = dim
+        context_dim = context_dim or dim
+        self.mlp = MLP(dim, expansion=expansion, dropout=dropout, gated=gated)
+        self.kv = nn.Linear(context_dim, dim * 2)
+        self.q = nn.Linear(dim, dim)
+        self.norm_attnx = nn.LayerNorm(dim)
+        self.norm_attnctx = nn.LayerNorm(context_dim)
+        self.cosine = cosine
+        self.out = nn.Linear(dim, dim)
+        self.ls1 = LayerScale(dim, layer_scale) if layer_scale > 0.0 else nn.Identity()
+        self.ls2 = LayerScale(dim, layer_scale) if layer_scale > 0.0 else nn.Identity()
+
+    def attn(
+        self,
+        x: torch.Tensor,
+        attn_bias: torch.Tensor | None = None,
+        context: torch.Tensor | None = None,
+        pos_embed: torch.Tensor | None = None,
+        pos_embed_context: torch.Tensor | None = None,
+        rope: nn.Module | None = None,
+    ) -> torch.Tensor:
+        x = self.norm_attnx(x)
+        context = self.norm_attnctx(context)
+        k, v = rearrange(
+            self.kv(context), "b n (kv h d) -> b h n d kv", h=self.num_heads, kv=2
+        ).unbind(dim=-1)
+        q = rearrange(self.q(x), "b n (h d) -> b h n d", h=self.num_heads)
+
+        if rope is not None:
+            q = rope(q)
+            k = rope(k)
+        else:
+            if pos_embed is not None:
+                pos_embed = rearrange(
+                    pos_embed, "b n (h d) -> b h n d", h=self.num_heads
+                )
+                q = q + pos_embed
+            if pos_embed_context is not None:
+                pos_embed_context = rearrange(
+                    pos_embed_context, "b n (h d) -> b h n d", h=self.num_heads
+                )
+                k = k + pos_embed_context
+
+        if self.cosine:
+            q, k = map(partial(F.normalize, p=2, dim=-1), (q, k))  # cosine sim
+
+        x = F.scaled_dot_product_attention(
+            q, k, v, dropout_p=self.dropout, attn_mask=attn_bias
+        )
+        x = rearrange(x, "b h n d -> b n (h d)")
+        x = self.out(x)
+        return x
+
+    def forward(
+        self,
+        x: torch.Tensor,
+        attn_bias: torch.Tensor | None = None,
+        context: torch.Tensor | None = None,
+        pos_embed: torch.Tensor | None = None,
+        pos_embed_context: torch.Tensor | None = None,
+        rope: nn.Module | None = None,
+    ) -> torch.Tensor:
+        context = x if context is None else context
+        x = (
+            self.ls1(
+                self.attn(
+                    x,
+                    rope=rope,
+                    attn_bias=attn_bias,
+                    context=context,
+                    pos_embed=pos_embed,
+                    pos_embed_context=pos_embed_context,
+                )
+            )
+            + x
+        )
+        x = self.ls2(self.mlp(x)) + x
+        return x
+
+
+class AttentionDecoderBlock(nn.Module):
+    def __init__(
+        self,
+        dim: int,
+        num_heads: int = 4,
+        expansion: int = 4,
+        dropout: float = 0.0,
+        cosine: bool = False,
+        gated: bool = False,
+        layer_scale: float = 1.0,
+        context_dim: int | None = None,
+        single_head_ca: bool = True,
+    ):
+        super().__init__()
+        self.dropout = dropout
+        self.num_heads = num_heads
+        self.hidden_dim = dim
+        self.single_head_ca = single_head_ca
+        context_dim = context_dim or dim
+        self.mlp = MLP(dim, expansion=expansion, dropout=dropout, gated=gated)
+        self.kv_ca = nn.Linear(context_dim, dim * 2)
+        self.q_ca = nn.Linear(dim, dim)
+        self.kv_sa = nn.Linear(dim, dim * 2)
+        self.q_sa = nn.Linear(dim, dim)
+        self.norm_x_sa = nn.LayerNorm(dim)
+        self.norm_x_ca = nn.LayerNorm(dim)
+        self.norm_ctx_ca = nn.LayerNorm(context_dim)
+        self.cosine = cosine
+        self.out_ca = nn.Linear(dim, dim)
+        self.out_sa = nn.Linear(dim, dim)
+        self.ls1 = LayerScale(dim, layer_scale) if layer_scale > 0.0 else nn.Identity()
+        self.ls2 = LayerScale(dim, layer_scale) if layer_scale > 0.0 else nn.Identity()
+        self.ls3 = LayerScale(dim, layer_scale) if layer_scale > 0.0 else nn.Identity()
+
+    def cross_attn(
+        self,
+        x: torch.Tensor,
+        attn_bias: torch.Tensor | None = None,
+        context: torch.Tensor | None = None,
+        pos_embed: torch.Tensor | None = None,
+        pos_embed_context: torch.Tensor | None = None,
+        rope: nn.Module | None = None,
+    ) -> torch.Tensor:
+        num_heads = 1 if self.single_head_ca else self.num_heads
+        x = self.norm_x_ca(x)
+        context = self.norm_ctx_ca(context)
+        k, v = rearrange(
+            self.kv_ca(context), "b n (kv h d) -> b h n d kv", h=num_heads, kv=2
+        ).unbind(dim=-1)
+        q = rearrange(self.q_ca(x), "b n (h d) -> b h n d", h=num_heads)
+
+        if rope is not None:
+            q = rope(q)
+            k = rope(k)
+        else:
+            if pos_embed is not None:
+                pos_embed = rearrange(pos_embed, "b n (h d) -> b h n d", h=num_heads)
+                q = q + pos_embed
+            if pos_embed_context is not None:
+                pos_embed_context = rearrange(
+                    pos_embed_context, "b n (h d) -> b h n d", h=num_heads
+                )
+                k = k + pos_embed_context
+
+        if self.cosine:
+            q, k = map(partial(F.normalize, p=2, dim=-1), (q, k))  # cosine sim
+        x = F.scaled_dot_product_attention(
+            q, k, v, dropout_p=self.dropout, attn_mask=attn_bias
+        )
+        x = rearrange(x, "b h n d -> b n (h d)")
+        x = self.out_ca(x)
+        return x
+
+    def self_attn(
+        self,
+        x: torch.Tensor,
+        attn_bias: torch.Tensor | None = None,
+        pos_embed: torch.Tensor | None = None,
+        rope: nn.Module | None = None,
+    ) -> torch.Tensor:
+        x = self.norm_x_sa(x)
+        k, v = rearrange(
+            self.kv_sa(x), "b n (kv h d) -> b h n d kv", h=self.num_heads, kv=2
+        ).unbind(dim=-1)
+        q = rearrange(self.q_sa(x), "b n (h d) -> b h n d", h=self.num_heads)
+
+        if rope is not None:
+            q = rope(q)
+            k = rope(k)
+        elif pos_embed is not None:
+            pos_embed = rearrange(pos_embed, "b n (h d) -> b h n d", h=self.num_heads)
+            q = q + pos_embed
+
+        if self.cosine:
+            q, k = map(partial(F.normalize, p=2, dim=-1), (q, k))  # cosine sim
+        x = F.scaled_dot_product_attention(
+            q, k, v, dropout_p=self.dropout, attn_mask=attn_bias
+        )
+        x = rearrange(x, "b h n d -> b n (h d)")
+        x = self.out_sa(x)
+        return x
+
+    def forward(
+        self,
+        x: torch.Tensor,
+        attn_bias: torch.Tensor | None = None,
+        context: torch.Tensor | None = None,
+        pos_embed: torch.Tensor | None = None,
+        pos_embed_context: torch.Tensor | None = None,
+        rope: nn.Module | None = None,
+    ) -> torch.Tensor:
+        context = x if context is None else context
+        x = (
+            self.ls1(
+                self.cross_attn(
+                    x,
+                    rope=rope,
+                    attn_bias=attn_bias,
+                    context=context,
+                    pos_embed=pos_embed,
+                    pos_embed_context=pos_embed_context,
+                )
+            )
+            + x
+        )
+        x = (
+            self.ls2(
+                self.self_attn(x, rope=rope, attn_bias=attn_bias, pos_embed=pos_embed)
+            )
+            + x
+        )
+        x = self.ls3(self.mlp(x)) + x
+        return x
diff --git a/flash3d/unidepth/layers/convnext.py b/flash3d/unidepth/layers/convnext.py
new file mode 100644
index 0000000000000000000000000000000000000000..12a4e9a15e25433418d6b066f15a39a205f5aa81
--- /dev/null
+++ b/flash3d/unidepth/layers/convnext.py
@@ -0,0 +1,44 @@
+import torch
+import torch.nn as nn
+
+
+class CvnxtBlock(nn.Module):
+    def __init__(
+        self,
+        dim,
+        kernel_size=7,
+        layer_scale=1.0,
+        expansion=4,
+        dilation=1,
+    ):
+        super().__init__()
+        self.dwconv = nn.Conv2d(
+            dim,
+            dim,
+            kernel_size=kernel_size,
+            padding="same",
+            groups=dim,
+            dilation=dilation,
+        )  # depthwise conv
+        self.norm = nn.LayerNorm(dim, eps=1e-6)
+        self.pwconv1 = nn.Linear(
+            dim, expansion * dim
+        )  # pointwise/1x1 convs, implemented with linear layers
+        self.act = nn.GELU()
+        self.pwconv2 = nn.Linear(expansion * dim, dim)
+        self.gamma = (
+            nn.Parameter(layer_scale * torch.ones((dim))) if layer_scale > 0.0 else 1.0
+        )
+
+    def forward(self, x):
+        input = x
+        x = self.dwconv(x)
+        x = x.permute(0, 2, 3, 1)  # (N, C, H, W) -> (N, H, W, C)
+        x = self.norm(x)
+        x = self.pwconv1(x)
+        x = self.act(x)
+        x = self.pwconv2(x)
+
+        x = self.gamma * x
+        x = input + x.permute(0, 3, 1, 2)  # (N, H, W, C) -> (N, C, H, W)
+        return x
diff --git a/flash3d/unidepth/layers/drop_path.py b/flash3d/unidepth/layers/drop_path.py
new file mode 100644
index 0000000000000000000000000000000000000000..781ff566500c923b1f199542b0c7dfb862a077ca
--- /dev/null
+++ b/flash3d/unidepth/layers/drop_path.py
@@ -0,0 +1,25 @@
+import torch
+import torch.nn as nn
+
+
+def drop_path(x: torch.Tensor, drop_prob: float = 0.0, training: bool = False):
+    if drop_prob == 0.0 or not training:
+        return x
+    keep_prob = 1 - drop_prob
+    shape = (x.shape[0],) + (1,) * (
+        x.ndim - 1
+    )  # work with diff dim tensors, not just 2D ConvNets
+    random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
+    if keep_prob > 0.0:
+        random_tensor.div_(keep_prob)
+    output = x * random_tensor
+    return output
+
+
+class DropPath(nn.Module):
+    def __init__(self, drop_prob=None):
+        super(DropPath, self).__init__()
+        self.drop_prob = drop_prob
+
+    def forward(self, x):
+        return drop_path(x, self.drop_prob, self.training)
diff --git a/flash3d/unidepth/layers/layer_scale.py b/flash3d/unidepth/layers/layer_scale.py
new file mode 100644
index 0000000000000000000000000000000000000000..01b6662490d7296725f103d1abf8790cac84d0f8
--- /dev/null
+++ b/flash3d/unidepth/layers/layer_scale.py
@@ -0,0 +1,17 @@
+import torch
+import torch.nn as nn
+
+
+class LayerScale(nn.Module):
+    def __init__(
+        self,
+        dim: int,
+        init_values: float | torch.Tensor = 1e-5,
+        inplace: bool = False,
+    ) -> None:
+        super().__init__()
+        self.inplace = inplace
+        self.gamma = nn.Parameter(init_values * torch.ones(dim))
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        return x.mul_(self.gamma) if self.inplace else x * self.gamma
diff --git a/flash3d/unidepth/layers/mlp.py b/flash3d/unidepth/layers/mlp.py
new file mode 100644
index 0000000000000000000000000000000000000000..074b7e3949a12233e88a08877738e0ce2ca53acf
--- /dev/null
+++ b/flash3d/unidepth/layers/mlp.py
@@ -0,0 +1,34 @@
+import torch
+import torch.nn as nn
+
+from unidepth.utils.misc import default
+from .activation import SwiGLU
+
+
+class MLP(nn.Module):
+    def __init__(
+        self,
+        input_dim: int,
+        expansion: int = 4,
+        dropout: float = 0.0,
+        gated: bool = False,
+        output_dim: int | None = None,
+    ):
+        super().__init__()
+        if gated:
+            expansion = int(expansion * 2 / 3)
+        hidden_dim = int(input_dim * expansion)
+        output_dim = default(output_dim, input_dim)
+        self.norm = nn.LayerNorm(input_dim)
+        self.proj1 = nn.Linear(input_dim, hidden_dim)
+        self.proj2 = nn.Linear(hidden_dim, output_dim)
+        self.act = nn.GELU() if not gated else SwiGLU()
+        self.dropout = nn.Dropout(dropout) if dropout > 0.0 else nn.Identity()
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        x = self.norm(x)
+        x = self.proj1(x)
+        x = self.act(x)
+        x = self.proj2(x)
+        x = self.dropout(x)
+        return x
diff --git a/flash3d/unidepth/layers/nystrom_attention.py b/flash3d/unidepth/layers/nystrom_attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..9f7476f114a68617bf64bc4cb51eec6c98445df5
--- /dev/null
+++ b/flash3d/unidepth/layers/nystrom_attention.py
@@ -0,0 +1,74 @@
+from functools import partial
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from einops import rearrange
+from xformers.components.attention import NystromAttention
+
+from .attention import AttentionBlock
+
+
+class NystromBlock(AttentionBlock):
+    def __init__(
+        self,
+        dim: int,
+        num_heads: int = 4,
+        expansion: int = 4,
+        dropout: float = 0.0,
+        cosine: bool = False,
+        gated: bool = False,
+        layer_scale: float = 1.0,
+        context_dim: int | None = None,
+    ):
+        super().__init__(
+            dim=dim,
+            num_heads=num_heads,
+            expansion=expansion,
+            dropout=dropout,
+            cosine=cosine,
+            gated=gated,
+            layer_scale=layer_scale,
+            context_dim=context_dim,
+        )
+        self.attention_fn = NystromAttention(
+            num_landmarks=128, num_heads=num_heads, dropout=dropout
+        )
+
+    def attn(
+        self,
+        x: torch.Tensor,
+        attn_bias: torch.Tensor | None = None,
+        context: torch.Tensor | None = None,
+        pos_embed: torch.Tensor | None = None,
+        pos_embed_context: torch.Tensor | None = None,
+        rope: nn.Module | None = None,
+    ) -> torch.Tensor:
+        x = self.norm_attnx(x)
+        context = self.norm_attnctx(context)
+        k, v = rearrange(
+            self.kv(context), "b n (kv h d) -> b n h d kv", h=self.num_heads, kv=2
+        ).unbind(dim=-1)
+        q = rearrange(self.q(x), "b n (h d) -> b n h d", h=self.num_heads)
+
+        if rope is not None:
+            q = rope(q)
+            k = rope(k)
+        else:
+            if pos_embed is not None:
+                pos_embed = rearrange(
+                    pos_embed, "b n (h d) -> b n h d", h=self.num_heads
+                )
+                q = q + pos_embed
+            if pos_embed_context is not None:
+                pos_embed_context = rearrange(
+                    pos_embed_context, "b n (h d) -> b n h d", h=self.num_heads
+                )
+                k = k + pos_embed_context
+
+        if self.cosine:
+            q, k = map(partial(F.normalize, p=2, dim=-1), (q, k))  # cosine sim
+        x = self.attention_fn(q, k, v, key_padding_mask=attn_bias)
+        x = rearrange(x, "b n h d -> b n (h d)")
+        x = self.out(x)
+        return x
diff --git a/flash3d/unidepth/layers/positional_encoding.py b/flash3d/unidepth/layers/positional_encoding.py
new file mode 100644
index 0000000000000000000000000000000000000000..616dc197cf2e602e85085dc6f05957920f115cb5
--- /dev/null
+++ b/flash3d/unidepth/layers/positional_encoding.py
@@ -0,0 +1,228 @@
+"""
+Author: Luigi Piccinelli
+Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/)
+"""
+
+from math import pi
+from typing import Optional
+
+import torch
+import torch.nn as nn
+
+from einops import rearrange, repeat
+
+
+class PositionEmbeddingSine(nn.Module):
+    def __init__(
+        self, num_pos_feats=64, temperature=10000, normalize=False, scale=None
+    ):
+        super().__init__()
+        self.num_pos_feats = num_pos_feats
+        self.temperature = temperature
+        self.normalize = normalize
+        if scale is not None and normalize is False:
+            raise ValueError("normalize should be True if scale is passed")
+        if scale is None:
+            scale = 2 * pi
+        self.scale = scale
+
+    def forward(
+        self, x: torch.Tensor, mask: Optional[torch.Tensor] = None
+    ) -> torch.Tensor:
+        if mask is None:
+            mask = torch.zeros(
+                (x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool
+            )
+        not_mask = ~mask
+        y_embed = not_mask.cumsum(1, dtype=torch.float32)
+        x_embed = not_mask.cumsum(2, dtype=torch.float32)
+        if self.normalize:
+            eps = 1e-6
+            y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
+            x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
+
+        dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
+        dim_t = self.temperature ** (
+            2 * torch.div(dim_t, 2, rounding_mode="floor") / self.num_pos_feats
+        )
+
+        pos_x = x_embed[:, :, :, None] / dim_t
+        pos_y = y_embed[:, :, :, None] / dim_t
+        pos_x = torch.stack(
+            (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
+        ).flatten(3)
+        pos_y = torch.stack(
+            (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
+        ).flatten(3)
+        pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
+        return pos
+
+    def __repr__(self, _repr_indent=4):
+        head = "Positional encoding " + self.__class__.__name__
+        body = [
+            "num_pos_feats: {}".format(self.num_pos_feats),
+            "temperature: {}".format(self.temperature),
+            "normalize: {}".format(self.normalize),
+            "scale: {}".format(self.scale),
+        ]
+        # _repr_indent = 4
+        lines = [head] + [" " * _repr_indent + line for line in body]
+        return "\n".join(lines)
+
+
+class LearnedSinusoidalPosEmb(nn.Module):
+    def __init__(self, dim):
+        super().__init__()
+        assert (dim % 2) == 0
+        half_dim = dim // 2
+        self.weights = nn.Parameter(torch.randn(half_dim))
+
+    def forward(self, x):
+        x = rearrange(x, "b -> b 1")
+        freqs = x * rearrange(self.weights, "d -> 1 d") * 2 * pi
+        fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1)
+        fouriered = torch.cat((x, fouriered), dim=-1)
+        return fouriered
+
+
+def generate_fourier_features(x, max_freq=64, num_bands=16):
+    x = x.unsqueeze(-1)
+    device, dtype, orig_x = x.device, x.dtype, x
+
+    scales = torch.linspace(
+        -max_freq / 2, max_freq / 2, num_bands, device=device, dtype=dtype
+    )
+    scales = scales[(*((None,) * (len(x.shape) - 1)), Ellipsis)]
+
+    x = x * scales * pi
+    x = torch.cat([x.sin(), x.cos()], dim=-1)
+    x = torch.cat((x, orig_x), dim=-1)
+    return x.flatten(-2)
+
+
+def broadcat(tensors, dim=-1):
+    num_tensors = len(tensors)
+    shape_lens = set(list(map(lambda t: len(t.shape), tensors)))
+    assert len(shape_lens) == 1, "tensors must all have the same number of dimensions"
+    shape_len = list(shape_lens)[0]
+    dim = (dim + shape_len) if dim < 0 else dim
+    dims = list(zip(*map(lambda t: list(t.shape), tensors)))
+    expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
+    assert all(
+        [*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]
+    ), "invalid dimensions for broadcastable concatentation"
+    max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims))
+    expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims))
+    expanded_dims.insert(dim, (dim, dims[dim]))
+    expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims)))
+    tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes)))
+    return torch.cat(tensors, dim=dim)
+
+
+def rotate_half(x):
+    x = rearrange(x, "... (d r) -> ... d r", r=2)
+    x1, x2 = x.unbind(dim=-1)
+    x = torch.stack((-x2, x1), dim=-1)
+    return rearrange(x, "... d r -> ... (d r)")
+
+
+class VisionRotaryEmbedding(nn.Module):
+    def __init__(
+        self,
+        dim,
+        pt_seq_len,
+        ft_seq_len=None,
+        custom_freqs=None,
+        freqs_for="lang",
+        theta=10000,
+        max_freq=10,
+        num_freqs=1,
+    ):
+        super().__init__()
+        if custom_freqs:
+            freqs = custom_freqs
+        elif freqs_for == "lang":
+            freqs = 1.0 / (
+                theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
+            )
+        elif freqs_for == "pixel":
+            freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi
+        elif freqs_for == "constant":
+            freqs = torch.ones(num_freqs).float()
+        else:
+            raise ValueError(f"unknown modality {freqs_for}")
+
+        if ft_seq_len is None:
+            ft_seq_len = pt_seq_len
+        t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
+
+        freqs_h = torch.einsum("..., f -> ... f", t, freqs)
+        freqs_h = repeat(freqs_h, "... n -> ... (n r)", r=2)
+
+        freqs_w = torch.einsum("..., f -> ... f", t, freqs)
+        freqs_w = repeat(freqs_w, "... n -> ... (n r)", r=2)
+
+        freqs = broadcat((freqs_h[:, None, :], freqs_w[None, :, :]), dim=-1)
+
+        self.register_buffer("freqs_cos", freqs.cos())
+        self.register_buffer("freqs_sin", freqs.sin())
+
+        print("======== shape of rope freq", self.freqs_cos.shape, "========")
+
+    def forward(self, t, start_index=0):
+        rot_dim = self.freqs_cos.shape[-1]
+        end_index = start_index + rot_dim
+        assert (
+            rot_dim <= t.shape[-1]
+        ), f"feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}"
+        t_left, t, t_right = (
+            t[..., :start_index],
+            t[..., start_index:end_index],
+            t[..., end_index:],
+        )
+        t = (t * self.freqs_cos) + (rotate_half(t) * self.freqs_sin)
+        return torch.cat((t_left, t, t_right), dim=-1)
+
+
+class VisionRotaryEmbeddingFast(nn.Module):
+    def __init__(
+        self,
+        dim,
+        pt_seq_len,
+        ft_seq_len=None,
+        custom_freqs=None,
+        freqs_for="lang",
+        theta=10000,
+        max_freq=10,
+        num_freqs=1,
+    ):
+        super().__init__()
+        if custom_freqs:
+            freqs = custom_freqs
+        elif freqs_for == "lang":
+            freqs = 1.0 / (
+                theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
+            )
+        elif freqs_for == "pixel":
+            freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi
+        elif freqs_for == "constant":
+            freqs = torch.ones(num_freqs).float()
+        else:
+            raise ValueError(f"unknown modality {freqs_for}")
+
+        if ft_seq_len is None:
+            ft_seq_len = pt_seq_len
+        t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
+
+        freqs = torch.einsum("..., f -> ... f", t, freqs)
+        freqs = repeat(freqs, "... n -> ... (n r)", r=2)
+        freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim=-1)
+
+        freqs_cos = freqs.cos().view(-1, freqs.shape[-1])
+        freqs_sin = freqs.sin().view(-1, freqs.shape[-1])
+
+        self.register_buffer("freqs_cos", freqs_cos)
+        self.register_buffer("freqs_sin", freqs_sin)
+
+    def forward(self, t):
+        return t * self.freqs_cos + rotate_half(t) * self.freqs_sin
diff --git a/flash3d/unidepth/layers/upsample.py b/flash3d/unidepth/layers/upsample.py
new file mode 100644
index 0000000000000000000000000000000000000000..b0162e76c3319c8ac802b68795d6f32793e693b6
--- /dev/null
+++ b/flash3d/unidepth/layers/upsample.py
@@ -0,0 +1,69 @@
+"""
+Author: Luigi Piccinelli
+Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/)
+"""
+
+import torch
+import torch.nn as nn
+from einops import rearrange
+
+from .convnext import CvnxtBlock
+
+
+class ConvUpsample(nn.Module):
+    def __init__(
+        self,
+        hidden_dim,
+        num_layers: int = 2,
+        expansion: int = 4,
+        layer_scale: float = 1.0,
+        kernel_size: int = 7,
+        **kwargs
+    ):
+        super().__init__()
+        self.convs = nn.ModuleList([])
+        for _ in range(num_layers):
+            self.convs.append(
+                CvnxtBlock(
+                    hidden_dim,
+                    kernel_size=kernel_size,
+                    expansion=expansion,
+                    layer_scale=layer_scale,
+                )
+            )
+        self.up = nn.Sequential(
+            nn.Conv2d(hidden_dim, hidden_dim // 2, kernel_size=1, padding=0),
+            nn.UpsamplingBilinear2d(scale_factor=2),
+            nn.Conv2d(hidden_dim // 2, hidden_dim // 2, kernel_size=3, padding=1),
+        )
+
+    def forward(self, x: torch.Tensor):
+        for conv in self.convs:
+            x = conv(x)
+        x = self.up(x)
+        x = rearrange(x, "b c h w -> b (h w) c")
+        return x
+
+
+class ConvUpsampleShuffle(nn.Module):
+    def __init__(
+        self, hidden_dim, expansion: int = 4, layer_scale: float = 1.0, **kwargs
+    ):
+        super().__init__()
+        self.conv1 = CvnxtBlock(
+            hidden_dim, expansion=expansion, layer_scale=layer_scale
+        )
+        self.conv2 = CvnxtBlock(
+            hidden_dim, expansion=expansion, layer_scale=layer_scale
+        )
+        self.up = nn.Sequential(
+            nn.PixelShuffle(2),
+            nn.Conv2d(hidden_dim // 4, hidden_dim // 2, kernel_size=3, padding=1),
+        )
+
+    def forward(self, x: torch.Tensor):
+        x = self.conv1(x)
+        x = self.conv2(x)
+        x = self.up(x)
+        x = rearrange(x, "b c h w -> b (h w) c")
+        return x
diff --git a/flash3d/unidepth/models/__init__.py b/flash3d/unidepth/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..1781bda94cdc13b0c0c805e7cde0872defc20cd3
--- /dev/null
+++ b/flash3d/unidepth/models/__init__.py
@@ -0,0 +1,5 @@
+from .unidepthv1 import UniDepthV1
+
+__all__ = [
+    "UniDepthV1",
+]
diff --git a/flash3d/unidepth/models/backbones/__init__.py b/flash3d/unidepth/models/backbones/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..55f55cde8b365f0f63f98fcc21ea64166c7ce33c
--- /dev/null
+++ b/flash3d/unidepth/models/backbones/__init__.py
@@ -0,0 +1,9 @@
+from .convnext2 import ConvNeXtV2
+from .convnext import ConvNeXt
+from .dinov2 import _make_dinov2_model
+
+__all__ = [
+    "ConvNeXt",
+    "ConvNeXtV2",
+    "_make_dinov2_model",
+]
diff --git a/flash3d/unidepth/models/backbones/convnext.py b/flash3d/unidepth/models/backbones/convnext.py
new file mode 100644
index 0000000000000000000000000000000000000000..b745415724df69347697efc9987c3b8a6c9cb849
--- /dev/null
+++ b/flash3d/unidepth/models/backbones/convnext.py
@@ -0,0 +1,590 @@
+from collections import OrderedDict
+from functools import partial
+from typing import Callable, Optional, Tuple, Union, Sequence
+
+import torch
+import torch.nn as nn
+from torch.utils.checkpoint import checkpoint
+
+from timm.layers import (
+    trunc_normal_,
+    AvgPool2dSame,
+    DropPath,
+    Mlp,
+    GlobalResponseNormMlp,
+    LayerNorm2d,
+    LayerNorm,
+    create_conv2d,
+    get_act_layer,
+    make_divisible,
+    to_ntuple,
+)
+
+
+def get_num_layer_for_convnext(var_name):
+    """
+    Divide [3, 3, 27, 3] layers into 12 groups; each group is three
+    consecutive blocks, including possible neighboring downsample layers;
+    adapted from https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py
+    """
+    if var_name.startswith("downsample_layers"):
+        stage_id = int(var_name.split(".")[1])
+        if stage_id == 0:
+            layer_id = 0
+        elif stage_id == 1 or stage_id == 2:
+            layer_id = stage_id + 1
+        elif stage_id == 3:
+            layer_id = 12
+
+    elif var_name.startswith("stages"):
+        stage_id = int(var_name.split(".")[1])
+        block_id = int(var_name.split(".")[3])
+        if stage_id == 0 or stage_id == 1:
+            layer_id = stage_id + 1
+        elif stage_id == 2:
+            layer_id = 3 + block_id // 3
+        elif stage_id == 3:
+            layer_id = 12
+
+    elif var_name.startswith("stem"):
+        return 0
+    else:
+        layer_id = 12
+    return layer_id + 1
+
+
+def get_parameter_groups(model, lr, wd=1e-5, ld=0.9, skip_list=None):
+    parameter_group_names = {}
+    parameter_group_vars = {}
+    skip = set()
+    if skip_list is not None:
+        skip = skip_list
+    if hasattr(model, "no_weight_decay"):
+        skip.update(model.no_weight_decay())
+    num_layers = 12
+    layer_scale = list(ld ** (num_layers + 1 - i) for i in range(num_layers + 2))
+    for name, param in model.named_parameters():
+        if not param.requires_grad:
+            continue  # frozen weights
+        if len(param.shape) == 1 or name.endswith(".bias") or name in skip:
+            group_name = "no_decay"
+            this_wd = 0.0
+        else:
+            group_name = "decay"
+            this_wd = wd
+
+        layer_id = get_num_layer_for_convnext(name)
+        group_name = "layer_%d_%s" % (layer_id, group_name)
+
+        if group_name not in parameter_group_names:
+            scale = layer_scale[layer_id]
+            cur_lr = lr * scale
+
+            parameter_group_names[group_name] = {
+                "weight_decay": this_wd,
+                "weight_decay_init": this_wd,
+                "weight_decay_base": this_wd,
+                "params": [],
+                "lr_init": cur_lr,
+                "lr_base": lr,
+                "lr": cur_lr,
+            }
+            parameter_group_vars[group_name] = {
+                "weight_decay": this_wd,
+                "weight_decay_init": this_wd,
+                "weight_decay_base": this_wd,
+                "params": [],
+                "lr_init": cur_lr,
+                "lr_base": lr,
+                "lr": cur_lr,
+            }
+            if this_wd == 0.0:
+                parameter_group_names[group_name]["weight_decay_final"] = 0.0
+                parameter_group_vars[group_name]["weight_decay_final"] = 0.0
+        parameter_group_vars[group_name]["params"].append(param)
+        parameter_group_names[group_name]["params"].append(name)
+    # from unidepth.utils import is_main_process
+    # import json
+    # if is_main_process():
+    #     print("Param groups = %s" % json.dumps(parameter_group_names, indent=2))
+    return list(parameter_group_vars.values()), [
+        v["lr"] for k, v in parameter_group_vars.items()
+    ]
+
+
+class Downsample(nn.Module):
+    def __init__(self, in_chs, out_chs, stride=1, dilation=1):
+        super().__init__()
+        avg_stride = stride if dilation == 1 else 1
+        if stride > 1 or dilation > 1:
+            avg_pool_fn = (
+                AvgPool2dSame if avg_stride == 1 and dilation > 1 else nn.AvgPool2d
+            )
+            self.pool = avg_pool_fn(
+                2, avg_stride, ceil_mode=True, count_include_pad=False
+            )
+        else:
+            self.pool = nn.Identity()
+
+        if in_chs != out_chs:
+            self.conv = create_conv2d(in_chs, out_chs, 1, stride=1)
+        else:
+            self.conv = nn.Identity()
+
+    def forward(self, x):
+        x = self.pool(x)
+        x = self.conv(x)
+        return x
+
+
+class ConvNeXtBlock(nn.Module):
+    """ConvNeXt Block
+    There are two equivalent implementations:
+      (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
+      (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
+
+    Unlike the official impl, this one allows choice of 1 or 2, 1x1 conv can be faster with appropriate
+    choice of LayerNorm impl, however as model size increases the tradeoffs appear to change and nn.Linear
+    is a better choice. This was observed with PyTorch 1.10 on 3090 GPU, it could change over time & w/ different HW.
+    """
+
+    def __init__(
+        self,
+        in_chs: int,
+        out_chs: Optional[int] = None,
+        kernel_size: int = 7,
+        stride: int = 1,
+        dilation: Union[int, Tuple[int, int]] = (1, 1),
+        mlp_ratio: float = 4,
+        conv_mlp: bool = False,
+        conv_bias: bool = True,
+        use_grn: bool = False,
+        ls_init_value: Optional[float] = 1e-6,
+        act_layer: Union[str, Callable] = "gelu",
+        norm_layer: Optional[Callable] = None,
+        drop_path: float = 0.0,
+    ):
+        """
+
+        Args:
+            in_chs: Block input channels.
+            out_chs: Block output channels (same as in_chs if None).
+            kernel_size: Depthwise convolution kernel size.
+            stride: Stride of depthwise convolution.
+            dilation: Tuple specifying input and output dilation of block.
+            mlp_ratio: MLP expansion ratio.
+            conv_mlp: Use 1x1 convolutions for MLP and a NCHW compatible norm layer if True.
+            conv_bias: Apply bias for all convolution (linear) layers.
+            use_grn: Use GlobalResponseNorm in MLP (from ConvNeXt-V2)
+            ls_init_value: Layer-scale init values, layer-scale applied if not None.
+            act_layer: Activation layer.
+            norm_layer: Normalization layer (defaults to LN if not specified).
+            drop_path: Stochastic depth probability.
+        """
+        super().__init__()
+        out_chs = out_chs or in_chs
+        dilation = to_ntuple(2)(dilation)
+        act_layer = get_act_layer(act_layer)
+        if not norm_layer:
+            norm_layer = LayerNorm2d if conv_mlp else LayerNorm
+        mlp_layer = partial(
+            GlobalResponseNormMlp if use_grn else Mlp, use_conv=conv_mlp
+        )
+        self.use_conv_mlp = conv_mlp
+        self.conv_dw = create_conv2d(
+            in_chs,
+            out_chs,
+            kernel_size=kernel_size,
+            stride=stride,
+            dilation=dilation[0],
+            depthwise=True,
+            bias=conv_bias,
+        )
+        self.norm = norm_layer(out_chs)
+        self.mlp = mlp_layer(out_chs, int(mlp_ratio * out_chs), act_layer=act_layer)
+        self.gamma = (
+            nn.Parameter(ls_init_value * torch.ones(out_chs))
+            if ls_init_value is not None
+            else None
+        )
+        if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]:
+            self.shortcut = Downsample(
+                in_chs, out_chs, stride=stride, dilation=dilation[0]
+            )
+        else:
+            self.shortcut = nn.Identity()
+        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+
+    def forward(self, x):
+        shortcut = x
+        x = self.conv_dw(x.contiguous())
+        if self.use_conv_mlp:
+            x = self.norm(x)
+            x = self.mlp(x)
+        else:
+            x = x.permute(0, 2, 3, 1).contiguous()
+            x = self.norm(x)
+            x = self.mlp(x)
+            x = x.permute(0, 3, 1, 2).contiguous()
+        if self.gamma is not None:
+            x = x.mul(self.gamma.reshape(1, -1, 1, 1))
+
+        x = self.drop_path(x) + self.shortcut(shortcut)
+        return x.contiguous()
+
+
+class ConvNeXtStage(nn.Module):
+    def __init__(
+        self,
+        in_chs,
+        out_chs,
+        kernel_size=7,
+        stride=2,
+        depth=2,
+        dilation=(1, 1),
+        drop_path_rates=None,
+        ls_init_value=1.0,
+        conv_mlp=False,
+        conv_bias=True,
+        use_grn=False,
+        act_layer="gelu",
+        norm_layer=None,
+        norm_layer_cl=None,
+    ):
+        super().__init__()
+        self.grad_checkpointing = False
+
+        if in_chs != out_chs or stride > 1 or dilation[0] != dilation[1]:
+            ds_ks = 2 if stride > 1 or dilation[0] != dilation[1] else 1
+            pad = (
+                "same" if dilation[1] > 1 else 0
+            )  # same padding needed if dilation used
+            self.downsample = nn.Sequential(
+                norm_layer(in_chs),
+                create_conv2d(
+                    in_chs,
+                    out_chs,
+                    kernel_size=ds_ks,
+                    stride=stride,
+                    dilation=dilation[0],
+                    padding=pad,
+                    bias=conv_bias,
+                ),
+            )
+            in_chs = out_chs
+        else:
+            self.downsample = nn.Identity()
+
+        drop_path_rates = drop_path_rates or [0.0] * depth
+        stage_blocks = []
+        for i in range(depth):
+            stage_blocks.append(
+                ConvNeXtBlock(
+                    in_chs=in_chs,
+                    out_chs=out_chs,
+                    kernel_size=kernel_size,
+                    dilation=dilation[1],
+                    drop_path=drop_path_rates[i],
+                    ls_init_value=ls_init_value,
+                    conv_mlp=conv_mlp,
+                    conv_bias=conv_bias,
+                    use_grn=use_grn,
+                    act_layer=act_layer,
+                    norm_layer=norm_layer if conv_mlp else norm_layer_cl,
+                )
+            )
+            in_chs = out_chs
+        self.blocks = nn.ModuleList(stage_blocks)
+
+    def forward(self, x):
+        xs = []
+        x = self.downsample(x)
+        for block in self.blocks:
+            if self.grad_checkpointing:
+                x = checkpoint(block, x)
+            else:
+                x = block(x)
+            xs.append(x)
+        return xs
+
+
+class ConvNeXt(nn.Module):
+    def __init__(
+        self,
+        in_chans: int = 3,
+        output_stride: int = 32,
+        depths: Tuple[int, ...] = (3, 3, 9, 3),
+        dims: Tuple[int, ...] = (96, 192, 384, 768),
+        kernel_sizes: Union[int, Tuple[int, ...]] = 7,
+        ls_init_value: Optional[float] = 1e-6,
+        stem_type: str = "patch",
+        patch_size: int = 4,
+        conv_mlp: bool = False,
+        conv_bias: bool = True,
+        use_grn: bool = False,
+        act_layer: Union[str, Callable] = "gelu",
+        norm_layer: Optional[Union[str, Callable]] = None,
+        norm_eps: Optional[float] = None,
+        drop_path_rate: float = 0.0,
+        output_idx=[],
+        use_checkpoint=False,
+    ):
+        """
+        Args:
+            in_chans: Number of input image channels.
+            num_classes: Number of classes for classification head.
+            global_pool: Global pooling type.
+            output_stride: Output stride of network, one of (8, 16, 32).
+            depths: Number of blocks at each stage.
+            dims: Feature dimension at each stage.
+            kernel_sizes: Depthwise convolution kernel-sizes for each stage.
+            ls_init_value: Init value for Layer Scale, disabled if None.
+            stem_type: Type of stem.
+            patch_size: Stem patch size for patch stem.
+            head_init_scale: Init scaling value for classifier weights and biases.
+            head_norm_first: Apply normalization before global pool + head.
+            head_hidden_size: Size of MLP hidden layer in head if not None and head_norm_first == False.
+            conv_mlp: Use 1x1 conv in MLP, improves speed for small networks w/ chan last.
+            conv_bias: Use bias layers w/ all convolutions.
+            use_grn: Use Global Response Norm (ConvNeXt-V2) in MLP.
+            act_layer: Activation layer type.
+            norm_layer: Normalization layer type.
+            drop_rate: Head pre-classifier dropout rate.
+            drop_path_rate: Stochastic depth drop rate.
+        """
+        super().__init__()
+        self.num_layers = len(depths)
+        self.depths = output_idx
+        self.embed_dims = [
+            int(dim) for i, dim in enumerate(dims) for _ in range(depths[i])
+        ]
+        self.embed_dim = dims[0]
+
+        assert output_stride in (8, 16, 32)
+        kernel_sizes = to_ntuple(4)(kernel_sizes)
+        if norm_layer is None:
+            norm_layer = LayerNorm2d
+            norm_layer_cl = norm_layer if conv_mlp else LayerNorm
+            if norm_eps is not None:
+                norm_layer = partial(norm_layer, eps=norm_eps)
+                norm_layer_cl = partial(norm_layer_cl, eps=norm_eps)
+        else:
+            assert (
+                conv_mlp
+            ), "If a norm_layer is specified, conv MLP must be used so all norm expect rank-4, channels-first input"
+            norm_layer_cl = norm_layer
+            if norm_eps is not None:
+                norm_layer_cl = partial(norm_layer_cl, eps=norm_eps)
+
+        self.feature_info = []
+
+        assert stem_type in ("patch", "overlap", "overlap_tiered")
+        if stem_type == "patch":
+            # NOTE: this stem is a minimal form of ViT PatchEmbed, as used in SwinTransformer w/ patch_size = 4
+            self.stem = nn.Sequential(
+                nn.Conv2d(
+                    in_chans,
+                    dims[0],
+                    kernel_size=patch_size,
+                    stride=patch_size,
+                    bias=conv_bias,
+                ),
+                norm_layer(dims[0]),
+            )
+            stem_stride = patch_size
+        else:
+            mid_chs = make_divisible(dims[0] // 2) if "tiered" in stem_type else dims[0]
+            self.stem = nn.Sequential(
+                nn.Conv2d(
+                    in_chans,
+                    mid_chs,
+                    kernel_size=3,
+                    stride=2,
+                    padding=1,
+                    bias=conv_bias,
+                ),
+                nn.Conv2d(
+                    mid_chs, dims[0], kernel_size=3, stride=2, padding=1, bias=conv_bias
+                ),
+                norm_layer(dims[0]),
+            )
+            stem_stride = 4
+
+        self.stages = nn.Sequential()
+        dp_rates = [
+            x.tolist()
+            for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)
+        ]
+        stages = []
+        prev_chs = dims[0]
+        curr_stride = stem_stride
+        dilation = 1
+        # 4 feature resolution stages, each consisting of multiple residual blocks
+        for i in range(4):
+            stride = 2 if curr_stride == 2 or i > 0 else 1
+            if curr_stride >= output_stride and stride > 1:
+                dilation *= stride
+                stride = 1
+            curr_stride *= stride
+            first_dilation = 1 if dilation in (1, 2) else 2
+            out_chs = dims[i]
+            stages.append(
+                ConvNeXtStage(
+                    prev_chs,
+                    out_chs,
+                    kernel_size=kernel_sizes[i],
+                    stride=stride,
+                    dilation=(first_dilation, dilation),
+                    depth=depths[i],
+                    drop_path_rates=dp_rates[i],
+                    ls_init_value=ls_init_value,
+                    conv_mlp=conv_mlp,
+                    conv_bias=conv_bias,
+                    use_grn=use_grn,
+                    act_layer=act_layer,
+                    norm_layer=norm_layer,
+                    norm_layer_cl=norm_layer_cl,
+                )
+            )
+            prev_chs = out_chs
+            # NOTE feature_info use currently assumes stage 0 == stride 1, rest are stride 2
+            self.feature_info += [
+                dict(num_chs=prev_chs, reduction=curr_stride, module=f"stages.{i}")
+            ]
+        self.stages = nn.ModuleList(stages)
+        self.mask_token = nn.Parameter(torch.zeros(1, self.embed_dim, 1, 1))
+        self.num_features = prev_chs
+        self.apply(self._init_weights)
+        self.set_grad_checkpointing(use_checkpoint)
+
+    def _init_weights(self, module):
+        if isinstance(module, nn.Conv2d):
+            trunc_normal_(module.weight, std=0.02)
+            if module.bias is not None:
+                nn.init.zeros_(module.bias)
+        elif isinstance(module, nn.Linear):
+            trunc_normal_(module.weight, std=0.02)
+            nn.init.zeros_(module.bias)
+
+    def forward(self, x, masks=None):
+        outs = []
+        x = self.stem(x)
+        if masks is not None:
+            masks = torch.nn.functional.interpolate(
+                masks.float(), size=x.shape[-2:], mode="nearest"
+            )
+            x = torch.where(masks.bool(), self.mask_token.to(x.dtype), x).contiguous()
+        for stage in self.stages:
+            xs = stage(x)
+            outs.extend([x.permute(0, 2, 3, 1).contiguous() for x in xs])
+            x = xs[-1]
+        return outs, [x.mean(dim=(1, 2)).unsqueeze(1).contiguous() for x in outs]
+
+    @torch.jit.ignore
+    def group_matcher(self, coarse=False):
+        return dict(
+            stem=r"^stem",
+            blocks=(
+                r"^stages\.(\d+)"
+                if coarse
+                else [
+                    (r"^stages\.(\d+)\.downsample", (0,)),  # blocks
+                    (r"^stages\.(\d+)\.blocks\.(\d+)", None),
+                    (r"^norm_pre", (99999,)),
+                ]
+            ),
+        )
+
+    @torch.jit.ignore
+    def set_grad_checkpointing(self, enable=True):
+        for s in self.stages:
+            s.grad_checkpointing = enable
+
+    def freeze(self) -> None:
+        for module in self.modules():
+            module.eval()
+        for parameters in self.parameters():
+            parameters.requires_grad = False
+
+    def get_params(self, lr, wd, ld, *args, **kwargs):
+        encoder_p, encoder_lr = get_parameter_groups(self, lr, wd, ld)
+        return encoder_p, encoder_lr
+
+    def no_weight_decay(self):
+        return {"mask_token"}
+
+    @classmethod
+    def build(cls, config):
+        obj = globals()[config["model"]["encoder"]["name"]](config)
+        return obj
+
+
+def checkpoint_filter_fn(state_dict, model):
+    """Remap FB checkpoints -> timm"""
+    if "head.norm.weight" in state_dict or "norm_pre.weight" in state_dict:
+        return state_dict  # non-FB checkpoint
+    if "model" in state_dict:
+        state_dict = state_dict["model"]
+
+    out_dict = {}
+    if "visual.trunk.stem.0.weight" in state_dict:
+        out_dict = {
+            k.replace("visual.trunk.", ""): v
+            for k, v in state_dict.items()
+            if k.startswith("visual.trunk.")
+        }
+        if "visual.head.proj.weight" in state_dict:
+            out_dict["head.fc.weight"] = state_dict["visual.head.proj.weight"]
+            out_dict["head.fc.bias"] = torch.zeros(
+                state_dict["visual.head.proj.weight"].shape[0]
+            )
+        elif "visual.head.mlp.fc1.weight" in state_dict:
+            out_dict["head.pre_logits.fc.weight"] = state_dict[
+                "visual.head.mlp.fc1.weight"
+            ]
+            out_dict["head.pre_logits.fc.bias"] = state_dict["visual.head.mlp.fc1.bias"]
+            out_dict["head.fc.weight"] = state_dict["visual.head.mlp.fc2.weight"]
+            out_dict["head.fc.bias"] = torch.zeros(
+                state_dict["visual.head.mlp.fc2.weight"].shape[0]
+            )
+        return out_dict
+
+    import re
+
+    for k, v in state_dict.items():
+        k = k.replace("downsample_layers.0.", "stem.")
+        k = re.sub(r"stages.([0-9]+).([0-9]+)", r"stages.\1.blocks.\2", k)
+        k = re.sub(
+            r"downsample_layers.([0-9]+).([0-9]+)", r"stages.\1.downsample.\2", k
+        )
+        k = k.replace("dwconv", "conv_dw")
+        k = k.replace("pwconv", "mlp.fc")
+        if "grn" in k:
+            k = k.replace("grn.beta", "mlp.grn.bias")
+            k = k.replace("grn.gamma", "mlp.grn.weight")
+            v = v.reshape(v.shape[-1])
+        k = k.replace("head.", "head.fc.")
+        if k.startswith("norm."):
+            k = k.replace("norm", "head.norm")
+        if v.ndim == 2 and "head" not in k:
+            model_shape = model.state_dict()[k].shape
+            v = v.reshape(model_shape)
+        out_dict[k] = v
+
+    return out_dict
+
+
+HF_URL = {
+    "convnext_xxlarge_pt": (
+        "laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-soup",
+        "open_clip_pytorch_model.bin",
+    ),
+    "convnext_large_pt": (
+        "laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft-soup",
+        "open_clip_pytorch_model.bin",
+    ),
+    "convnext_large": (
+        "timm/convnext_large_mlp.clip_laion2b_soup_ft_in12k_in1k_384",
+        "pytorch_model.bin",
+    ),
+}
diff --git a/flash3d/unidepth/models/backbones/convnext2.py b/flash3d/unidepth/models/backbones/convnext2.py
new file mode 100644
index 0000000000000000000000000000000000000000..793538172b043f683d0856ddd68e48355774ca46
--- /dev/null
+++ b/flash3d/unidepth/models/backbones/convnext2.py
@@ -0,0 +1,288 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from timm.models.layers import trunc_normal_, DropPath
+
+
+def get_num_layer_for_convnext_single(var_name, depths):
+    """
+    Each layer is assigned distinctive layer ids
+    """
+    if var_name.startswith("downsample_layers"):
+        stage_id = int(var_name.split(".")[1])
+        layer_id = sum(depths[:stage_id]) + 1
+        return layer_id
+
+    elif var_name.startswith("stages"):
+        stage_id = int(var_name.split(".")[1])
+        block_id = int(var_name.split(".")[2])
+        layer_id = sum(depths[:stage_id]) + block_id + 1
+        return layer_id
+
+    else:
+        return sum(depths) + 1
+
+
+def get_num_layer_for_convnext(var_name):
+    """
+    Divide [3, 3, 27, 3] layers into 12 groups; each group is three
+    consecutive blocks, including possible neighboring downsample layers;
+    adapted from https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py
+    """
+    num_max_layer = 12
+    if var_name.startswith("downsample_layers"):
+        stage_id = int(var_name.split(".")[1])
+        if stage_id == 0:
+            layer_id = 0
+        elif stage_id == 1 or stage_id == 2:
+            layer_id = stage_id + 1
+        elif stage_id == 3:
+            layer_id = 12
+        return layer_id
+
+    elif var_name.startswith("stages"):
+        stage_id = int(var_name.split(".")[1])
+        block_id = int(var_name.split(".")[2])
+        if stage_id == 0 or stage_id == 1:
+            layer_id = stage_id + 1
+        elif stage_id == 2:
+            layer_id = 3 + block_id // 3
+        elif stage_id == 3:
+            layer_id = 12
+        return layer_id
+    else:
+        return num_max_layer + 1
+
+
+def get_parameter_groups(model, lr, wd=1e-5, ld=0.9, skip_list=()):
+    parameter_group_names = {}
+    parameter_group_vars = {}
+    skip = {}
+    if skip_list is not None:
+        skip = skip_list
+    elif hasattr(model, "no_weight_decay"):
+        skip = model.no_weight_decay()
+    num_layers = 12  # sum(model.depths)
+    layer_scale = list(ld ** (num_layers + 1 - i) for i in range(num_layers + 2))
+    for name, param in model.named_parameters():
+        if not param.requires_grad:
+            continue  # frozen weights
+        if (
+            len(param.shape) == 1
+            or name.endswith(".bias")
+            or name in skip
+            or name.endswith(".gamma")
+            or name.endswith(".beta")
+        ):
+            group_name = "no_decay"
+            this_weight_decay = 0.0
+        else:
+            group_name = "decay"
+            this_weight_decay = wd
+
+        # layer_id = get_num_layer_for_convnext_single(name, model.depths)
+        layer_id = get_num_layer_for_convnext(name)
+        group_name = "layer_%d_%s" % (layer_id, group_name)
+
+        if group_name not in parameter_group_names:
+            scale = layer_scale[layer_id]
+            cur_lr = lr * scale
+
+            parameter_group_names[group_name] = {
+                "weight_decay": this_weight_decay,
+                "params": [],
+                "lr_scale": scale,
+                "lr": cur_lr,
+            }
+            parameter_group_vars[group_name] = {
+                "weight_decay": this_weight_decay,
+                "params": [],
+                "lr_scale": scale,
+                "lr": cur_lr,
+            }
+        parameter_group_vars[group_name]["params"].append(param)
+        parameter_group_names[group_name]["params"].append(name)
+    # if is_main_process():
+    # print("Param groups = %s" % json.dumps(parameter_group_names, indent=2))
+    return list(parameter_group_vars.values()), [
+        v["lr"] for k, v in parameter_group_vars.items()
+    ]
+
+
+class LayerNorm(nn.Module):
+    """LayerNorm that supports two data formats: channels_last (default) or channels_first.
+    The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
+    shape (batch_size, height, width, channels) while channels_first corresponds to inputs
+    with shape (batch_size, channels, height, width).
+    """
+
+    def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
+        super().__init__()
+        self.weight = nn.Parameter(torch.ones(normalized_shape))
+        self.bias = nn.Parameter(torch.zeros(normalized_shape))
+        self.eps = eps
+        self.data_format = data_format
+        if self.data_format not in ["channels_last", "channels_first"]:
+            raise NotImplementedError
+        self.normalized_shape = (normalized_shape,)
+
+    def forward(self, x):
+        if self.data_format == "channels_last":
+            return F.layer_norm(
+                x, self.normalized_shape, self.weight, self.bias, self.eps
+            )
+        elif self.data_format == "channels_first":
+            u = x.mean(1, keepdim=True)
+            s = (x - u).pow(2).mean(1, keepdim=True)
+            x = (x - u) / torch.sqrt(s + self.eps)
+            x = self.weight[:, None, None] * x + self.bias[:, None, None]
+            return x
+
+
+class GRN(nn.Module):
+    """GRN (Global Response Normalization) layer"""
+
+    def __init__(self, dim):
+        super().__init__()
+        self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim))
+        self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim))
+
+    def forward(self, x):
+        Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True)
+        Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
+        return self.gamma * (x * Nx) + self.beta + x
+
+
+class Block(nn.Module):
+    """ConvNeXtV2 Block.
+
+    Args:
+        dim (int): Number of input channels.
+        drop_path (float): Stochastic depth rate. Default: 0.0
+    """
+
+    def __init__(self, dim, drop_path=0.0, mult=4, use_checkpoint=False):
+        super().__init__()
+        self.dwconv = nn.Conv2d(
+            dim, dim, kernel_size=7, padding=3, groups=dim
+        )  # depthwise conv
+        self.norm = LayerNorm(dim, eps=1e-6)
+        self.pwconv1 = nn.Linear(
+            dim, mult * dim
+        )  # pointwise/1x1 convs, implemented with linear layers
+        self.act = nn.GELU()
+        self.grn = GRN(mult * dim)
+        self.pwconv2 = nn.Linear(mult * dim, dim)
+        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+        self.use_checkpoint = use_checkpoint
+
+    def forward(self, x):
+        input = x
+        x = self.dwconv(x)
+        x = x.permute(0, 2, 3, 1)  # (N, C, H, W) -> (N, H, W, C)
+        x = self.norm(x)
+        x = self.pwconv1(x)
+        x = self.act(x)
+        x = self.grn(x)
+        x = self.pwconv2(x)
+        x = x.permute(0, 3, 1, 2)  # (N, H, W, C) -> (N, C, H, W)
+
+        x = input + self.drop_path(x)
+        return x
+
+
+class ConvNeXtV2(nn.Module):
+    """ConvNeXt V2
+
+    Args:
+        in_chans (int): Number of input image channels. Default: 3
+        num_classes (int): Number of classes for classification head. Default: 1000
+        depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
+        dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768]
+        drop_path_rate (float): Stochastic depth rate. Default: 0.
+        head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
+    """
+
+    def __init__(
+        self,
+        in_chans=3,
+        depths=[3, 3, 9, 3],
+        dims=96,
+        drop_path_rate=0.0,
+        output_idx=[],
+        use_checkpoint=False,
+    ):
+        super().__init__()
+        self.num_layers = len(depths)
+        self.depths = output_idx
+        self.embed_dims = [
+            int(dim) for i, dim in enumerate(dims) for _ in range(depths[i])
+        ]
+        self.embed_dim = dims[0]
+
+        self.downsample_layers = (
+            nn.ModuleList()
+        )  # stem and 3 intermediate downsampling conv layers
+        stem = nn.Sequential(
+            nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
+            LayerNorm(dims[0], eps=1e-6, data_format="channels_first"),
+        )
+        self.downsample_layers.append(stem)
+        for i in range(3):
+            downsample_layer = nn.Sequential(
+                LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
+                nn.Conv2d(dims[i], dims[i + 1], kernel_size=2, stride=2),
+            )
+            self.downsample_layers.append(downsample_layer)
+
+        self.stages = (
+            nn.ModuleList()
+        )  # 4 feature resolution stages, each consisting of multiple residual blocks
+        self.out_norms = nn.ModuleList()
+        dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
+        cur = 0
+        for i in range(4):
+            stage = nn.ModuleList(
+                [
+                    Block(
+                        dim=dims[i],
+                        drop_path=dp_rates[cur + j],
+                        use_checkpoint=use_checkpoint,
+                    )
+                    for j in range(depths[i])
+                ]
+            )
+            self.stages.append(stage)
+            cur += depths[i]
+
+        self.apply(self._init_weights)
+
+    def _init_weights(self, m):
+        if isinstance(m, (nn.Conv2d, nn.Linear)):
+            trunc_normal_(m.weight, std=0.02)
+            nn.init.constant_(m.bias, 0)
+
+    def forward(self, x):
+        outs = []
+        for i in range(4):
+            x = self.downsample_layers[i](x)
+            for stage in self.stages[i]:
+                x = stage(x)
+                outs.append(x.permute(0, 2, 3, 1))
+        cls_tokens = [x.mean(dim=(1, 2)).unsqueeze(1).contiguous() for x in outs]
+        return outs, cls_tokens
+
+    def get_params(self, lr, wd, ld, *args, **kwargs):
+        encoder_p, encoder_lr = get_parameter_groups(self, lr, wd, ld)
+        return encoder_p, encoder_lr
+
+    def freeze(self) -> None:
+        for module in self.modules():
+            module.eval()
+        for parameters in self.parameters():
+            parameters.requires_grad = False
+
+    @classmethod
+    def build(cls, config):
+        obj = globals()[config["model"]["encoder"]["name"]](config)
+        return obj
diff --git a/flash3d/unidepth/models/backbones/dinov2.py b/flash3d/unidepth/models/backbones/dinov2.py
new file mode 100644
index 0000000000000000000000000000000000000000..a9c0a25e2b5091d6eb435fb56f68cf292bebebf4
--- /dev/null
+++ b/flash3d/unidepth/models/backbones/dinov2.py
@@ -0,0 +1,552 @@
+from functools import partial
+import math
+import logging
+from typing import Sequence, Tuple, Union, Callable
+
+import torch
+import torch.nn as nn
+from torch.utils.checkpoint import checkpoint
+from torch.nn.init import trunc_normal_
+
+from .metadinov2 import (
+    Mlp,
+    PatchEmbed,
+    SwiGLUFFNFused,
+    MemEffAttention,
+    NestedTensorBlock as Block,
+)
+
+
+logger = logging.getLogger("dinov2")
+
+
+def named_apply(
+    fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False
+) -> nn.Module:
+    if not depth_first and include_root:
+        fn(module=module, name=name)
+    for child_name, child_module in module.named_children():
+        child_name = ".".join((name, child_name)) if name else child_name
+        named_apply(
+            fn=fn,
+            module=child_module,
+            name=child_name,
+            depth_first=depth_first,
+            include_root=True,
+        )
+    if depth_first and include_root:
+        fn(module=module, name=name)
+    return module
+
+
+def get_parameter_groups(model, lr, wd=1e-5, ld=0.9, skip_list=()):
+    parameter_group_names = {}
+    parameter_group_vars = {}
+    skip = {}
+    if skip_list is not None:
+        skip = skip_list
+    elif hasattr(model, "no_weight_decay"):
+        skip = model.no_weight_decay()
+
+    num_layers = model.n_blocks
+    layer_scale = list(ld ** (num_layers - i) for i in range(num_layers))
+
+    for name, param in model.named_parameters():
+        if not param.requires_grad:
+            continue
+
+        if len(param.shape) == 1:  # norm
+            group_name = "no_decay"
+            this_wd = 0.0
+        # layer scale, bias beta?
+        elif (
+            name in skip
+            or name.endswith(".gamma")
+            or name.endswith(".beta")
+            or name.endswith(".bias")
+        ):
+            group_name = "no_decay"
+            this_wd = 0.0
+        elif "cls_token" in name or "pos_embed" in name or "mask_token" in name:
+            group_name = "no_decay"
+            this_wd = 0.0
+        else:
+            group_name = "decay"
+            this_wd = wd
+
+        if name.startswith("blocks"):
+            layer_id = int(name.split(".")[1])
+        elif name.startswith("patch_embed"):
+            layer_id = 0
+        else:
+            layer_id = 0
+
+        group_name = f"layer_{layer_id}_{group_name}"
+
+        if group_name not in parameter_group_names:
+            scale = layer_scale[layer_id]
+            cur_lr = lr * scale
+
+            parameter_group_names[group_name] = {
+                "weight_decay": this_wd,
+                "params": [],
+                "lr_init": cur_lr,
+                "lr_base": lr,
+                "lr": cur_lr,
+            }
+            parameter_group_vars[group_name] = {
+                "weight_decay": this_wd,
+                "params": [],
+                "lr_init": cur_lr,
+                "lr_base": lr,
+                "lr": cur_lr,
+            }
+        parameter_group_vars[group_name]["params"].append(param)
+        parameter_group_names[group_name]["params"].append(name)
+    return list(parameter_group_vars.values()), [
+        v["lr"] for k, v in parameter_group_vars.items()
+    ]
+
+
+class BlockChunk(nn.ModuleList):
+    def forward(self, x):
+        for b in self:
+            x = b(x)
+        return x
+
+
+class DinoVisionTransformer(nn.Module):
+    def __init__(
+        self,
+        img_size=224,
+        patch_size=16,
+        in_chans=3,
+        embed_dim=768,
+        depth=12,
+        num_heads=12,
+        mlp_ratio=4.0,
+        qkv_bias=True,
+        ffn_bias=True,
+        proj_bias=True,
+        drop_path_rate=0.0,
+        drop_path_uniform=False,
+        init_values=None,  # for layerscale: None or 0 => no layerscale
+        embed_layer=PatchEmbed,
+        act_layer=nn.GELU,
+        block_fn=Block,
+        ffn_layer="mlp",
+        block_chunks=1,
+        output_idx=[5, 12, 18, 24],
+        checkpoint: bool = False,
+        num_register_tokens=0,
+        interpolate_antialias=False,
+        interpolate_offset=0.1,
+    ):
+        """
+        Args:
+            img_size (int, tuple): input image size
+            patch_size (int, tuple): patch size
+            in_chans (int): number of input channels
+            embed_dim (int): embedding dimension
+            depth (int): depth of transformer
+            num_heads (int): number of attention heads
+            mlp_ratio (int): ratio of mlp hidden dim to embedding dim
+            qkv_bias (bool): enable bias for qkv if True
+            proj_bias (bool): enable bias for proj in attn if True
+            ffn_bias (bool): enable bias for ffn if True
+            drop_path_rate (float): stochastic depth rate
+            drop_path_uniform (bool): apply uniform drop rate across blocks
+            weight_init (str): weight init scheme
+            init_values (float): layer-scale init values
+            embed_layer (nn.Module): patch embedding layer
+            act_layer (nn.Module): MLP activation layer
+            block_fn (nn.Module): transformer block class
+            ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
+            block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
+        """
+        super().__init__()
+        norm_layer = partial(nn.LayerNorm, eps=1e-6)
+
+        self.num_features = self.embed_dim = (
+            embed_dim  # num_features for consistency with other models
+        )
+        self.embed_dims = [embed_dim] * output_idx[-1]
+        self.num_tokens = 1
+        self.n_blocks = depth
+        self.num_heads = num_heads
+        self.patch_size = patch_size
+        self.depths = output_idx
+        self.checkpoint = checkpoint
+        self.num_register_tokens = num_register_tokens
+        self.interpolate_antialias = interpolate_antialias
+        self.interpolate_offset = interpolate_offset
+
+        self.patch_embed = embed_layer(
+            img_size=img_size,
+            patch_size=patch_size,
+            in_chans=in_chans,
+            embed_dim=embed_dim,
+        )
+        num_patches = self.patch_embed.num_patches
+
+        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
+        self.pos_embed = nn.Parameter(
+            torch.zeros(1, num_patches + self.num_tokens, embed_dim)
+        )
+        assert num_register_tokens >= 0
+        self.register_tokens = nn.Parameter(
+            torch.zeros(1, max(1, num_register_tokens), embed_dim)
+        )
+
+        if drop_path_uniform is True:
+            dpr = [drop_path_rate] * depth
+        else:
+            dpr = [
+                x.item() for x in torch.linspace(0, drop_path_rate, depth)
+            ]  # stochastic depth decay rule
+
+        if ffn_layer == "mlp":
+            logger.info("using MLP layer as FFN")
+            ffn_layer = Mlp
+        elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
+            logger.info("using SwiGLU layer as FFN")
+            ffn_layer = SwiGLUFFNFused
+        elif ffn_layer == "identity":
+            logger.info("using Identity layer as FFN")
+
+            def f(*args, **kwargs):
+                return nn.Identity()
+
+            ffn_layer = f
+        else:
+            raise NotImplementedError
+
+        blocks_list = [
+            block_fn(
+                dim=embed_dim,
+                num_heads=num_heads,
+                mlp_ratio=mlp_ratio,
+                qkv_bias=qkv_bias,
+                proj_bias=proj_bias,
+                ffn_bias=ffn_bias,
+                drop_path=dpr[i],
+                norm_layer=norm_layer,
+                act_layer=act_layer,
+                ffn_layer=ffn_layer,
+                init_values=init_values,
+            )
+            for i in range(depth)
+        ]
+        if block_chunks > 0:
+            self.chunked_blocks = True
+            chunked_blocks = []
+            chunksize = depth // block_chunks
+            for i in range(0, depth, chunksize):
+                # this is to keep the block index consistent if we chunk the block list
+                chunked_blocks.append(
+                    [nn.Identity()] * i + blocks_list[i : i + chunksize]
+                )
+            self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
+        else:
+            self.chunked_blocks = False
+            self.blocks = nn.ModuleList(blocks_list)
+
+        # self.norm = norm_layer(embed_dim)
+        self.head = nn.Identity()
+        self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
+        self.init_weights()
+
+    def init_weights(self):
+        trunc_normal_(self.pos_embed, std=0.02)
+        nn.init.normal_(self.cls_token, std=1e-6)
+        if self.num_register_tokens:
+            nn.init.normal_(self.register_tokens, std=1e-6)
+        named_apply(init_weights_vit_timm, self)
+
+    def interpolate_pos_encoding(self, x, w, h):
+        previous_dtype = x.dtype
+        npatch = x.shape[1] - 1
+        N = self.pos_embed.shape[1] - 1
+        if npatch == N and w == h:
+            return self.pos_embed
+        pos_embed = self.pos_embed.float()
+        class_pos_embed = pos_embed[:, 0]
+        patch_pos_embed = pos_embed[:, 1:]
+        dim = x.shape[-1]
+        w0 = w // self.patch_size
+        h0 = h // self.patch_size
+        # we add a small number to avoid floating point error in the interpolation
+        # see discussion at https://github.com/facebookresearch/dino/issues/8
+        w0, h0 = w0 + self.interpolate_offset, h0 + self.interpolate_offset
+
+        patch_pos_embed = nn.functional.interpolate(
+            patch_pos_embed.reshape(
+                1, int(math.sqrt(N)), int(math.sqrt(N)), dim
+            ).permute(0, 3, 1, 2),
+            scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
+            mode="bicubic",
+            antialias=self.interpolate_antialias,
+        )
+
+        assert (
+            int(w0) == patch_pos_embed.shape[-2]
+            and int(h0) == patch_pos_embed.shape[-1]
+        )
+        patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
+        return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(
+            previous_dtype
+        )
+
+    def prepare_tokens_with_masks(self, x, masks=None):
+        B, nc, w, h = x.shape
+        x = self.patch_embed(x)
+        if masks is not None:
+            masks = masks.bool().view(B, -1, 1)
+            x = torch.where(masks, self.mask_token.to(x.dtype).unsqueeze(0), x)
+
+        x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
+        x = x + self.interpolate_pos_encoding(x, w, h)
+
+        if self.num_register_tokens:
+            x = torch.cat(
+                (x[:, :1], self.register_tokens.expand(x.shape[0], -1, -1), x[:, 1:]),
+                dim=1,
+            )
+
+        return x
+
+    def forward_features(self, x, masks=None):
+        # if isinstance(x, list):
+        #     return self.forward_features_list(x, masks)
+        shapes = [val // self.patch_size for val in x.shape[-2:]]
+        batch_size = x.shape[0]
+        x = self.prepare_tokens_with_masks(x, masks)
+        output, cls_tokens = [], []
+
+        for i, blk in enumerate(self.blocks):
+            x = blk(x)
+            cls_token = x[:, :1]
+
+            out = x[:, self.num_register_tokens + 1 :]
+            # was like this before, add cls to dense features
+            # out = out + cls_token
+
+            output.append(out.view(batch_size, *shapes, -1))
+            cls_tokens.append(cls_token)
+        return (output, cls_tokens)
+
+    def get_params(self, lr, wd, ld, *args, **kwargs):
+        encoder_p, encoder_lr = get_parameter_groups(self, lr, wd, ld)
+        return encoder_p, encoder_lr
+
+    def freeze(self) -> None:
+        for module in self.modules():
+            module.eval()
+        for parameters in self.parameters():
+            parameters.requires_grad = False
+
+    def train(self, mode=True):
+        super().train(mode)
+        self.mask_token.requires_grad = False
+        self.register_tokens.requires_grad = False
+
+    def forward(self, *args, is_training=False, **kwargs):
+        ret = self.forward_features(*args, **kwargs)
+        return ret
+
+
+def init_weights_vit_timm(module: nn.Module, name: str = ""):
+    """ViT weight initialization, original timm impl (for reproducibility)"""
+    if isinstance(module, nn.Linear):
+        trunc_normal_(module.weight, std=0.02)
+        if module.bias is not None:
+            nn.init.zeros_(module.bias)
+
+
+def vit_small(patch_size=16, **kwargs):
+    model = DinoVisionTransformer(
+        patch_size=patch_size,
+        embed_dim=384,
+        depth=12,
+        num_heads=6,
+        mlp_ratio=4,
+        block_fn=partial(Block, attn_class=MemEffAttention),
+        **kwargs,
+    )
+    return model
+
+
+def vit_base(patch_size=16, num_register_tokens=0, **kwargs):
+    model = DinoVisionTransformer(
+        patch_size=patch_size,
+        embed_dim=768,
+        depth=12,
+        num_heads=12,
+        mlp_ratio=4,
+        num_register_tokens=num_register_tokens,
+        block_fn=partial(Block, attn_class=MemEffAttention),
+        **kwargs,
+    )
+    return model
+
+
+def vit_large(patch_size=16, num_register_tokens=0, **kwargs):
+    model = DinoVisionTransformer(
+        patch_size=patch_size,
+        embed_dim=1024,
+        depth=24,
+        num_heads=16,
+        mlp_ratio=4,
+        num_register_tokens=num_register_tokens,
+        block_fn=partial(Block, attn_class=MemEffAttention),
+        **kwargs,
+    )
+    return model
+
+
+def vit_giant2(patch_size=16, **kwargs):
+    """
+    Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
+    """
+    model = DinoVisionTransformer(
+        patch_size=patch_size,
+        embed_dim=1536,
+        depth=40,
+        num_heads=24,
+        mlp_ratio=4,
+        block_fn=partial(Block, attn_class=MemEffAttention),
+        **kwargs,
+    )
+    return model
+
+
+import torch
+import torch.nn as nn
+
+
+dependencies = ["torch"]
+
+
+_DINOV2_BASE_URL = "https://dl.fbaipublicfiles.com/dinov2"
+
+
+def _make_dinov2_model_name(arch_name: str, patch_size: int) -> str:
+    compact_arch_name = arch_name.replace("_", "")[:4]
+    return f"dinov2_{compact_arch_name}{patch_size}"
+
+
+def _make_dinov2_model(
+    *,
+    arch_name: str = "vit_large",
+    img_size: int = 518,
+    patch_size: int = 14,
+    init_values: float = 1.0,
+    ffn_layer: str = "mlp",
+    block_chunks: int = 0,
+    pretrained: str = "",
+    output_idx: Sequence[int] = [],
+    num_register_tokens: int = 0,
+    drop_path_rate: float = 0.0,
+    **kwargs,
+):
+    model_name = _make_dinov2_model_name(arch_name, patch_size)
+    print("Instantiate:", model_name)
+
+    vit_kwargs = dict(
+        img_size=img_size,
+        patch_size=patch_size,
+        init_values=init_values,
+        ffn_layer=ffn_layer,
+        block_chunks=block_chunks,
+        output_idx=output_idx,
+        drop_path_rate=drop_path_rate,
+        num_register_tokens=num_register_tokens,
+    )
+    vit_kwargs.update(**kwargs)
+    model = eval(arch_name)(**vit_kwargs)
+    if pretrained == "":
+        url = _DINOV2_BASE_URL + f"/{model_name}/{model_name}"
+        if num_register_tokens > 0:
+            url += "_reg4"
+        url += "_pretrain.pth"
+        state_dict = torch.hub.load_state_dict_from_url(
+            url, map_location="cpu", progress=False
+        )
+        info = model.load_state_dict(state_dict, strict=False)
+        print(info)
+    elif pretrained is not None:
+        state_dict = torch.load(pretrained, map_location="cpu")
+        info = model.load_state_dict(state_dict, strict=False)
+        print(f"loading from {pretrained} with:", info)
+    return model
+
+    # def forward_features_list(self, x_list, masks_list):
+    #     x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
+    #     for blk in self.blocks:
+    #         x = blk(x)
+
+    #     all_x = x
+    #     output = []
+    #     for x, masks in zip(all_x, masks_list):
+    #         x_norm = self.norm(x)
+    #         output.append(
+    #             {
+    #                 "x_norm_clstoken": x_norm[:, 0],
+    #                 "x_norm_patchtokens": x_norm[:, 1:],
+    #                 "x_prenorm": x,
+    #                 "masks": masks,
+    #             }
+    #         )
+    #     return output
+
+    # def _get_intermediate_layers_not_chunked(self, x, n=1):
+    #     x = self.prepare_tokens_with_masks(x)
+    #     # If n is an int, take the n last blocks. If it's a list, take them
+    #     output, total_block_len = [], len(self.blocks)
+    #     blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
+    #     for i, blk in enumerate(self.blocks):
+    #         x = blk(x)
+    #         if i in blocks_to_take:
+    #             output.append(x)
+    #     assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
+    #     return output
+
+    # def _get_intermediate_layers_chunked(self, x, n=1):
+    #     x = self.prepare_tokens_with_masks(x)
+    #     output, i, total_block_len = [], 0, len(self.blocks[-1])
+    #     # If n is an int, take the n last blocks. If it's a list, take them
+    #     blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
+    #     for block_chunk in self.blocks:
+    #         for blk in block_chunk[i:]:  # Passing the nn.Identity()
+    #             x = blk(x)
+    #             if i in blocks_to_take:
+    #                 output.append(x)
+    #             i += 1
+    #     assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
+    #     return output
+
+    # def get_intermediate_layers(
+    #     self,
+    #     x: torch.Tensor,
+    #     n: Union[int, Sequence] = 1,  # Layers or n last layers to take
+    #     reshape: bool = False,
+    #     return_class_token: bool = False,
+    #     norm=True,
+    # ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
+    #     if self.chunked_blocks:
+    #         outputs = self._get_intermediate_layers_chunked(x, n)
+    #     else:
+    #         outputs = self._get_intermediate_layers_not_chunked(x, n)
+    #     if norm:
+    #         outputs = [self.norm(out) for out in outputs]
+    #     class_tokens = [out[:, 0] for out in outputs]
+    #     outputs = [out[:, 1:] for out in outputs]
+    #     if reshape:
+    #         B, _, w, h = x.shape
+    #         outputs = [
+    #             out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
+    #             for out in outputs
+    #         ]
+    #     if return_class_token:
+    #         return tuple(zip(outputs, class_tokens))
+    #     return tuple(outputs)
diff --git a/flash3d/unidepth/models/backbones/metadinov2/__init__.py b/flash3d/unidepth/models/backbones/metadinov2/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..31f196aacac5be8a7c537a3dfa8f97084671b466
--- /dev/null
+++ b/flash3d/unidepth/models/backbones/metadinov2/__init__.py
@@ -0,0 +1,12 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from .dino_head import DINOHead
+from .mlp import Mlp
+from .patch_embed import PatchEmbed
+from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
+from .block import NestedTensorBlock
+from .attention import MemEffAttention
diff --git a/flash3d/unidepth/models/backbones/metadinov2/attention.py b/flash3d/unidepth/models/backbones/metadinov2/attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..38efc12df276fff129441805d260f9a8107a06d6
--- /dev/null
+++ b/flash3d/unidepth/models/backbones/metadinov2/attention.py
@@ -0,0 +1,85 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# References:
+#   https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+#   https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
+
+import logging
+
+from torch import Tensor
+import torch.nn as nn
+
+
+logger = logging.getLogger("dinov2")
+
+
+try:
+    from xformers.ops import memory_efficient_attention, unbind, fmha
+
+    XFORMERS_AVAILABLE = True
+except ImportError:
+    logger.warning("xFormers not available")
+    XFORMERS_AVAILABLE = False
+
+
+class Attention(nn.Module):
+    def __init__(
+        self,
+        dim: int,
+        num_heads: int = 8,
+        qkv_bias: bool = False,
+        proj_bias: bool = True,
+        attn_drop: float = 0.0,
+        proj_drop: float = 0.0,
+    ) -> None:
+        super().__init__()
+        self.num_heads = num_heads
+        head_dim = dim // num_heads
+        self.scale = head_dim**-0.5
+
+        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+        self.attn_drop = nn.Dropout(attn_drop)
+        self.proj = nn.Linear(dim, dim, bias=proj_bias)
+        self.proj_drop = nn.Dropout(proj_drop)
+
+    def forward(self, x: Tensor) -> Tensor:
+        B, N, C = x.shape
+        qkv = (
+            self.qkv(x)
+            .reshape(B, N, 3, self.num_heads, C // self.num_heads)
+            .permute(2, 0, 3, 1, 4)
+        )
+
+        q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
+        attn = q @ k.transpose(-2, -1)
+
+        attn = attn.softmax(dim=-1)
+        attn = self.attn_drop(attn)
+
+        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
+        x = self.proj(x)
+        x = self.proj_drop(x)
+        return x
+
+
+class MemEffAttention(Attention):
+    def forward(self, x: Tensor, attn_bias=None) -> Tensor:
+        if not XFORMERS_AVAILABLE:
+            assert attn_bias is None, "xFormers is required for nested tensors usage"
+            return super().forward(x)
+
+        B, N, C = x.shape
+        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
+
+        q, k, v = unbind(qkv, 2)
+
+        x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
+        x = x.reshape([B, N, C])
+
+        x = self.proj(x)
+        x = self.proj_drop(x)
+        return x
diff --git a/flash3d/unidepth/models/backbones/metadinov2/block.py b/flash3d/unidepth/models/backbones/metadinov2/block.py
new file mode 100644
index 0000000000000000000000000000000000000000..c568363443383aa107c07ec65b4bd2ec901575c0
--- /dev/null
+++ b/flash3d/unidepth/models/backbones/metadinov2/block.py
@@ -0,0 +1,284 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# References:
+#   https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+#   https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
+
+import logging
+from typing import Callable, List, Any, Tuple, Dict
+
+import torch
+import torch.nn as nn
+
+from .attention import Attention, MemEffAttention
+from .drop_path import DropPath
+from .layer_scale import LayerScale
+from .mlp import Mlp
+
+
+logger = logging.getLogger("dinov2")
+
+
+try:
+    from xformers.ops import fmha
+    from xformers.ops import scaled_index_add, index_select_cat
+
+    XFORMERS_AVAILABLE = True
+except ImportError:
+    logger.warning("xFormers not available")
+    XFORMERS_AVAILABLE = False
+
+
+class Block(nn.Module):
+    def __init__(
+        self,
+        dim: int,
+        num_heads: int,
+        mlp_ratio: float = 4.0,
+        qkv_bias: bool = False,
+        proj_bias: bool = True,
+        ffn_bias: bool = True,
+        drop: float = 0.0,
+        attn_drop: float = 0.0,
+        init_values=None,
+        drop_path: float = 0.0,
+        act_layer: Callable[..., nn.Module] = nn.GELU,
+        norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
+        attn_class: Callable[..., nn.Module] = Attention,
+        ffn_layer: Callable[..., nn.Module] = Mlp,
+    ) -> None:
+        super().__init__()
+        # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
+        self.norm1 = norm_layer(dim)
+        self.attn = attn_class(
+            dim,
+            num_heads=num_heads,
+            qkv_bias=qkv_bias,
+            proj_bias=proj_bias,
+            attn_drop=attn_drop,
+            proj_drop=drop,
+        )
+        self.ls1 = (
+            LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
+        )
+        self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+
+        self.norm2 = norm_layer(dim)
+        mlp_hidden_dim = int(dim * mlp_ratio)
+        self.mlp = ffn_layer(
+            in_features=dim,
+            hidden_features=mlp_hidden_dim,
+            act_layer=act_layer,
+            drop=drop,
+            bias=ffn_bias,
+        )
+        self.ls2 = (
+            LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
+        )
+        self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+
+        self.sample_drop_ratio = drop_path
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        def attn_residual_func(x: torch.Tensor) -> torch.Tensor:
+            return self.ls1(self.attn(self.norm1(x)))
+
+        def ffn_residual_func(x: torch.Tensor) -> torch.Tensor:
+            return self.ls2(self.mlp(self.norm2(x)))
+
+        if self.training and self.sample_drop_ratio > 0.1:
+            # the overhead is compensated only for a drop path rate larger than 0.1
+            x = drop_add_residual_stochastic_depth(
+                x,
+                residual_func=attn_residual_func,
+                sample_drop_ratio=self.sample_drop_ratio,
+            )
+            x = drop_add_residual_stochastic_depth(
+                x,
+                residual_func=ffn_residual_func,
+                sample_drop_ratio=self.sample_drop_ratio,
+            )
+        elif self.training and self.sample_drop_ratio > 0.0:
+            x = x + self.drop_path1(attn_residual_func(x))
+            x = x + self.drop_path1(ffn_residual_func(x))  # FIXME: drop_path2
+        else:
+            x = x + attn_residual_func(x)
+            x = x + ffn_residual_func(x)
+        return x
+
+
+def drop_add_residual_stochastic_depth(
+    x: torch.Tensor,
+    residual_func: Callable[[torch.Tensor], torch.Tensor],
+    sample_drop_ratio: float = 0.0,
+) -> torch.Tensor:
+    # 1) extract subset using permutation
+    b, n, d = x.shape
+    sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
+    brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
+    x_subset = x[brange]
+
+    # 2) apply residual_func to get residual
+    residual = residual_func(x_subset)
+
+    x_flat = x.flatten(1)
+    residual = residual.flatten(1)
+
+    residual_scale_factor = b / sample_subset_size
+
+    # 3) add the residual
+    x_plus_residual = torch.index_add(
+        x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor
+    )
+    return x_plus_residual.view_as(x)
+
+
+def get_branges_scales(x, sample_drop_ratio=0.0):
+    b, n, d = x.shape
+    sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
+    brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
+    residual_scale_factor = b / sample_subset_size
+    return brange, residual_scale_factor
+
+
+def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
+    if scaling_vector is None:
+        x_flat = x.flatten(1)
+        residual = residual.flatten(1)
+        x_plus_residual = torch.index_add(
+            x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor
+        )
+    else:
+        x_plus_residual = scaled_index_add(
+            x,
+            brange,
+            residual.to(dtype=x.dtype),
+            scaling=scaling_vector,
+            alpha=residual_scale_factor,
+        )
+    return x_plus_residual
+
+
+attn_bias_cache: Dict[Tuple, Any] = {}
+
+
+def get_attn_bias_and_cat(x_list, branges=None):
+    """
+    this will perform the index select, cat the tensors, and provide the attn_bias from cache
+    """
+    batch_sizes = (
+        [b.shape[0] for b in branges]
+        if branges is not None
+        else [x.shape[0] for x in x_list]
+    )
+    all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
+    if all_shapes not in attn_bias_cache.keys():
+        seqlens = []
+        for b, x in zip(batch_sizes, x_list):
+            for _ in range(b):
+                seqlens.append(x.shape[1])
+        attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
+        attn_bias._batch_sizes = batch_sizes
+        attn_bias_cache[all_shapes] = attn_bias
+
+    if branges is not None:
+        cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(
+            1, -1, x_list[0].shape[-1]
+        )
+    else:
+        tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
+        cat_tensors = torch.cat(tensors_bs1, dim=1)
+
+    return attn_bias_cache[all_shapes], cat_tensors
+
+
+def drop_add_residual_stochastic_depth_list(
+    x_list: List[torch.Tensor],
+    residual_func: Callable[[torch.Tensor, Any], torch.Tensor],
+    sample_drop_ratio: float = 0.0,
+    scaling_vector=None,
+) -> torch.Tensor:
+    # 1) generate random set of indices for dropping samples in the batch
+    branges_scales = [
+        get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list
+    ]
+    branges = [s[0] for s in branges_scales]
+    residual_scale_factors = [s[1] for s in branges_scales]
+
+    # 2) get attention bias and index+concat the tensors
+    attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
+
+    # 3) apply residual_func to get residual, and split the result
+    residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias))  # type: ignore
+
+    outputs = []
+    for x, brange, residual, residual_scale_factor in zip(
+        x_list, branges, residual_list, residual_scale_factors
+    ):
+        outputs.append(
+            add_residual(
+                x, brange, residual, residual_scale_factor, scaling_vector
+            ).view_as(x)
+        )
+    return outputs
+
+
+class NestedTensorBlock(Block):
+    def forward_nested(self, x_list: List[torch.Tensor]) -> List[torch.Tensor]:
+        """
+        x_list contains a list of tensors to nest together and run
+        """
+        assert isinstance(self.attn, MemEffAttention)
+
+        if self.training and self.sample_drop_ratio > 0.0:
+
+            def attn_residual_func(x: torch.Tensor, attn_bias=None) -> torch.Tensor:
+                return self.attn(self.norm1(x), attn_bias=attn_bias)
+
+            def ffn_residual_func(x: torch.Tensor, attn_bias=None) -> torch.Tensor:
+                return self.mlp(self.norm2(x))
+
+            x_list = drop_add_residual_stochastic_depth_list(
+                x_list,
+                residual_func=attn_residual_func,
+                sample_drop_ratio=self.sample_drop_ratio,
+                scaling_vector=(
+                    self.ls1.gamma if isinstance(self.ls1, LayerScale) else None
+                ),
+            )
+            x_list = drop_add_residual_stochastic_depth_list(
+                x_list,
+                residual_func=ffn_residual_func,
+                sample_drop_ratio=self.sample_drop_ratio,
+                scaling_vector=(
+                    self.ls2.gamma if isinstance(self.ls1, LayerScale) else None
+                ),
+            )
+            return x_list
+        else:
+
+            def attn_residual_func(x: torch.Tensor, attn_bias=None) -> torch.Tensor:
+                return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
+
+            def ffn_residual_func(x: torch.Tensor, attn_bias=None) -> torch.Tensor:
+                return self.ls2(self.mlp(self.norm2(x)))
+
+            attn_bias, x = get_attn_bias_and_cat(x_list)
+            x = x + attn_residual_func(x, attn_bias=attn_bias)
+            x = x + ffn_residual_func(x)
+            return attn_bias.split(x)
+
+    def forward(self, x_or_x_list):
+        if isinstance(x_or_x_list, torch.Tensor):
+            return super().forward(x_or_x_list)
+        elif isinstance(x_or_x_list, list):
+            assert (
+                XFORMERS_AVAILABLE
+            ), "Please install xFormers for nested tensors usage"
+            return self.forward_nested(x_or_x_list)
+        else:
+            raise AssertionError
diff --git a/flash3d/unidepth/models/backbones/metadinov2/dino_head.py b/flash3d/unidepth/models/backbones/metadinov2/dino_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..1147dd3a3c046aee8d427b42b1055f38a218275b
--- /dev/null
+++ b/flash3d/unidepth/models/backbones/metadinov2/dino_head.py
@@ -0,0 +1,68 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn as nn
+from torch.nn.init import trunc_normal_
+from torch.nn.utils import weight_norm
+
+
+class DINOHead(nn.Module):
+    def __init__(
+        self,
+        in_dim,
+        out_dim,
+        use_bn=False,
+        nlayers=3,
+        hidden_dim=2048,
+        bottleneck_dim=256,
+        mlp_bias=True,
+    ):
+        super().__init__()
+        nlayers = max(nlayers, 1)
+        self.mlp = _build_mlp(
+            nlayers,
+            in_dim,
+            bottleneck_dim,
+            hidden_dim=hidden_dim,
+            use_bn=use_bn,
+            bias=mlp_bias,
+        )
+        self.apply(self._init_weights)
+        self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
+        self.last_layer.weight_g.data.fill_(1)
+
+    def _init_weights(self, m):
+        if isinstance(m, nn.Linear):
+            trunc_normal_(m.weight, std=0.02)
+            if isinstance(m, nn.Linear) and m.bias is not None:
+                nn.init.constant_(m.bias, 0)
+
+    def forward(self, x):
+        x = self.mlp(x)
+        eps = 1e-6 if x.dtype == torch.float16 else 1e-12
+        x = nn.functional.normalize(x, dim=-1, p=2, eps=eps)
+        x = self.last_layer(x)
+        return x
+
+
+def _build_mlp(
+    nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True
+):
+    if nlayers == 1:
+        return nn.Linear(in_dim, bottleneck_dim, bias=bias)
+    else:
+        layers = [nn.Linear(in_dim, hidden_dim, bias=bias)]
+        if use_bn:
+            layers.append(nn.BatchNorm1d(hidden_dim))
+        layers.append(nn.GELU())
+        for _ in range(nlayers - 2):
+            layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias))
+            if use_bn:
+                layers.append(nn.BatchNorm1d(hidden_dim))
+            layers.append(nn.GELU())
+        layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias))
+        return nn.Sequential(*layers)
diff --git a/flash3d/unidepth/models/backbones/metadinov2/drop_path.py b/flash3d/unidepth/models/backbones/metadinov2/drop_path.py
new file mode 100644
index 0000000000000000000000000000000000000000..35b1a620d06ba862ea05297d271d8c2c625b5f93
--- /dev/null
+++ b/flash3d/unidepth/models/backbones/metadinov2/drop_path.py
@@ -0,0 +1,37 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# References:
+#   https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+#   https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
+
+
+import torch.nn as nn
+
+
+def drop_path(x, drop_prob: float = 0.0, training: bool = False):
+    if drop_prob == 0.0 or not training:
+        return x
+    keep_prob = 1 - drop_prob
+    shape = (x.shape[0],) + (1,) * (
+        x.ndim - 1
+    )  # work with diff dim tensors, not just 2D ConvNets
+    random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
+    if keep_prob > 0.0:
+        random_tensor.div_(keep_prob)
+    output = x * random_tensor
+    return output
+
+
+class DropPath(nn.Module):
+    """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
+
+    def __init__(self, drop_prob=None):
+        super(DropPath, self).__init__()
+        self.drop_prob = drop_prob
+
+    def forward(self, x):
+        return drop_path(x, self.drop_prob, self.training)
diff --git a/flash3d/unidepth/models/backbones/metadinov2/layer_scale.py b/flash3d/unidepth/models/backbones/metadinov2/layer_scale.py
new file mode 100644
index 0000000000000000000000000000000000000000..08c29476cff85a85ab5f071139175f6ac8ba19b2
--- /dev/null
+++ b/flash3d/unidepth/models/backbones/metadinov2/layer_scale.py
@@ -0,0 +1,28 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
+
+from typing import Union
+
+import torch
+from torch import Tensor
+import torch.nn as nn
+
+
+class LayerScale(nn.Module):
+    def __init__(
+        self,
+        dim: int,
+        init_values: Union[float, Tensor] = 1e-5,
+        inplace: bool = False,
+    ) -> None:
+        super().__init__()
+        self.inplace = inplace
+        self.gamma = nn.Parameter(init_values * torch.ones(dim))
+
+    def forward(self, x: Tensor) -> Tensor:
+        return x.mul_(self.gamma) if self.inplace else x * self.gamma
diff --git a/flash3d/unidepth/models/backbones/metadinov2/mlp.py b/flash3d/unidepth/models/backbones/metadinov2/mlp.py
new file mode 100644
index 0000000000000000000000000000000000000000..5e4b315f972f9a9f54aef1e4ef4e81b52976f018
--- /dev/null
+++ b/flash3d/unidepth/models/backbones/metadinov2/mlp.py
@@ -0,0 +1,41 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# References:
+#   https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+#   https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
+
+
+from typing import Callable, Optional
+
+from torch import Tensor, nn
+
+
+class Mlp(nn.Module):
+    def __init__(
+        self,
+        in_features: int,
+        hidden_features: Optional[int] = None,
+        out_features: Optional[int] = None,
+        act_layer: Callable[..., nn.Module] = nn.GELU,
+        drop: float = 0.0,
+        bias: bool = True,
+    ) -> None:
+        super().__init__()
+        out_features = out_features or in_features
+        hidden_features = hidden_features or in_features
+        self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
+        self.act = act_layer()
+        self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
+        self.drop = nn.Dropout(drop)
+
+    def forward(self, x: Tensor) -> Tensor:
+        x = self.fc1(x)
+        x = self.act(x)
+        x = self.drop(x)
+        x = self.fc2(x)
+        x = self.drop(x)
+        return x
diff --git a/flash3d/unidepth/models/backbones/metadinov2/patch_embed.py b/flash3d/unidepth/models/backbones/metadinov2/patch_embed.py
new file mode 100644
index 0000000000000000000000000000000000000000..837f952cf9a463444feeb146e0d5b539102ee26c
--- /dev/null
+++ b/flash3d/unidepth/models/backbones/metadinov2/patch_embed.py
@@ -0,0 +1,101 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# References:
+#   https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+#   https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
+
+from typing import Callable, Optional, Tuple, Union
+
+from torch import Tensor
+import torch.nn as nn
+
+
+def make_2tuple(x):
+    if isinstance(x, tuple):
+        assert len(x) == 2
+        return x
+
+    assert isinstance(x, int)
+    return (x, x)
+
+
+class PatchEmbed(nn.Module):
+    """
+    2D image to patch embedding: (B,C,H,W) -> (B,N,D)
+
+    Args:
+        img_size: Image size.
+        patch_size: Patch token size.
+        in_chans: Number of input image channels.
+        embed_dim: Number of linear projection output channels.
+        norm_layer: Normalization layer.
+    """
+
+    def __init__(
+        self,
+        img_size: Union[int, Tuple[int, int]] = 224,
+        patch_size: Union[int, Tuple[int, int]] = 16,
+        in_chans: int = 3,
+        embed_dim: int = 768,
+        norm_layer: Optional[Callable] = None,
+        flatten_embedding: bool = True,
+    ) -> None:
+        super().__init__()
+
+        image_HW = make_2tuple(img_size)
+        patch_HW = make_2tuple(patch_size)
+        patch_grid_size = (
+            image_HW[0] // patch_HW[0],
+            image_HW[1] // patch_HW[1],
+        )
+
+        self.img_size = image_HW
+        self.patch_size = patch_HW
+        self.patches_resolution = patch_grid_size
+        self.num_patches = patch_grid_size[0] * patch_grid_size[1]
+
+        self.in_chans = in_chans
+        self.embed_dim = embed_dim
+
+        self.flatten_embedding = flatten_embedding
+
+        self.proj = nn.Conv2d(
+            in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW
+        )
+        self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
+
+    def forward(self, x: Tensor) -> Tensor:
+        _, _, H, W = x.shape
+        patch_H, patch_W = self.patch_size
+
+        assert (
+            H % patch_H == 0
+        ), f"Input image height {H} is not a multiple of patch height {patch_H}"
+        assert (
+            W % patch_W == 0
+        ), f"Input image width {W} is not a multiple of patch width: {patch_W}"
+
+        x = self.proj(x)  # B C H W
+        H, W = x.size(2), x.size(3)
+        x = x.flatten(2).transpose(1, 2)  # B HW C
+        x = self.norm(x)
+        if not self.flatten_embedding:
+            x = x.reshape(-1, H, W, self.embed_dim)  # B H W C
+        return x
+
+    def flops(self) -> float:
+        Ho, Wo = self.patches_resolution
+        flops = (
+            Ho
+            * Wo
+            * self.embed_dim
+            * self.in_chans
+            * (self.patch_size[0] * self.patch_size[1])
+        )
+        if self.norm is not None:
+            flops += Ho * Wo * self.embed_dim
+        return flops
diff --git a/flash3d/unidepth/models/backbones/metadinov2/swiglu_ffn.py b/flash3d/unidepth/models/backbones/metadinov2/swiglu_ffn.py
new file mode 100644
index 0000000000000000000000000000000000000000..b3324b266fb0a50ccf8c3a0ede2ae10ac4dfa03e
--- /dev/null
+++ b/flash3d/unidepth/models/backbones/metadinov2/swiglu_ffn.py
@@ -0,0 +1,63 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from typing import Callable, Optional
+
+from torch import Tensor, nn
+import torch.nn.functional as F
+
+
+class SwiGLUFFN(nn.Module):
+    def __init__(
+        self,
+        in_features: int,
+        hidden_features: Optional[int] = None,
+        out_features: Optional[int] = None,
+        act_layer: Callable[..., nn.Module] = None,
+        drop: float = 0.0,
+        bias: bool = True,
+    ) -> None:
+        super().__init__()
+        out_features = out_features or in_features
+        hidden_features = hidden_features or in_features
+        self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
+        self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
+
+    def forward(self, x: Tensor) -> Tensor:
+        x12 = self.w12(x)
+        x1, x2 = x12.chunk(2, dim=-1)
+        hidden = F.silu(x1) * x2
+        return self.w3(hidden)
+
+
+try:
+    from xformers.ops import SwiGLU
+
+    XFORMERS_AVAILABLE = True
+except ImportError:
+    SwiGLU = SwiGLUFFN
+    XFORMERS_AVAILABLE = False
+
+
+class SwiGLUFFNFused(SwiGLU):
+    def __init__(
+        self,
+        in_features: int,
+        hidden_features: Optional[int] = None,
+        out_features: Optional[int] = None,
+        act_layer: Callable[..., nn.Module] = None,
+        drop: float = 0.0,
+        bias: bool = True,
+    ) -> None:
+        out_features = out_features or in_features
+        hidden_features = hidden_features or in_features
+        hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
+        super().__init__(
+            in_features=in_features,
+            hidden_features=hidden_features,
+            out_features=out_features,
+            bias=bias,
+        )
diff --git a/flash3d/unidepth/models/encoder.py b/flash3d/unidepth/models/encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..e302df27bb4bf0f82c1b0baff9c682dc2d2b9e9f
--- /dev/null
+++ b/flash3d/unidepth/models/encoder.py
@@ -0,0 +1,184 @@
+import torch
+import torch.nn as nn
+
+from unidepth.models.backbones import ConvNeXtV2, _make_dinov2_model, ConvNeXt
+
+
+class ModelWrap(nn.Module):
+    def __init__(self, model) -> None:
+        super().__init__()
+        self.backbone = model
+
+    def forward(self, x, *args, **kwargs):
+        features = []
+        for layer in self.backbone.features:
+            x = layer(x)
+            features.append(x)
+        return features
+
+
+def convnextv2_base(config, **kwargs):
+    model = ConvNeXtV2(
+        depths=[3, 3, 27, 3],
+        dims=[128, 256, 512, 1024],
+        output_idx=config.get("output_idx", [3, 6, 33, 36]),
+        use_checkpoint=config.get("use_checkpoint", False),
+        **kwargs,
+    )
+    url = "https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_base_22k_384_ema.pt"
+    state_dict = torch.hub.load_state_dict_from_url(
+        url, map_location="cpu", progress=False
+    )["model"]
+    info = model.load_state_dict(state_dict, strict=False)
+    print(info)
+    return model
+
+
+def convnextv2_large(config, **kwargs):
+    model = ConvNeXtV2(
+        depths=[3, 3, 27, 3],
+        dims=[192, 384, 768, 1536],
+        output_idx=config.get("output_idx", [3, 6, 33, 36]),
+        use_checkpoint=config.get("use_checkpoint", False),
+        **kwargs,
+    )
+    url = "https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_large_22k_384_ema.pt"
+    state_dict = torch.hub.load_state_dict_from_url(
+        url, map_location="cpu", progress=False
+    )["model"]
+    info = model.load_state_dict(state_dict, strict=False)
+    print(info)
+    return model
+
+
+def convnextv2_large_mae(config, **kwargs):
+    model = ConvNeXtV2(
+        depths=[3, 3, 27, 3],
+        dims=[192, 384, 768, 1536],
+        output_idx=config.get("output_idx", [3, 6, 33, 36]),
+        use_checkpoint=config.get("use_checkpoint", False),
+        **kwargs,
+    )
+    url = "https://dl.fbaipublicfiles.com/convnext/convnextv2/pt_only/convnextv2_large_1k_224_fcmae.pt"
+    state_dict = torch.hub.load_state_dict_from_url(
+        url, map_location="cpu", progress=False
+    )["model"]
+    info = model.load_state_dict(state_dict, strict=False)
+    print(info)
+    return model
+
+
+def convnextv2_huge(config, **kwargs):
+    model = ConvNeXtV2(
+        depths=[3, 3, 27, 3],
+        dims=[352, 704, 1408, 2816],
+        output_idx=config.get("output_idx", [3, 6, 33, 36]),
+        use_checkpoint=config.get("use_checkpoint", False),
+        **kwargs,
+    )
+    url = "https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_huge_22k_512_ema.pt"
+    state_dict = torch.hub.load_state_dict_from_url(
+        url, map_location="cpu", progress=False
+    )["model"]
+    info = model.load_state_dict(state_dict, strict=False)
+    print(info)
+    return model
+
+
+def convnextv2_huge_mae(config, **kwargs):
+    model = ConvNeXtV2(
+        depths=[3, 3, 27, 3],
+        dims=[352, 704, 1408, 2816],
+        output_idx=config.get("output_idx", [3, 6, 33, 36]),
+        use_checkpoint=config.get("use_checkpoint", False),
+        **kwargs,
+    )
+    url = "https://dl.fbaipublicfiles.com/convnext/convnextv2/pt_only/convnextv2_huge_1k_224_fcmae.pt"
+    state_dict = torch.hub.load_state_dict_from_url(
+        url, map_location="cpu", progress=False
+    )["model"]
+    info = model.load_state_dict(state_dict, strict=False)
+    print(info)
+    return model
+
+
+def convnext_large_pt(config, **kwargs):
+    model = ConvNeXt(
+        depths=[3, 3, 27, 3],
+        dims=[192, 384, 768, 1536],
+        output_idx=config.get("output_idx", [3, 6, 33, 36]),
+        use_checkpoint=config.get("use_checkpoint", False),
+        **kwargs,
+    )
+    from unidepth.models.backbones.convnext import HF_URL, checkpoint_filter_fn
+    from huggingface_hub import hf_hub_download
+    from huggingface_hub.utils import disable_progress_bars
+
+    disable_progress_bars()
+    repo_id, filename = HF_URL["convnext_large_pt"]
+    state_dict = torch.load(hf_hub_download(repo_id=repo_id, filename=filename))
+    state_dict = checkpoint_filter_fn(state_dict, model)
+    info = model.load_state_dict(state_dict, strict=False)
+    print(info)
+    return model
+
+
+def convnext_large(config, **kwargs):
+    model = ConvNeXt(
+        depths=[3, 3, 27, 3],
+        dims=[192, 384, 768, 1536],
+        output_idx=config.get("output_idx", [3, 6, 33, 36]),
+        use_checkpoint=config.get("use_checkpoint", False),
+        drop_path_rate=config.get("drop_path", 0.0),
+        **kwargs,
+    )
+    return model
+
+
+def dinov2_vitb14(config, pretrained: bool = True, **kwargs):
+    """
+    DINOv2 ViT-B/14 model (optionally) pretrained on the LVD-142M dataset.
+    """
+    vit = _make_dinov2_model(
+        arch_name="vit_base",
+        pretrained=pretrained,
+        output_idx=config.get("output_idx", [3, 6, 9, 12]),
+        checkpoint=config.get("use_checkpoint", False),
+        drop_path_rate=config.get("drop_path", 0.0),
+        num_register_tokens=config.get("num_register_tokens", 0),
+        **kwargs,
+    )
+    return vit
+
+
+def dinov2_vitl14(config, pretrained: str = "", **kwargs):
+    """
+    DINOv2 ViT-L/14 model (optionally) pretrained on the LVD-142M dataset.
+    """
+    vit = _make_dinov2_model(
+        arch_name="vit_large",
+        pretrained=config["pretrained"],
+        output_idx=config.get("output_idx", [5, 12, 18, 24]),
+        checkpoint=config.get("use_checkpoint", False),
+        drop_path_rate=config.get("drop_path", 0.0),
+        num_register_tokens=config.get("num_register_tokens", 0),
+        **kwargs,
+    )
+    return vit
+
+
+def dinov2_vitg14(config, pretrained: bool = True, **kwargs):
+    """
+    DINOv2 ViT-g/14 model (optionally) pretrained on the LVD-142M dataset.
+    """
+    vit = _make_dinov2_model(
+        arch_name="vit_giant2",
+        ffn_layer="swiglufused",
+        pretrained=pretrained,
+        output_idx=config.get("output_idx", [10, 20, 30, 40]),
+        checkpoint=config.get("use_checkpoint", False),
+        drop_path_rate=config.get("drop_path", 0.0),
+        num_register_tokens=config.get("num_register_tokens", 0),
+        **kwargs,
+    )
+    return vit
diff --git a/flash3d/unidepth/models/unidepthv1/__init__.py b/flash3d/unidepth/models/unidepthv1/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..1781bda94cdc13b0c0c805e7cde0872defc20cd3
--- /dev/null
+++ b/flash3d/unidepth/models/unidepthv1/__init__.py
@@ -0,0 +1,5 @@
+from .unidepthv1 import UniDepthV1
+
+__all__ = [
+    "UniDepthV1",
+]
diff --git a/flash3d/unidepth/models/unidepthv1/decoder.py b/flash3d/unidepth/models/unidepthv1/decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..e1b0fc750ac8798575c401a23994318914cf80f0
--- /dev/null
+++ b/flash3d/unidepth/models/unidepthv1/decoder.py
@@ -0,0 +1,542 @@
+"""
+Author: Luigi Piccinelli
+Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/)
+"""
+
+from typing import List, Tuple
+
+from einops import rearrange
+from timm.models.layers import trunc_normal_
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from unidepth.layers import (
+    MLP,
+    AttentionBlock,
+    NystromBlock,
+    PositionEmbeddingSine,
+    ConvUpsample,
+)
+from unidepth.utils.sht import rsh_cart_8
+from unidepth.utils.geometric import (
+    generate_rays,
+    flat_interpolate,
+)
+from unidepth.utils.misc import max_stack
+
+
+class ListAdapter(nn.Module):
+    def __init__(self, input_dims: List[int], hidden_dim: int):
+        super().__init__()
+        self.input_adapters = nn.ModuleList([])
+        self.num_chunks = len(input_dims)
+        for input_dim in input_dims:
+            self.input_adapters.append(
+                nn.Sequential(
+                    nn.LayerNorm(input_dim), nn.Linear(input_dim, hidden_dim), nn.GELU()
+                )
+            )
+
+    def forward(self, x: torch.Tensor, splits: torch.Tensor) -> torch.Tensor:
+        xs = torch.split(x, splits.int().tolist(), dim=-1)
+        xs = [adapter(x) for x, adapter in zip(xs, self.input_adapters)]
+        return torch.cat(xs, dim=-1)
+
+
+class CameraHead(nn.Module):
+    def __init__(
+        self,
+        input_dim: int,
+        hidden_dim: int,
+        num_heads: int = 8,
+        expansion: int = 4,
+        depth: int = 4,
+        dropout: float = 0.0,
+        layer_scale: float = 1.0,
+        **kwargs,
+    ):
+        super().__init__()
+
+        self.aggregate = AttentionBlock(
+            hidden_dim,
+            num_heads=1,
+            expansion=expansion,
+            dropout=dropout,
+            layer_scale=layer_scale,
+        )
+        self.latents_pos = nn.Parameter(
+            torch.randn(1, 4, hidden_dim), requires_grad=True
+        )
+        self.layers = nn.ModuleList([])
+        self.in_features = MLP(hidden_dim, expansion=2, dropout=dropout)
+        for _ in range(depth):
+            blk = AttentionBlock(
+                hidden_dim,
+                num_heads=num_heads,
+                expansion=expansion,
+                dropout=dropout,
+                layer_scale=layer_scale,
+            )
+            self.layers.append(blk)
+        self.out = MLP(hidden_dim, expansion=2, dropout=0.0, output_dim=1)
+        self.cls_project = nn.Sequential(
+            nn.LayerNorm(input_dim),
+            nn.Linear(input_dim, hidden_dim // 2),
+            nn.GELU(),
+            nn.Linear(hidden_dim // 2, hidden_dim),
+        )
+
+    def forward(self, features, cls_tokens, pos_embed) -> torch.Tensor:
+        features = features.unbind(dim=-1)
+        cls_tokens = self.cls_project(cls_tokens)
+        features_stack = torch.cat(features, dim=1)
+        features_stack = features_stack + pos_embed
+        latents_pos = self.latents_pos.expand(cls_tokens.shape[0], -1, -1)
+        features_stack = self.in_features(features_stack)
+        features = torch.cat((features_stack, cls_tokens), dim=1)
+        cls_tokens = self.aggregate(cls_tokens, context=features, pos_embed=latents_pos)
+        for i, layer in enumerate(self.layers):
+            cls_tokens = layer(cls_tokens, pos_embed=latents_pos)
+
+        # project
+        x = self.out(cls_tokens).squeeze(-1)
+        camera_intrinsics = torch.zeros(
+            x.shape[0], 3, 3, device=x.device, requires_grad=False
+        )
+        camera_intrinsics[:, 0, 0] = x[:, 0].exp()
+        camera_intrinsics[:, 1, 1] = x[:, 1].exp()
+        camera_intrinsics[:, 0, 2] = x[:, 2].sigmoid()
+        camera_intrinsics[:, 1, 2] = x[:, 3].sigmoid()
+        camera_intrinsics[:, 2, 2] = 1.0
+        return camera_intrinsics
+
+    def set_shapes(self, shapes: Tuple[int, int]):
+        self.shapes = shapes
+
+
+class DepthHead(nn.Module):
+    def __init__(
+        self,
+        hidden_dim: int,
+        num_heads: int = 8,
+        expansion: int = 4,
+        depths: int | list[int] = 4,
+        camera_dim: int = 256,
+        num_resolutions: int = 4,
+        dropout: float = 0.0,
+        layer_scale: float = 1.0,
+        **kwargs,
+    ) -> None:
+        super().__init__()
+        if isinstance(depths, int):
+            depths = [depths] * 3
+        assert len(depths) == 3
+
+        self.project_rays16 = MLP(
+            camera_dim, expansion=expansion, dropout=dropout, output_dim=hidden_dim
+        )
+        self.project_rays8 = MLP(
+            camera_dim, expansion=expansion, dropout=dropout, output_dim=hidden_dim // 2
+        )
+        self.project_rays4 = MLP(
+            camera_dim, expansion=expansion, dropout=dropout, output_dim=hidden_dim // 4
+        )
+        self.to_latents = MLP(hidden_dim, expansion=2, dropout=dropout)
+
+        self.features_channel_cat = nn.Linear(hidden_dim * num_resolutions, hidden_dim)
+
+        self.up8 = ConvUpsample(
+            hidden_dim, expansion=expansion, layer_scale=layer_scale
+        )
+        self.up4 = ConvUpsample(
+            hidden_dim // 2, expansion=expansion, layer_scale=layer_scale
+        )
+        self.up2 = ConvUpsample(
+            hidden_dim // 4, expansion=expansion, layer_scale=layer_scale
+        )
+
+        self.layers_16 = nn.ModuleList([])
+        self.layers_8 = nn.ModuleList([])
+        self.layers_4 = nn.ModuleList([])
+        self.aggregate_16 = AttentionBlock(
+            hidden_dim,
+            num_heads=1,
+            expansion=expansion,
+            dropout=dropout,
+            layer_scale=layer_scale,
+            context_dim=hidden_dim,
+        )
+        self.prompt_camera = AttentionBlock(
+            hidden_dim,
+            num_heads=1,
+            expansion=expansion,
+            dropout=dropout,
+            layer_scale=layer_scale,
+            context_dim=hidden_dim,
+        )
+        for i, (blk_lst, depth) in enumerate(
+            zip([self.layers_16, self.layers_8, self.layers_4], depths)
+        ):
+            attn_cls = AttentionBlock if i == 0 else NystromBlock
+            for _ in range(depth):
+                blk_lst.append(
+                    attn_cls(
+                        hidden_dim // (2**i),
+                        num_heads=num_heads // (2**i),
+                        expansion=expansion,
+                        dropout=dropout,
+                        layer_scale=layer_scale,
+                    )
+                )
+
+        self.out2 = nn.Conv2d(hidden_dim // 8, 1, 3, padding=1)
+        self.out4 = nn.Conv2d(hidden_dim // 4, 1, 3, padding=1)
+        self.out8 = nn.Conv2d(hidden_dim // 2, 1, 3, padding=1)
+
+    def set_original_shapes(self, shapes: Tuple[int, int]):
+        self.original_shapes = shapes
+
+    def set_shapes(self, shapes: Tuple[int, int]):
+        self.shapes = shapes
+
+    def forward(
+        self, features: torch.Tensor, rays_hr: torch.Tensor, pos_embed, level_embed
+    ) -> torch.Tensor:
+        features = features.unbind(dim=-1)
+        shapes = self.shapes
+
+        # camera_embedding
+        # torch.cuda.synchronize()
+        # start = time()
+        rays_embedding_16 = F.normalize(
+            flat_interpolate(rays_hr, old=self.original_shapes, new=shapes), dim=-1
+        )
+        rays_embedding_8 = F.normalize(
+            flat_interpolate(
+                rays_hr, old=self.original_shapes, new=[x * 2 for x in shapes]
+            ),
+            dim=-1,
+        )
+        rays_embedding_4 = F.normalize(
+            flat_interpolate(
+                rays_hr, old=self.original_shapes, new=[x * 4 for x in shapes]
+            ),
+            dim=-1,
+        )
+        rays_embedding_16 = self.project_rays16(rsh_cart_8(rays_embedding_16))
+        rays_embedding_8 = self.project_rays8(rsh_cart_8(rays_embedding_8))
+        rays_embedding_4 = self.project_rays4(rsh_cart_8(rays_embedding_4))
+        # torch.cuda.synchronize()
+        # print(f"camera_embedding took {time() - start} seconds")
+        features_tokens = torch.cat(features, dim=1)
+        features_tokens_pos = pos_embed + level_embed
+
+        # Generate latents with init as pooled features
+        features_channels = torch.cat(features, dim=-1)
+        features_16 = self.features_channel_cat(features_channels)
+        latents_16 = self.to_latents(
+            flat_interpolate(features_16, old=self.shapes, new=shapes, antialias=False)
+        )
+
+        # Aggregate features: F -> D
+        latents_16 = self.aggregate_16(
+            latents_16, context=features_tokens, pos_embed_context=features_tokens_pos
+        )
+
+        # Aggregate camera: D- > D|E
+        latents_16 = self.prompt_camera(latents_16, context=rays_embedding_16)
+
+        # Block 16 - Out 8
+        for layer in self.layers_16:
+            latents_16 = layer(latents_16, pos_embed=rays_embedding_16)
+        latents_8 = self.up8(
+            rearrange(
+                latents_16 + rays_embedding_16,
+                "b (h w) c -> b c h w",
+                h=shapes[0],
+                w=shapes[1],
+            ).contiguous()
+        )
+        out8 = self.out8(
+            rearrange(
+                latents_8, "b (h w) c -> b c h w", h=shapes[0] * 2, w=shapes[1] * 2
+            )
+        )
+
+        # Block 8 - Out 4
+        for layer in self.layers_8:
+            latents_8 = layer(latents_8, pos_embed=rays_embedding_8)
+        latents_4 = self.up4(
+            rearrange(
+                latents_8 + rays_embedding_8,
+                "b (h w) c -> b c h w",
+                h=shapes[0] * 2,
+                w=shapes[1] * 2,
+            ).contiguous()
+        )
+        out4 = self.out4(
+            rearrange(
+                latents_4, "b (h w) c -> b c h w", h=shapes[0] * 4, w=shapes[1] * 4
+            )
+        )
+
+        # Block 4 - Out 2
+        for layer in self.layers_4:
+            latents_4 = layer(latents_4, pos_embed=rays_embedding_4)
+        latents_2 = self.up2(
+            rearrange(
+                latents_4 + rays_embedding_4,
+                "b (h w) c -> b c h w",
+                h=shapes[0] * 4,
+                w=shapes[1] * 4,
+            ).contiguous()
+        )
+        out2 = self.out2(
+            rearrange(
+                latents_2, "b (h w) c -> b c h w", h=shapes[0] * 8, w=shapes[1] * 8
+            )
+        )
+
+        # Depth features
+        proj_latents_16 = rearrange(
+            latents_16, "b (h w) c -> b c h w", h=shapes[0], w=shapes[1]
+        ).contiguous()
+
+        # MS Outputs
+        out2 = out2.clamp(-10.0, 10.0).exp()
+        out4 = out4.clamp(-10.0, 10.0).exp()
+        out8 = out8.clamp(-10.0, 10.0).exp()
+
+        return out8, out4, out2, proj_latents_16
+
+
+class Decoder(nn.Module):
+    def __init__(
+        self,
+        config,
+        *args,
+        **kwargs,
+    ):
+        super().__init__()
+        self.build(config)
+        self.apply(self._init_weights)
+        self.test_fixed_camera = False
+        self.skip_camera = False
+
+    def _init_weights(self, m):
+        if isinstance(m, nn.Linear):
+            trunc_normal_(m.weight, std=0.02)
+            if m.bias is not None:
+                nn.init.constant_(m.bias, 0)
+        elif isinstance(m, nn.Conv2d):
+            trunc_normal_(m.weight, std=0.02)
+            if m.bias is not None:
+                nn.init.constant_(m.bias, 0)
+        elif isinstance(m, nn.LayerNorm):
+            nn.init.constant_(m.bias, 0)
+            nn.init.constant_(m.weight, 1.0)
+
+    def get_adapted_features(self, features_flat, splits):
+        features_flat_cat = torch.cat(features_flat, dim=-1)
+        features_projected = self.input_adapter(
+            features_flat_cat, splits
+        )  # list [b hw c] shapes
+        features = torch.chunk(features_projected, len(splits), dim=-1)
+        return features
+
+    def run_camera(self, cls_tokens, features, pos_embed, original_shapes, rays):
+        # get cls tokens projections
+        cls_tokens_splits = torch.tensor(
+            [x.shape[-1] for x in cls_tokens],
+            device=features.device,
+            requires_grad=False,
+            dtype=features.dtype,
+        )
+        cls_tokens = torch.cat(cls_tokens, dim=-1)
+        cls_tokens = self.token_adapter(cls_tokens, cls_tokens_splits)
+        cls_tokens = torch.cat(
+            torch.chunk(cls_tokens, len(cls_tokens_splits), dim=-1), dim=1
+        )
+
+        # camera layer
+        intrinsics = self.camera_layer(
+            features=features, cls_tokens=cls_tokens, pos_embed=pos_embed
+        )
+        intrinsics[:, 0, 0] = max(original_shapes) / 2 * intrinsics[:, 0, 0]
+        intrinsics[:, 1, 1] = max(original_shapes) / 2 * intrinsics[:, 1, 1]
+        intrinsics[:, 0, 2] = intrinsics[:, 0, 2] * original_shapes[1]
+        intrinsics[:, 1, 2] = intrinsics[:, 1, 2] * original_shapes[0]
+        if not self.test_fixed_camera:
+            rays, _ = generate_rays(intrinsics, original_shapes, noisy=False)
+
+        return intrinsics, rays
+
+    def forward(self, inputs, image_metas) -> torch.Tensor:
+        B, _, H, W = inputs["image"].shape
+        device = inputs["image"].device
+
+        # make stride happy?
+        original_encoder_outputs = [x.contiguous() for x in inputs["encoder_outputs"]]
+        cls_tokens = [x.contiguous() for x in inputs["cls_tokens"]]
+
+        # collect features and tokens
+        original_encoder_outputs = [
+            max_stack(original_encoder_outputs[i:j])
+            for i, j in self.slices_encoder_range
+        ]
+        cls_tokens = [cls_tokens[-i - 1] for i in range(len(self.slices_encoder_range))]
+
+        # get features in b n d format
+        # level shapes, the shape per level, for swin like [[128, 128], [64, 64],...], for vit [[32,32]] -> mult times resolutions
+        resolutions = [
+            tuple(sorted([x.shape[1], x.shape[2]])) for x in original_encoder_outputs
+        ]
+        level_shapes = sorted(list(set(resolutions)))[::-1]
+
+        if len(level_shapes) == 1:
+            level_shapes = level_shapes * self.num_resolutions
+        input_shapes = [
+            level_shapes[i]
+            for i, (start, end) in enumerate(self.slices_encoder)
+            for _ in range(end - start)
+        ]
+        common_shape = level_shapes[-2]
+
+        # input shapes repeat shapes for each level, times the amount of the layers:
+        features_flat = [
+            flat_interpolate(
+                rearrange(x, "b h w c -> b (h w) c"), old=input_shape, new=common_shape
+            )
+            for x, input_shape in zip(original_encoder_outputs, input_shapes)
+        ]
+        features_splits = torch.tensor(
+            [x.shape[-1] for x in features_flat],
+            device=device,
+            requires_grad=False,
+            dtype=torch.float32,
+        )
+
+        # input adapter, then do mean of features in same blocks
+        features = self.get_adapted_features(features_flat, features_splits)
+        features = torch.stack(features, dim=-1)
+
+        # positional embeddings, spatial and level
+        level_embed = torch.cat(
+            [
+                self.level_embed_layer(self.level_embeds)[i : i + 1]
+                .unsqueeze(0)
+                .repeat(B, common_shape[0] * common_shape[1], 1)
+                for i in range(self.num_resolutions)
+            ],
+            dim=1,
+        )
+        pos_embed = self.pos_embed(
+            torch.zeros(
+                B,
+                1,
+                common_shape[0],
+                common_shape[1],
+                device=device,
+                requires_grad=False,
+            )
+        )
+        pos_embed = rearrange(pos_embed, "b c h w -> b (h w) c").repeat(
+            1, self.num_resolutions, 1
+        )
+
+        self.camera_layer.set_shapes(common_shape)
+        intrinsics, rays = (
+            self.run_camera(
+                cls_tokens,
+                features=features,
+                pos_embed=pos_embed + level_embed,
+                original_shapes=(H, W),
+                rays=inputs.get("rays", None),
+            )
+            if not self.skip_camera
+            else (inputs["K"], inputs["rays"])
+        )
+
+        # run bulk of the model
+        self.depth_layer.set_shapes(common_shape)
+        self.depth_layer.set_original_shapes((H, W))
+        out8, out4, out2, depth_features = self.depth_layer(
+            features=features,
+            rays_hr=rays,
+            pos_embed=pos_embed,
+            level_embed=level_embed,
+        )
+
+        return intrinsics, [out8, out4, out2], depth_features, rays
+
+    @torch.jit.ignore
+    def no_weight_decay_keywords(self):
+        return {"latents_pos", "level_embeds"}
+
+    def build(self, config):
+        depth = config["model"]["pixel_decoder"]["depths"]
+        input_dims = config["model"]["pixel_encoder"]["embed_dims"]
+        hidden_dim = config["model"]["pixel_decoder"]["hidden_dim"]
+        num_heads = config["model"]["num_heads"]
+        expansion = config["model"]["expansion"]
+        dropout = config["model"]["pixel_decoder"]["dropout"]
+        depths_encoder = config["model"]["pixel_encoder"]["depths"]
+        num_steps = config["model"].get("num_steps", 100000)
+        layer_scale = 1.0
+        
+        self.depth = depth
+        self.dim = hidden_dim
+        self.downsample = 4
+        self.num_heads = num_heads
+        self.num_resolutions = len(depths_encoder)
+        self.depths_encoder = depths_encoder
+
+        self.slices_encoder_single = list(
+            zip([d - 1 for d in self.depths_encoder], self.depths_encoder)
+        )
+        self.slices_encoder_range = list(
+            zip([0, *self.depths_encoder[:-1]], self.depths_encoder)
+        )
+        cls_token_input_dims = [input_dims[-i - 1] for i in range(len(depths_encoder))]
+
+        input_dims = [input_dims[d - 1] for d in depths_encoder]
+        self.slices_encoder = self.slices_encoder_single
+
+        # adapt from encoder features, just project
+        self.input_adapter = ListAdapter(input_dims, hidden_dim)
+        self.token_adapter = ListAdapter(cls_token_input_dims, hidden_dim)
+
+        # camera layer
+        self.camera_layer = CameraHead(
+            input_dim=hidden_dim,
+            hidden_dim=hidden_dim,
+            num_heads=num_heads,
+            expansion=expansion,
+            depth=2,
+            dropout=dropout,
+            layer_scale=layer_scale,
+        )
+
+        self.depth_layer = DepthHead(
+            hidden_dim=hidden_dim,
+            num_heads=num_heads,
+            expansion=expansion,
+            depths=depth,
+            dropout=dropout,
+            camera_dim=81,
+            num_resolutions=self.num_resolutions,
+            layer_scale=layer_scale,
+        )
+
+        # transformer part
+        self.pos_embed = PositionEmbeddingSine(hidden_dim // 2, normalize=True)
+        self.level_embeds = nn.Parameter(
+            torch.randn(len(input_dims), hidden_dim), requires_grad=True
+        )
+        self.level_embed_layer = nn.Sequential(
+            nn.Linear(hidden_dim, hidden_dim),
+            nn.GELU(),
+            nn.Linear(hidden_dim, hidden_dim),
+            nn.LayerNorm(hidden_dim),
+        )
\ No newline at end of file
diff --git a/flash3d/unidepth/models/unidepthv1/unidepthv1.py b/flash3d/unidepth/models/unidepthv1/unidepthv1.py
new file mode 100644
index 0000000000000000000000000000000000000000..bf0207120b85f033b6479d4db7003eee1563c868
--- /dev/null
+++ b/flash3d/unidepth/models/unidepthv1/unidepthv1.py
@@ -0,0 +1,329 @@
+"""
+Author: Luigi Piccinelli
+Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/)
+"""
+
+from copy import deepcopy
+import importlib
+from typing import Any, Dict, Tuple
+from math import ceil
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torchvision.transforms.functional as TF
+from einops import rearrange
+
+from unidepth.utils.geometric import (
+    generate_rays,
+    spherical_zbuffer_to_euclidean,
+)
+from unidepth.utils.misc import get_params
+from unidepth.utils.distributed import is_main_process
+from unidepth.utils.constants import IMAGENET_DATASET_MEAN, IMAGENET_DATASET_STD
+from unidepth.models.unidepthv1.decoder import Decoder
+
+from huggingface_hub import PyTorchModelHubMixin
+
+
+MAP_BACKBONES = {"ViTL14": "vitl14", "ConvNextL": "cnvnxtl"}
+
+
+# inference helpers
+def _paddings(image_shape, network_shape):
+    cur_h, cur_w = image_shape
+    h, w = network_shape
+    pad_top, pad_bottom = (h - cur_h) // 2, h - cur_h - (h - cur_h) // 2
+    pad_left, pad_right = (w - cur_w) // 2, w - cur_w - (w - cur_w) // 2
+    return pad_left, pad_right, pad_top, pad_bottom
+
+
+def _shapes(image_shape, network_shape):
+    h, w = image_shape
+    input_ratio = w / h
+    output_ratio = network_shape[1] / network_shape[0]
+    if output_ratio > input_ratio:
+        ratio = network_shape[0] / h
+    elif output_ratio <= input_ratio:
+        ratio = network_shape[1] / w
+    return (ceil(h * ratio - 0.5), ceil(w * ratio - 0.5)), ratio
+
+
+def _preprocess(rgbs, intrinsics, shapes, pads, ratio, output_shapes):
+    (pad_left, pad_right, pad_top, pad_bottom) = pads
+    rgbs = F.interpolate(
+        rgbs, size=shapes, mode="bilinear", align_corners=False, antialias=True
+    )
+    rgbs = F.pad(rgbs, (pad_left, pad_right, pad_top, pad_bottom), mode="constant")
+    if intrinsics is not None:
+        intrinsics = intrinsics.clone()
+        intrinsics[:, 0, 0] = intrinsics[:, 0, 0] * ratio
+        intrinsics[:, 1, 1] = intrinsics[:, 1, 1] * ratio
+        intrinsics[:, 0, 2] = intrinsics[:, 0, 2] * ratio + pad_left
+        intrinsics[:, 1, 2] = intrinsics[:, 1, 2] * ratio + pad_top
+        return rgbs, intrinsics
+    return rgbs, None
+
+
+def _postprocess(predictions, intrinsics, shapes, pads, ratio, original_shapes):
+    (pad_left, pad_right, pad_top, pad_bottom) = pads
+    # pred mean, trim paddings, and upsample to input dim
+    predictions = sum(
+        [
+            F.interpolate(
+                x.clone(),
+                size=shapes,
+                mode="bilinear",
+                align_corners=False,
+                antialias=True,
+            )
+            for x in predictions
+        ]
+    ) / len(predictions)
+    predictions = predictions[
+        ..., pad_top : shapes[0] - pad_bottom, pad_left : shapes[1] - pad_right
+    ]
+    predictions = F.interpolate(
+        predictions,
+        size=original_shapes,
+        mode="bilinear",
+        align_corners=False,
+        antialias=True,
+    )
+    intrinsics[:, 0, 0] = intrinsics[:, 0, 0] / ratio
+    intrinsics[:, 1, 1] = intrinsics[:, 1, 1] / ratio
+    intrinsics[:, 0, 2] = (intrinsics[:, 0, 2] - pad_left) / ratio
+    intrinsics[:, 1, 2] = (intrinsics[:, 1, 2] - pad_top) / ratio
+    return predictions, intrinsics
+
+
+class UniDepthV1(nn.Module,
+                 PyTorchModelHubMixin,
+                 library_name="UniDepth",
+                 repo_url="https://github.com/lpiccinelli-eth/UniDepth",
+                 tags=["monocular-metric-depth-estimation"]):
+    def __init__(
+        self,
+        config,
+        eps: float = 1e-6,
+        **kwargs,
+    ):
+        super().__init__()
+        self.build(config)
+        self.eps = eps
+
+    def forward(self, inputs, image_metas):
+        rgbs = inputs["image"]
+        gt_intrinsics = inputs.get("K")
+        H, W = rgbs.shape[-2:]
+
+        # Encode
+        encoder_outputs, cls_tokens = self.pixel_encoder(rgbs)
+        if "dino" in self.pixel_encoder.__class__.__name__.lower():
+            encoder_outputs = [
+                (x + y.unsqueeze(1)).contiguous()
+                for x, y in zip(encoder_outputs, cls_tokens)
+            ]
+        inputs["encoder_outputs"] = encoder_outputs
+        inputs["cls_tokens"] = cls_tokens
+
+        # Get camera infos, if any
+        if gt_intrinsics is not None:
+            rays, angles = generate_rays(
+                gt_intrinsics, self.image_shape, noisy=self.training
+            )
+            inputs["rays"] = rays
+            inputs["angles"] = angles
+            inputs["K"] = gt_intrinsics
+            self.pixel_decoder.test_fixed_camera = True  # use GT camera in fwd
+
+        # Decode
+        pred_intrinsics, predictions, _, _ = self.pixel_decoder(inputs, {})
+        predictions = sum(
+            [
+                F.interpolate(
+                    x.clone(),
+                    size=self.image_shape,
+                    mode="bilinear",
+                    align_corners=False,
+                    antialias=True,
+                )
+                for x in predictions
+            ]
+        ) / len(predictions)
+
+        # Final 3D points backprojection
+        pred_angles = generate_rays(pred_intrinsics, (H, W), noisy=False)[-1]
+        # You may want to use inputs["angles"] if available?
+        pred_angles = rearrange(pred_angles, "b (h w) c -> b c h w", h=H, w=W)
+        points_3d = torch.cat((pred_angles, predictions), dim=1)
+        points_3d = spherical_zbuffer_to_euclidean(
+            points_3d.permute(0, 2, 3, 1)
+        ).permute(0, 3, 1, 2)
+
+        # Output data, use for loss computation
+        outputs = {
+            "angles": pred_angles,
+            "intrinsics": pred_intrinsics,
+            "points": points_3d,
+            "depth": predictions[:, -1:],
+        }
+        self.pixel_decoder.test_fixed_camera = False
+        return outputs
+
+    @torch.no_grad()
+    def infer(self, rgbs: torch.Tensor, intrinsics=None, skip_camera=False):
+        if rgbs.ndim == 3:
+            rgbs = rgbs.unsqueeze(0)
+        if intrinsics is not None and intrinsics.ndim == 2:
+            intrinsics = intrinsics.unsqueeze(0)
+        B, _, H, W = rgbs.shape
+
+        rgbs = rgbs.to(self.device)
+        if intrinsics is not None:
+            intrinsics = intrinsics.to(self.device)
+
+        # process image and intrinsiscs (if any) to match network input (slow?)
+        if rgbs.max() > 5 or rgbs.dtype == torch.uint8:
+            rgbs = TF.normalize(
+                rgbs.to(torch.float32).div(255),
+                mean=IMAGENET_DATASET_MEAN,
+                std=IMAGENET_DATASET_STD,
+            )
+        else:
+            pass
+            # print("Image not normalized, was it already normalized?")
+        (h, w), ratio = _shapes((H, W), self.image_shape)
+        pad_left, pad_right, pad_top, pad_bottom = _paddings((h, w), self.image_shape)
+        rgbs, gt_intrinsics = _preprocess(
+            rgbs,
+            intrinsics,
+            (h, w),
+            (pad_left, pad_right, pad_top, pad_bottom),
+            ratio,
+            self.image_shape,
+        )
+
+        # run encoder
+        encoder_outputs, cls_tokens = self.pixel_encoder(rgbs)
+        if "dino" in self.pixel_encoder.__class__.__name__.lower():
+            encoder_outputs = [
+                (x + y.unsqueeze(1)).contiguous()
+                for x, y in zip(encoder_outputs, cls_tokens)
+            ]
+
+        # get data for decoder and adapt to given camera
+        inputs = {}
+        inputs["encoder_outputs"] = encoder_outputs
+        inputs["cls_tokens"] = cls_tokens
+        inputs["image"] = rgbs
+        if gt_intrinsics is not None:
+            rays, angles = generate_rays(
+                gt_intrinsics, self.image_shape, noisy=self.training
+            )
+            inputs["rays"] = rays
+            inputs["angles"] = angles
+            inputs["K"] = gt_intrinsics
+            self.pixel_decoder.test_fixed_camera = True
+            self.pixel_decoder.skip_camera = skip_camera
+
+        # decode all
+        pred_intrinsics, predictions, _, _ = self.pixel_decoder(inputs, {})
+
+        # undo the reshaping and get original image size (slow)
+        predictions, pred_intrinsics = _postprocess(
+            predictions,
+            pred_intrinsics,
+            self.image_shape,
+            (pad_left, pad_right, pad_top, pad_bottom),
+            ratio,
+            (H, W),
+        )
+
+        # final 3D points backprojection
+        intrinsics = gt_intrinsics if gt_intrinsics is not None else pred_intrinsics
+        angles = generate_rays(intrinsics, (H, W), noisy=False)[-1]
+        angles = rearrange(angles, "b (h w) c -> b c h w", h=H, w=W)
+        points_3d = torch.cat((angles, predictions), dim=1)
+        points_3d = spherical_zbuffer_to_euclidean(
+            points_3d.permute(0, 2, 3, 1)
+        ).permute(0, 3, 1, 2)
+
+        # output data
+        outputs = {
+            "intrinsics": pred_intrinsics,
+            "points": points_3d,
+            "depth": predictions[:, -1:],
+        }
+        self.pixel_decoder.test_fixed_camera = False
+        self.pixel_decoder.skip_camera = False
+        return outputs
+
+    def load_pretrained(self, model_file):
+        device = (
+            torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
+        )
+        dict_model = torch.load(model_file, map_location=device)
+
+        if "model" in dict_model:
+            dict_model = dict_model["model"]
+        new_state_dict = deepcopy(
+            {k.replace("module.", ""): v for k, v in dict_model.items()}
+        )
+
+        info = self.load_state_dict(new_state_dict, strict=False)
+        if is_main_process():
+            print(
+                f"Loaded from {model_file} for {self.__class__.__name__} results in:",
+                info,
+            )
+
+    def get_params(self, config):
+        if hasattr(self.pixel_encoder, "get_params"):
+            encoder_p, encoder_lr = self.pixel_encoder.get_params(
+                config["model"]["pixel_encoder"]["lr"],
+                config["training"]["wd"],
+                config["training"]["ld"],
+            )
+        else:
+            encoder_p, encoder_lr = get_params(
+                self.pixel_encoder,
+                config["model"]["pixel_encoder"]["lr"],
+                config["training"]["wd"],
+            )
+        decoder_p, decoder_lr = get_params(
+            self.pixel_decoder, config["training"]["lr"], config["training"]["wd"]
+        )
+        return [*encoder_p, *decoder_p], [*encoder_lr, *decoder_lr]
+
+    @property
+    def device(self):
+        return next(self.parameters()).device
+
+    def build(self, config: Dict[str, Dict[str, Any]]):
+        mod = importlib.import_module("unidepth.models.encoder")
+        pixel_encoder_factory = getattr(mod, config["model"]["pixel_encoder"]["name"])
+        pixel_encoder_config = {
+            **config["training"],
+            **config["data"],
+            **config["model"]["pixel_encoder"],
+        }
+        pixel_encoder = pixel_encoder_factory(pixel_encoder_config)
+
+        config["model"]["pixel_encoder"]["patch_size"] = (
+            14 if "dino" in config["model"]["pixel_encoder"]["name"] else 16
+        )
+        pixel_encoder_embed_dims = (
+            pixel_encoder.embed_dims
+            if hasattr(pixel_encoder, "embed_dims")
+            else [getattr(pixel_encoder, "embed_dim") * 2**i for i in range(4)]
+        )
+        config["model"]["pixel_encoder"]["embed_dim"] = getattr(
+            pixel_encoder, "embed_dim"
+        )
+        config["model"]["pixel_encoder"]["embed_dims"] = pixel_encoder_embed_dims
+        config["model"]["pixel_encoder"]["depths"] = pixel_encoder.depths
+
+        self.pixel_encoder = pixel_encoder
+        self.pixel_decoder = Decoder(config)
+        self.image_shape = config["data"]["image_shape"]
diff --git a/flash3d/unidepth/ops/__init__.py b/flash3d/unidepth/ops/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..412c242f25ec2e4c496cde55bbbf0b47a8580081
--- /dev/null
+++ b/flash3d/unidepth/ops/__init__.py
@@ -0,0 +1,9 @@
+from .losses import SILog, MSE, SelfCons
+from .scheduler import CosineScheduler
+
+__all__ = [
+    "SILog",
+    "MSE",
+    "SelfCons",
+    "CosineScheduler",
+]
diff --git a/flash3d/unidepth/ops/losses.py b/flash3d/unidepth/ops/losses.py
new file mode 100644
index 0000000000000000000000000000000000000000..6dda902708aca5b15e680d1dfd88bd040e68ef6f
--- /dev/null
+++ b/flash3d/unidepth/ops/losses.py
@@ -0,0 +1,429 @@
+"""
+Author: Luigi Piccinelli
+Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/)
+"""
+
+from typing import Any, Optional, Dict, Tuple, List
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+FNS = {
+    "sqrt": torch.sqrt,
+    "log": torch.log,
+    "log1": lambda x: torch.log(x + 1),
+    "linear": lambda x: x,
+    "square": torch.square,
+    "disp": lambda x: 1 / x,
+}
+
+
+FNS_INV = {
+    "sqrt": torch.square,
+    "log": torch.exp,
+    "log1": lambda x: torch.exp(x) - 1,
+    "linear": lambda x: x,
+    "square": torch.sqrt,
+    "disp": lambda x: 1 / x,
+}
+
+
+def masked_mean_var(data: torch.Tensor, mask: torch.Tensor, dim: List[int]):
+    if mask is None:
+        return data.mean(dim=dim, keepdim=True), data.var(dim=dim, keepdim=True)
+    mask = mask.float()
+    mask_sum = torch.sum(mask, dim=dim, keepdim=True)
+    mask_mean = torch.sum(data * mask, dim=dim, keepdim=True) / torch.clamp(
+        mask_sum, min=1.0
+    )
+    mask_var = torch.sum(
+        mask * (data - mask_mean) ** 2, dim=dim, keepdim=True
+    ) / torch.clamp(mask_sum, min=1.0)
+    return mask_mean.squeeze(dim), mask_var.squeeze(dim)
+
+
+def masked_mean(data: torch.Tensor, mask: torch.Tensor | None, dim: List[int]):
+    if mask is None:
+        return data.mean(dim=dim, keepdim=True)
+    mask = mask.float()
+    mask_sum = torch.sum(mask, dim=dim, keepdim=True)
+    mask_mean = torch.sum(data * mask, dim=dim, keepdim=True) / torch.clamp(
+        mask_sum, min=1.0
+    )
+    return mask_mean
+
+
+def masked_mae(data: torch.Tensor, mask: torch.Tensor, dim: Tuple[int, ...]):
+    if mask is None:
+        return data.abs().mean(dim=dim, keepdim=True)
+    mask = mask.float()
+    mask_sum = torch.sum(mask, dim=dim, keepdim=True)
+    mask_mean = torch.sum(data.abs() * mask, dim=dim, keepdim=True) / torch.clamp(
+        mask_sum, min=1.0
+    )
+    return mask_mean
+
+
+def masked_mse(data: torch.Tensor, mask: torch.Tensor, dim: Tuple[int, ...]):
+    if mask is None:
+        return (data**2).mean(dim=dim, keepdim=True)
+    mask = mask.float()
+    mask_sum = torch.sum(mask, dim=dim, keepdim=True)
+    mask_mean = torch.sum((data**2) * mask, dim=dim, keepdim=True) / torch.clamp(
+        mask_sum, min=1.0
+    )
+    return mask_mean
+
+
+def masked_median(data: torch.Tensor, mask: torch.Tensor, dim: List[int]):
+    ndim = data.ndim
+    data = data.flatten(ndim - len(dim))
+    mask = mask.flatten(ndim - len(dim))
+    mask_median = torch.median(data[mask], dim=-1).values
+    return mask_median
+
+
+def masked_median_mad(data: torch.Tensor, mask: torch.Tensor):
+    data = data.flatten()
+    mask = mask.flatten()
+    mask_median = torch.median(data[mask])
+    n_samples = torch.clamp(torch.sum(mask.float()), min=1.0)
+    mask_mad = torch.sum((data[mask] - mask_median).abs()) / n_samples
+    return mask_median, mask_mad
+
+
+def masked_weighted_mean_var(
+    data: torch.Tensor, mask: torch.Tensor, weights: torch.Tensor, dim: Tuple[int, ...]
+):
+    if mask is None:
+        return data.mean(dim=dim, keepdim=True), data.var(dim=dim, keepdim=True)
+    mask = mask.float()
+    mask_mean = torch.sum(data * mask * weights, dim=dim, keepdim=True) / torch.sum(
+        mask * weights, dim=dim, keepdim=True
+    ).clamp(min=1.0)
+    # V1**2 - V2, V1: sum w_i, V2: sum w_i**2
+    denom = torch.sum(weights * mask, dim=dim, keepdim=True).square() - torch.sum(
+        (mask * weights).square(), dim=dim, keepdim=True
+    )
+    # correction is V1 / (V1**2 - V2), if w_i=1 => N/(N**2 - N) => 1/(N-1) (unbiased estimator of variance, cvd)
+    correction_factor = torch.sum(mask * weights, dim=dim, keepdim=True) / denom.clamp(
+        min=1.0
+    )
+    mask_var = correction_factor * torch.sum(
+        weights * mask * (data - mask_mean) ** 2, dim=dim, keepdim=True
+    )
+    return mask_mean, mask_var
+
+
+def masked_mean_var_q(data: torch.Tensor, mask: torch.Tensor, dim: List[int]):
+    if mask is None:
+        return data.mean(dim=dim, keepdim=True), data.var(dim=dim, keepdim=True)
+    mask = mask.float()
+    mask_sum = torch.sum(mask, dim=dim, keepdim=True)
+    mask_mean = torch.sum(data * mask, dim=dim, keepdim=True) / torch.clamp(
+        mask_sum, min=1.0
+    )
+    mask_var = torch.sum(
+        mask * (data - mask_mean) ** 2, dim=dim, keepdim=True
+    ) / torch.clamp(mask_sum, min=1.0)
+    return mask_mean, mask_var
+
+
+class SILog(nn.Module):
+    def __init__(
+        self,
+        weight: float,
+        scale_pred_weight: float = 0.15,
+        output_fn: str = "sqrt",
+        input_fn: str = "log",
+        legacy: bool = False,
+        abs_rel: bool = False,
+        norm: bool = False,
+        eps: float = 1e-5,
+    ):
+        super().__init__()
+        assert output_fn in FNS
+        self.name: str = self.__class__.__name__
+        self.weight: float = weight
+
+        self.scale_pred_weight: float = scale_pred_weight
+        self.dims = (-4, -3, -2, -1) if legacy else (-2, -1)
+        self.output_fn = FNS[output_fn]
+        self.input_fn = FNS[input_fn]
+        self.abs_rel = abs_rel
+        self.norm = norm
+        self.eps: float = eps
+
+    @torch.cuda.amp.autocast(enabled=False)
+    def forward(
+        self,
+        input: torch.Tensor,
+        target: torch.Tensor,
+        mask: Optional[torch.Tensor] = None,
+        interpolate: bool = True,
+        scale_inv: torch.Tensor | None = None,
+        ss_inv: torch.Tensor | None = None,
+        **kwargs
+    ) -> torch.Tensor:
+        if interpolate:
+            input = F.interpolate(
+                input, target.shape[-2:], mode="bilinear", align_corners=False
+            )
+        if mask is not None:
+            mask = mask.to(torch.bool)
+        if ss_inv is not None:
+            ss_inv = ~ss_inv
+
+        if input.shape[1] > 1:
+            input_ = torch.cat(
+                [input[:, :-1], self.input_fn(input[:, -1:].clamp(min=self.eps))], dim=1
+            )
+            target_ = torch.cat(
+                [target[:, :-1], self.input_fn(target[:, -1:].clamp(min=self.eps))],
+                dim=1,
+            )
+            error = torch.norm(input_ - target_, dim=1, keepdim=True)
+        else:
+            input_ = self.input_fn(input.clamp(min=self.eps))
+            target_ = self.input_fn(target.clamp(min=self.eps))
+            error = input_ - target_
+
+        mean_error, var_error = masked_mean_var(data=error, mask=mask, dim=self.dims)
+
+        # prevoiusly was inverted!!
+        if self.abs_rel:
+            scale_error = (input - target).abs()[:, -1:] / target[:, -1:].clip(
+                min=self.eps
+            )
+            scale_error = masked_mean(data=scale_error, mask=mask, dim=self.dims)
+        else:
+            scale_error = mean_error**2
+
+        if var_error.ndim > 1:
+            var_error = var_error.sum(dim=1)
+            scale_error = scale_error.sum(dim=1)
+
+        # if scale inv -> mask scale error, if scale/shift, mask the full loss
+        if scale_inv is not None:
+            scale_error = (1 - scale_inv.int()) * scale_error
+        scale_error = self.scale_pred_weight * scale_error
+        loss = var_error + scale_error
+        out_loss = self.output_fn(loss.clamp(min=self.eps))
+        out_loss = masked_mean(data=out_loss, mask=ss_inv, dim=(0,))
+        return out_loss.mean()
+
+    @classmethod
+    def build(cls, config: Dict[str, Any]):
+        obj = cls(
+            weight=config["weight"],
+            legacy=config["legacy"],
+            output_fn=config["output_fn"],
+            input_fn=config["input_fn"],
+            norm=config.get("norm", False),
+            scale_pred_weight=config.get("gamma", 0.15),
+            abs_rel=config.get("abs_rel", False),
+        )
+        return obj
+
+
+class MSE(nn.Module):
+    def __init__(
+        self,
+        weight: float = 1.0,
+        input_fn: str = "linear",
+        output_fn: str = "linear",
+    ):
+        super().__init__()
+        self.name: str = self.__class__.__name__
+        self.output_fn = FNS[output_fn]
+        self.input_fn = FNS[input_fn]
+        self.weight: float = weight
+        self.eps = 1e-6
+
+    @torch.cuda.amp.autocast(enabled=False)
+    def forward(
+        self,
+        input: torch.Tensor,
+        target: torch.Tensor,
+        mask: torch.Tensor | None = None,
+        batch_mask: torch.Tensor | None = None,
+        **kwargs
+    ) -> torch.Tensor:
+        input = input[..., : target.shape[-1]]  # B N C or B H W C
+        error = self.input_fn(input + self.eps) - self.input_fn(target + self.eps)
+        abs_error = torch.square(error).sum(dim=-1)
+        mean_error = masked_mean(data=abs_error, mask=mask, dim=(-1,)).mean(dim=-1)
+        batched_error = masked_mean(
+            self.output_fn(mean_error.clamp(self.eps)), batch_mask, dim=(0,)
+        )
+        return batched_error.mean(), mean_error.detach()
+
+    @classmethod
+    def build(cls, config: Dict[str, Any]):
+        obj = cls(
+            weight=config["weight"],
+            output_fn=config["output_fn"],
+            input_fn=config["input_fn"],
+        )
+        return obj
+
+
+class SelfCons(nn.Module):
+    def __init__(
+        self,
+        weight: float,
+        scale_pred_weight: float = 0.15,
+        output_fn: str = "sqrt",
+        input_fn: str = "log",
+        abs_rel: bool = False,
+        norm: bool = False,
+        eps: float = 1e-5,
+    ):
+        super().__init__()
+        assert output_fn in FNS
+        self.name: str = self.__class__.__name__
+        self.weight: float = weight
+
+        self.scale_pred_weight: float = scale_pred_weight
+        self.dims = (-2, -1)
+        self.output_fn = FNS[output_fn]
+        self.input_fn = FNS[input_fn]
+        self.abs_rel = abs_rel
+        self.norm = norm
+        self.eps: float = eps
+
+    @torch.cuda.amp.autocast(enabled=False)
+    def forward(
+        self,
+        input: torch.Tensor,
+        mask: torch.Tensor,
+        metas: List[Dict[str, torch.Tensor]],
+    ) -> torch.Tensor:
+        chunks = input.shape[0] // 2
+        device = input.device
+        mask = F.interpolate(mask.float(), size=input.shape[-2:], mode="nearest")
+
+        rescales = input.shape[-2] / torch.tensor(
+            [x["resized_shape"][0] for x in metas], device=device
+        )
+        cams = torch.cat([x["K_target"] for x in metas], dim=0).to(device)
+        flips = torch.tensor([x["flip"] for x in metas], device=device)
+
+        iters = zip(
+            input.chunk(chunks),
+            mask.chunk(chunks),
+            cams.chunk(chunks),
+            rescales.chunk(chunks),
+            flips.chunk(chunks),
+        )
+        inputs0, inputs1, masks = [], [], []
+        for i, (pair_input, pair_mask, pair_cam, pair_rescale, pair_flip) in enumerate(
+            iters
+        ):
+            mask0, mask1 = pair_mask
+            input0, input1 = pair_input
+            cam0, cam1 = pair_cam
+            rescale0, rescale1 = pair_rescale
+            flip0, flip1 = pair_flip
+
+            fx_0 = cam0[0, 0] * rescale0
+            fx_1 = cam1[0, 0] * rescale1
+            cx_0 = (cam0[0, 2] - 0.5) * rescale0 + 0.5
+            cx_1 = (cam1[0, 2] - 0.5) * rescale1 + 0.5
+            cy_0 = (cam0[1, 2] - 0.5) * rescale0 + 0.5
+            cy_1 = (cam1[1, 2] - 0.5) * rescale1 + 0.5
+
+            # flip image
+            if flip0 ^ flip1:
+                input0 = torch.flip(input0, dims=(2,))
+                mask0 = torch.flip(mask0, dims=(2,))
+                cx_0 = input0.shape[-1] - cx_0
+
+            # calc zoom
+            zoom_x = float(fx_1 / fx_0)
+
+            # apply zoom
+            input0 = F.interpolate(
+                input0.unsqueeze(0),
+                scale_factor=zoom_x,
+                mode="bilinear",
+                align_corners=True,
+            ).squeeze(0)
+            mask0 = F.interpolate(
+                mask0.unsqueeze(0), scale_factor=zoom_x, mode="nearest"
+            ).squeeze(0)
+
+            # calc translation
+            change_left = int(cx_1 - (cx_0 - 0.5) * zoom_x - 0.5)
+            change_top = int(cy_1 - (cy_0 - 0.5) * zoom_x - 0.5)
+            change_right = input1.shape[-1] - change_left - input0.shape[-1]
+            change_bottom = input1.shape[-2] - change_top - input0.shape[-2]
+
+            # apply translation
+            pad_left = max(0, change_left)
+            pad_right = max(0, change_right)
+            pad_top = max(0, change_top)
+            pad_bottom = max(0, change_bottom)
+
+            crop_left = max(0, -change_left)
+            crop_right = max(0, -change_right)
+            crop_top = max(0, -change_top)
+            crop_bottom = max(0, -change_bottom)
+
+            input0 = F.pad(
+                input0,
+                (pad_left, pad_right, pad_top, pad_bottom),
+                mode="constant",
+                value=0,
+            )
+            mask0 = F.pad(
+                mask0,
+                (pad_left, pad_right, pad_top, pad_bottom),
+                mode="constant",
+                value=0,
+            )
+            input0 = input0[
+                :,
+                crop_top : input0.shape[-2] - crop_bottom,
+                crop_left : input0.shape[-1] - crop_right,
+            ]
+            mask0 = mask0[
+                :,
+                crop_top : mask0.shape[-2] - crop_bottom,
+                crop_left : mask0.shape[-1] - crop_right,
+            ]
+
+            mask = torch.logical_and(mask0, mask1)
+
+            inputs0.append(input0)
+            inputs1.append(input1)
+            masks.append(mask)
+
+        inputs0 = torch.stack(inputs0, dim=0)
+        inputs1 = torch.stack(inputs1, dim=0)
+        masks = torch.stack(masks, dim=0)
+        loss1 = self.loss(inputs0, inputs1.detach(), masks)
+        loss2 = self.loss(inputs1, inputs0.detach(), masks)
+        return torch.cat([loss1, loss2], dim=0).mean()
+
+    def loss(
+        self,
+        input: torch.Tensor,
+        target: torch.Tensor,
+        mask: torch.Tensor,
+    ) -> torch.Tensor:
+        loss = masked_mean(
+            (input - target).square().mean(dim=1), mask=mask, dim=(-2, -1)
+        )
+        return self.output_fn(loss + self.eps)
+
+    @classmethod
+    def build(cls, config: Dict[str, Any]):
+        obj = cls(
+            weight=config["weight"],
+            output_fn=config["output_fn"],
+            input_fn=config["input_fn"],
+        )
+        return obj
diff --git a/flash3d/unidepth/ops/scheduler.py b/flash3d/unidepth/ops/scheduler.py
new file mode 100644
index 0000000000000000000000000000000000000000..a182ff6e204ab445a67846314a8bea087119685e
--- /dev/null
+++ b/flash3d/unidepth/ops/scheduler.py
@@ -0,0 +1,70 @@
+"""
+Author: Luigi Piccinelli
+Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/)
+"""
+
+import numpy as np
+
+
+class CosineScheduler(object):
+    def __init__(
+        self,
+        optimizer,
+        warmup_iters,
+        total_iters,
+        key,
+        overwrite=False,
+        init_value=None,
+        base_value=None,
+        final_value=None,
+        step_init=-1,
+    ):
+        super().__init__()
+        self.iter = step_init
+        self.overwrite = overwrite
+        self.optimizer = optimizer
+        self.base_value = base_value
+        self.init_value = init_value
+        self.final_value = final_value
+        self.total_iters = total_iters
+        self.warmup_iters = warmup_iters
+        self.key = key
+        self.schedulers = [
+            self.get_schedulers(group) for group in optimizer.param_groups
+        ]
+
+    def get_schedulers(self, group):
+        init_value = group.get(self.key + "_init", self.init_value)
+        base_value = group.get(self.key + "_base", self.base_value)
+        final_value = group.get(self.key + "_final", self.final_value)
+        warmup_iters = self.warmup_iters
+        total_iters = self.total_iters
+        if self.overwrite:
+            final_value = self.final_value
+
+        # normalize in 0,1, then apply function (power) and denormalize
+        normalized_schedule = np.linspace(0, 1, warmup_iters, endpoint=True)
+        normalized_schedule = np.power(normalized_schedule, 2)
+        warmup_schedule = (base_value - init_value) * normalized_schedule + init_value
+
+        # main scheduling
+        iters = np.arange(total_iters - warmup_iters)
+        schedule = final_value + 0.5 * (base_value - final_value) * (
+            1 + np.cos(np.pi * iters / len(iters))
+        )
+        return np.concatenate((warmup_schedule, schedule))
+
+    def step(self):
+        self.iter = self.iter + 1
+        vals = self[self.iter]
+        for group, val in zip(self.optimizer.param_groups, vals):
+            if isinstance(group[self.key], (tuple, list)):
+                val = (val, *group[self.key][1:])
+            group[self.key] = val
+
+    def __getitem__(self, it):
+        it = min(it, self.total_iters - 1)
+        return [scheduler[it] for scheduler in self.schedulers]
+
+    def get(self):
+        return [group[self.key] for group in self.optimizer.param_groups]
diff --git a/flash3d/unidepth/utils/__init__.py b/flash3d/unidepth/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..4e3153806457cace0aa1ffd2820e60f4d87c8180
--- /dev/null
+++ b/flash3d/unidepth/utils/__init__.py
@@ -0,0 +1,35 @@
+from .evaluation_depth import eval_depth, DICT_METRICS
+from .visualization import colorize, image_grid, log_train_artifacts
+from .misc import format_seconds, remove_padding, get_params, identity
+from .distributed import (
+    is_main_process,
+    setup_multi_processes,
+    setup_slurm,
+    sync_tensor_across_gpus,
+    barrier,
+    get_rank,
+    get_dist_info,
+)
+from .geometric import unproject_points, spherical_zbuffer_to_euclidean
+
+__all__ = [
+    "eval_depth",
+    "DICT_METRICS",
+    "colorize",
+    "image_grid",
+    "log_train_artifacts",
+    "format_seconds",
+    "remove_padding",
+    "get_params",
+    "identity",
+    "is_main_process",
+    "setup_multi_processes",
+    "setup_slurm",
+    "sync_tensor_across_gpus",
+    "barrier",
+    "get_rank",
+    "unproject_points",
+    "spherical_zbuffer_to_euclidean",
+    "validate",
+    "get_dist_info",
+]
diff --git a/flash3d/unidepth/utils/constants.py b/flash3d/unidepth/utils/constants.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b23481335056a8bfe756bf6d0772bbf10c2ca22
--- /dev/null
+++ b/flash3d/unidepth/utils/constants.py
@@ -0,0 +1,21 @@
+"""
+Author: Luigi Piccinelli
+Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/)
+"""
+
+import math
+import torch
+
+OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
+OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)
+IMAGENET_DATASET_MEAN = (0.485, 0.456, 0.406)
+IMAGENET_DATASET_STD = (0.229, 0.224, 0.225)
+DEPTH_BINS = torch.cat(
+    (
+        torch.logspace(math.log10(0.1), math.log10(180.0), steps=512),
+        torch.tensor([260.0]),
+    ),
+    dim=0,
+)
+LOGERR_BINS = torch.linspace(-2, 2, steps=128 + 1)
+LINERR_BINS = torch.linspace(-50, 50, steps=256 + 1)
diff --git a/flash3d/unidepth/utils/distributed.py b/flash3d/unidepth/utils/distributed.py
new file mode 100644
index 0000000000000000000000000000000000000000..9bd8a501582808fe967f54f51fb645b667137d02
--- /dev/null
+++ b/flash3d/unidepth/utils/distributed.py
@@ -0,0 +1,179 @@
+"""
+Author: Luigi Piccinelli
+Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/)
+"""
+
+import os
+import platform
+import warnings
+import subprocess
+
+import cv2
+
+import torch
+import torch.utils.data.distributed
+from torch import multiprocessing as mp
+from torch import distributed as dist
+
+
+def is_dist_avail_and_initialized():
+    if not dist.is_available():
+        return False
+    if not dist.is_initialized():
+        return False
+    return True
+
+
+def get_rank():
+    if not is_dist_avail_and_initialized():
+        return 0
+    return dist.get_rank()
+
+
+def barrier():
+    if not is_dist_avail_and_initialized():
+        return
+    dist.barrier()
+
+
+def is_main_process():
+    return get_rank() == 0
+
+
+def is_rank_zero(args):
+    return args.rank == 0
+
+
+def get_dist_info():
+    if dist.is_available() and dist.is_initialized():
+        rank = dist.get_rank()
+        world_size = dist.get_world_size()
+    else:
+        rank = 0
+        world_size = 1
+    return rank, world_size
+
+
+def setup_multi_processes(cfg):
+    """Setup multi-processing environment variables."""
+    # set multi-process start method as `fork` to speed up the training
+    if platform.system() != "Windows":
+        mp_start_method = cfg.get("mp_start_method", "fork")
+        current_method = mp.get_start_method(allow_none=True)
+        if current_method is not None and current_method != mp_start_method:
+            warnings.warn(
+                f"Multi-processing start method `{mp_start_method}` is "
+                f"different from the previous setting `{current_method}`."
+                f"It will be force set to `{mp_start_method}`. You can change "
+                f"this behavior by changing `mp_start_method` in your config."
+            )
+        mp.set_start_method(mp_start_method, force=True)
+
+    # disable opencv multithreading to avoid system being overloaded
+    opencv_num_threads = cfg.get("opencv_num_threads", 0)
+    cv2.setNumThreads(opencv_num_threads)
+
+    # setup OMP threads
+    # This code is referred from https://github.com/pytorch/pytorch/blob/master/torch/distributed/run.py  # noqa
+    workers_per_gpu = cfg.get("workers_per_gpu", 4)
+
+    if "OMP_NUM_THREADS" not in os.environ and workers_per_gpu > 1:
+        omp_num_threads = 1
+        warnings.warn(
+            f"Setting OMP_NUM_THREADS environment variable for each process "
+            f"to be {omp_num_threads} in default, to avoid your system being "
+            f"overloaded, please further tune the variable for optimal "
+            f"performance in your application as needed."
+        )
+        os.environ["OMP_NUM_THREADS"] = str(omp_num_threads)
+
+    # setup MKL threads
+    if "MKL_NUM_THREADS" not in os.environ and workers_per_gpu > 1:
+        mkl_num_threads = os.environ.get("OMP_NUM_THREADS", 1)
+        warnings.warn(
+            f"Setting MKL_NUM_THREADS environment variable for each process "
+            f"to be {mkl_num_threads} in default, to avoid your system being "
+            f"overloaded, please further tune the variable for optimal "
+            f"performance in your application as needed."
+        )
+        os.environ["MKL_NUM_THREADS"] = str(mkl_num_threads)
+
+
+def setup_slurm(backend: str, port: str) -> None:
+    """Initialize slurm distributed training environment.
+    If argument ``port`` is not specified, then the master port will be system
+    environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system
+    environment variable, then a default port ``29500`` will be used.
+    Args:
+        backend (str): Backend of torch.distributed.
+        port (int, optional): Master port. Defaults to None.
+    """
+    proc_id = int(os.environ["SLURM_PROCID"])
+    ntasks = int(os.environ["SLURM_NTASKS"])
+    node_list = os.environ["SLURM_NODELIST"]
+
+    num_gpus = torch.cuda.device_count()
+
+    torch.cuda.set_device(proc_id % num_gpus)
+    addr = subprocess.getoutput(f"scontrol show hostname {node_list} | head -n1")
+    os.environ["MASTER_PORT"] = str(port)
+    os.environ["MASTER_ADDR"] = addr
+    os.environ["WORLD_SIZE"] = str(ntasks)
+    os.environ["LOCAL_RANK"] = str(proc_id % num_gpus)
+    os.environ["RANK"] = str(proc_id)
+    print(
+        proc_id,
+        ntasks,
+        num_gpus,
+        proc_id % num_gpus,
+        node_list,
+        addr,
+        os.environ["MASTER_PORT"],
+        os.system("nvidia-smi -L"),
+    )
+    dist.init_process_group(backend, rank=proc_id, world_size=ntasks)
+
+
+def sync_tensor_across_gpus(t, dim=0, cat=True):
+    if t is None or not (dist.is_available() and dist.is_initialized()):
+        return t
+    t = torch.atleast_1d(t)
+    group = dist.group.WORLD
+    group_size = torch.distributed.get_world_size(group)
+
+    local_size = torch.tensor(t.size(dim), device=t.device)
+    all_sizes = [torch.zeros_like(local_size) for _ in range(group_size)]
+    dist.all_gather(all_sizes, local_size)
+    max_size = max(all_sizes)
+    size_diff = max_size.item() - local_size.item()
+    if size_diff:
+        padding = torch.zeros(size_diff, device=t.device, dtype=t.dtype)
+        t = torch.cat((t, padding))
+
+    gather_t_tensor = [torch.zeros_like(t) for _ in range(group_size)]
+    dist.all_gather(gather_t_tensor, t)
+    all_ts = []
+    for t, size in zip(gather_t_tensor, all_sizes):
+        all_ts.append(t[:size])
+    if cat:
+        return torch.cat(all_ts, dim=0)
+    return all_ts
+
+
+import pickle
+
+
+def sync_string_across_gpus(keys: list[str], device, dim=0):
+    keys_serialized = pickle.dumps(keys, protocol=pickle.HIGHEST_PROTOCOL)
+    keys_serialized_tensor = torch.frombuffer(keys_serialized, dtype=torch.uint8).to(
+        device
+    )
+    keys_serialized_tensor = sync_tensor_across_gpus(
+        keys_serialized_tensor, dim=0, cat=False
+    )
+    keys = [
+        key
+        for keys in keys_serialized_tensor
+        for key in pickle.loads(bytes(keys.cpu().tolist()))
+    ]
+    return keys
diff --git a/flash3d/unidepth/utils/ema_torch.py b/flash3d/unidepth/utils/ema_torch.py
new file mode 100644
index 0000000000000000000000000000000000000000..9a37461f017fb220c88c82416eb3d2c371b5a9e8
--- /dev/null
+++ b/flash3d/unidepth/utils/ema_torch.py
@@ -0,0 +1,342 @@
+"""
+Author: Luigi Piccinelli
+Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/)
+"""
+
+from __future__ import division
+from __future__ import unicode_literals
+
+from typing import Iterable, Optional
+import weakref
+import copy
+import contextlib
+from math import tanh
+
+import torch
+
+
+class DummyExponentialMovingAverage:
+    def __init__(self, *args, **kwargs):
+        pass
+
+    def _get_parameters(self, *args, **kwargs):
+        pass
+
+    def get_current_decay(self, *args, **kwargs):
+        pass
+
+    def update(self, *args, **kwargs):
+        pass
+
+    def copy_to(self, *args, **kwargs):
+        pass
+
+    def store(self, *args, **kwargs):
+        return
+
+    def restore(self, *args, **kwargs):
+        return
+
+    @contextlib.contextmanager
+    def average_parameters(self, *args, **kwargs):
+        try:
+            yield
+        finally:
+            pass
+
+    def to(self, *args, **kwargs):
+        pass
+
+    def state_dict(self, *args, **kwargs):
+        pass
+
+    def load_state_dict(self, *args, **kwargs):
+        pass
+
+
+class ExponentialMovingAverage:
+    """
+    Maintains (exponential) moving average of a set of parameters.
+
+    Args:
+        parameters: Iterable of `torch.nn.Parameter` (typically from
+            `model.parameters()`).
+            Note that EMA is computed on *all* provided parameters,
+            regardless of whether or not they have `requires_grad = True`;
+            this allows a single EMA object to be consistantly used even
+            if which parameters are trainable changes step to step.
+
+            If you want to some parameters in the EMA, do not pass them
+            to the object in the first place. For example:
+
+                ExponentialMovingAverage(
+                    parameters=[p for p in model.parameters() if p.requires_grad],
+                    decay=0.9
+                )
+
+            will ignore parameters that do not require grad.
+
+        decay: The exponential decay.
+
+        use_num_updates: Whether to use number of updates when computing
+            averages.
+    """
+
+    def __init__(
+        self,
+        parameters: Iterable[torch.nn.Parameter],
+        decay: float,
+        use_num_updates: bool = True,
+        update_after_step: int = 10000,
+        tau: int = 20000,
+        switch: bool = False,
+    ):
+        if decay < 0.0 or decay > 1.0:
+            raise ValueError("Decay must be between 0 and 1")
+        self.decay = decay
+        self.switch = switch  # fi keeping EMA params in model after epochs
+        self.num_updates = 0 if use_num_updates else None
+        parameters = list(parameters)
+        self.shadow_params = [p.clone().detach() for p in parameters]
+        self.collected_params = None
+        # By maintaining only a weakref to each parameter,
+        # we maintain the old GC behaviour of ExponentialMovingAverage:
+        # if the model goes out of scope but the ExponentialMovingAverage
+        # is kept, no references to the model or its parameters will be
+        # maintained, and the model will be cleaned up.
+        self._params_refs = [weakref.ref(p) for p in parameters]
+        self.update_after_step = update_after_step
+        self.tau = tau
+
+    def _get_parameters(
+        self, parameters: Optional[Iterable[torch.nn.Parameter]]
+    ) -> Iterable[torch.nn.Parameter]:
+        if parameters is None:
+            parameters = [p() for p in self._params_refs]
+            if any(p is None for p in parameters):
+                raise ValueError(
+                    "(One of) the parameters with which this ExponentialMovingAverage was initialized no longer exists (was garbage collected);"
+                    " please either provide `parameters` explicitly or keep the model to which they belong from being garbage collected."
+                )
+            return parameters
+        else:
+            parameters = list(parameters)
+            if len(parameters) != len(self.shadow_params):
+                raise ValueError(
+                    "Number of parameters passed as argument is different "
+                    "from number of shadow parameters maintained by this "
+                    "ExponentialMovingAverage"
+                )
+            return parameters
+
+    def get_current_decay(self):
+        epoch = max(self.num_updates - self.update_after_step - 1, 0.0)
+        if epoch <= 0:
+            return 0.0
+        value = tanh(epoch / self.tau) * self.decay
+        return value
+
+    def update(self, parameters: Optional[Iterable[torch.nn.Parameter]] = None) -> None:
+        """
+        Update currently maintained parameters.
+
+        Call this every time the parameters are updated, such as the result of
+        the `optimizer.step()` call.
+
+        Args:
+            parameters: Iterable of `torch.nn.Parameter`; usually the same set of
+                parameters used to initialize this object. If `None`, the
+                parameters with which this `ExponentialMovingAverage` was
+                initialized will be used.
+        """
+        parameters = self._get_parameters(parameters)
+        decay = self.get_current_decay()
+        if self.num_updates is not None:
+            self.num_updates += 1
+
+        one_minus_decay = 1.0 - decay
+        with torch.no_grad():
+            for s_param, param in zip(self.shadow_params, parameters):
+                tmp = s_param - param
+                # tmp will be a new tensor so we can do in-place
+                tmp.mul_(one_minus_decay)
+                s_param.sub_(tmp)
+
+    def copy_to(
+        self, parameters: Optional[Iterable[torch.nn.Parameter]] = None
+    ) -> None:
+        """
+        Copy current averaged parameters into given collection of parameters.
+
+        Args:
+            parameters: Iterable of `torch.nn.Parameter`; the parameters to be
+                updated with the stored moving averages. If `None`, the
+                parameters with which this `ExponentialMovingAverage` was
+                initialized will be used.
+        """
+        parameters = self._get_parameters(parameters)
+        for s_param, param in zip(self.shadow_params, parameters):
+            param.data.copy_(s_param.data)
+
+    def store(self, parameters: Optional[Iterable[torch.nn.Parameter]] = None) -> None:
+        """
+        Save the current parameters for restoring later.
+
+        Args:
+            parameters: Iterable of `torch.nn.Parameter`; the parameters to be
+                temporarily stored. If `None`, the parameters of with which this
+                `ExponentialMovingAverage` was initialized will be used.
+        """
+        parameters = self._get_parameters(parameters)
+        self.collected_params = [param.detach().clone() for param in parameters]
+
+    def restore(
+        self, parameters: Optional[Iterable[torch.nn.Parameter]] = None
+    ) -> None:
+        """
+        Restore the parameters stored with the `store` method.
+        Useful to validate the model with EMA parameters without affecting the
+        original optimization process. Store the parameters before the
+        `copy_to` method. After validation (or model saving), use this to
+        restore the former parameters.
+
+        Args:
+            parameters: Iterable of `torch.nn.Parameter`; the parameters to be
+                updated with the stored parameters. If `None`, the
+                parameters with which this `ExponentialMovingAverage` was
+                initialized will be used.
+        """
+        if self.collected_params is None:
+            raise RuntimeError(
+                "This ExponentialMovingAverage has no `store()`ed weights "
+                "to `restore()`"
+            )
+        parameters = self._get_parameters(parameters)
+        for c_param, param in zip(self.collected_params, parameters):
+            param.data.copy_(c_param.data)
+
+    @contextlib.contextmanager
+    def average_parameters(
+        self, parameters: Optional[Iterable[torch.nn.Parameter]] = None
+    ):
+        r"""
+        Context manager for validation/inference with averaged parameters.
+
+        Equivalent to:
+
+            ema.store()
+            ema.copy_to()
+            try:
+                ...
+            finally:
+                ema.restore()
+
+        Args:
+            parameters: Iterable of `torch.nn.Parameter`; the parameters to be
+                updated with the stored parameters. If `None`, the
+                parameters with which this `ExponentialMovingAverage` was
+                initialized will be used.
+        """
+        parameters = self._get_parameters(parameters)
+        self.store(parameters)
+        self.copy_to(parameters)
+        try:
+            yield
+        finally:
+            if not self.switch:
+                self.restore(parameters)
+
+    def to(self, device=None, dtype=None) -> None:
+        r"""Move internal buffers of the ExponentialMovingAverage to `device`.
+
+        Args:
+            device: like `device` argument to `torch.Tensor.to`
+        """
+        # .to() on the tensors handles None correctly
+        self.shadow_params = [
+            (
+                p.to(device=device, dtype=dtype)
+                if p.is_floating_point()
+                else p.to(device=device)
+            )
+            for p in self.shadow_params
+        ]
+        if self.collected_params is not None:
+            self.collected_params = [
+                (
+                    p.to(device=device, dtype=dtype)
+                    if p.is_floating_point()
+                    else p.to(device=device)
+                )
+                for p in self.collected_params
+            ]
+        return
+
+    def state_dict(self) -> dict:
+        r"""Returns the state of the ExponentialMovingAverage as a dict."""
+        # Following PyTorch conventions, references to tensors are returned:
+        # "returns a reference to the state and not its copy!" -
+        # https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict
+        return {
+            "decay": self.decay,
+            "num_updates": self.num_updates,
+            "shadow_params": self.shadow_params,
+            "collected_params": self.collected_params,
+        }
+
+    def load_state_dict(self, state_dict: dict) -> None:
+        r"""Loads the ExponentialMovingAverage state.
+
+        Args:
+            state_dict (dict): EMA state. Should be an object returned
+                from a call to :meth:`state_dict`.
+        """
+        # deepcopy, to be consistent with module API
+        state_dict = copy.deepcopy(state_dict)
+        self.decay = state_dict["decay"]
+        if self.decay < 0.0 or self.decay > 1.0:
+            raise ValueError("Decay must be between 0 and 1")
+        self.num_updates = state_dict["num_updates"]
+        assert self.num_updates is None or isinstance(
+            self.num_updates, int
+        ), "Invalid num_updates"
+
+        self.shadow_params = state_dict["shadow_params"]
+        assert isinstance(self.shadow_params, list), "shadow_params must be a list"
+        assert all(
+            isinstance(p, torch.Tensor) for p in self.shadow_params
+        ), "shadow_params must all be Tensors"
+
+        self.collected_params = state_dict["collected_params"]
+        if self.collected_params is not None:
+            assert isinstance(
+                self.collected_params, list
+            ), "collected_params must be a list"
+            assert all(
+                isinstance(p, torch.Tensor) for p in self.collected_params
+            ), "collected_params must all be Tensors"
+            assert len(self.collected_params) == len(
+                self.shadow_params
+            ), "collected_params and shadow_params had different lengths"
+
+        if len(self.shadow_params) == len(self._params_refs):
+            # Consistant with torch.optim.Optimizer, cast things to consistant
+            # device and dtype with the parameters
+            params = [p() for p in self._params_refs]
+            # If parameters have been garbage collected, just load the state
+            # we were given without change.
+            if not any(p is None for p in params):
+                # ^ parameter references are still good
+                for i, p in enumerate(params):
+                    self.shadow_params[i] = self.shadow_params[i].to(
+                        device=p.device, dtype=p.dtype
+                    )
+                    if self.collected_params is not None:
+                        self.collected_params[i] = self.collected_params[i].to(
+                            device=p.device, dtype=p.dtype
+                        )
+        else:
+            raise ValueError(
+                "Tried to `load_state_dict()` with the wrong number of "
+                "parameters in the saved state."
+            )
diff --git a/flash3d/unidepth/utils/evaluation_depth.py b/flash3d/unidepth/utils/evaluation_depth.py
new file mode 100644
index 0000000000000000000000000000000000000000..9f84ca0591efae3f13f78ab4e10b5069ae2f74eb
--- /dev/null
+++ b/flash3d/unidepth/utils/evaluation_depth.py
@@ -0,0 +1,173 @@
+"""
+Author: Luigi Piccinelli
+Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/)
+"""
+# We prefer not to install PyTorch3D in the package
+# Code commented is how 3D metrics are computed
+
+from collections import defaultdict
+from functools import partial
+
+import torch
+import torch.nn.functional as F
+
+# from chamfer_distance import ChamferDistance
+
+from unidepth.utils.constants import DEPTH_BINS
+
+
+# chamfer_cls = ChamferDistance()
+
+
+# def chamfer_dist(tensor1, tensor2):
+#     x_lengths = torch.tensor((tensor1.shape[1],), device=tensor1.device)
+#     y_lengths = torch.tensor((tensor2.shape[1],), device=tensor2.device)
+#     dist1, dist2, idx1, idx2 = chamfer_cls(
+#         tensor1, tensor2, x_lengths=x_lengths, y_lengths=y_lengths
+#     )
+#     return (torch.sqrt(dist1) + torch.sqrt(dist2)) / 2
+
+
+# def auc(tensor1, tensor2, thresholds):
+#     x_lengths = torch.tensor((tensor1.shape[1],), device=tensor1.device)
+#     y_lengths = torch.tensor((tensor2.shape[1],), device=tensor2.device)
+#     dist1, dist2, idx1, idx2 = chamfer_cls(
+#         tensor1, tensor2, x_lengths=x_lengths, y_lengths=y_lengths
+#     )
+#     # compute precision recall
+#     precisions = [(dist1 < threshold).sum() / dist1.numel() for threshold in thresholds]
+#     recalls = [(dist2 < threshold).sum() / dist2.numel() for threshold in thresholds]
+#     auc_value = torch.trapz(
+#         torch.tensor(precisions, device=tensor1.device),
+#         torch.tensor(recalls, device=tensor1.device),
+#     )
+#     return auc_value
+
+
+def delta(tensor1, tensor2, exponent):
+    inlier = torch.maximum((tensor1 / tensor2), (tensor2 / tensor1))
+    return (inlier < 1.25**exponent).to(torch.float32).mean()
+
+
+def ssi(tensor1, tensor2, qtl=0.05):
+    stability_mat = 1e-9 * torch.eye(2, device=tensor1.device)
+    error = (tensor1 - tensor2).abs()
+    mask = error < torch.quantile(error, 1 - qtl)
+    tensor1_mask = tensor1[mask]
+    tensor2_mask = tensor2[mask]
+    tensor2_one = torch.stack(
+        [tensor2_mask.detach(), torch.ones_like(tensor2_mask).detach()], dim=1
+    )
+    scale_shift = torch.inverse(tensor2_one.T @ tensor2_one + stability_mat) @ (
+        tensor2_one.T @ tensor1_mask.unsqueeze(1)
+    )
+    scale, shift = scale_shift.squeeze().chunk(2, dim=0)
+    return tensor2 * scale + shift
+    # tensor2_one = torch.stack([tensor2.detach(), torch.ones_like(tensor2).detach()], dim=1)
+    # scale_shift = torch.inverse(tensor2_one.T @ tensor2_one + stability_mat) @ (tensor2_one.T @ tensor1.unsqueeze(1))
+    # scale, shift = scale_shift.squeeze().chunk(2, dim=0)
+    # return tensor2 * scale + shift
+
+
+def d1_ssi(tensor1, tensor2):
+    delta_ = delta(tensor1, ssi(tensor1, tensor2), 1.0)
+    return delta_
+
+
+def d_auc(tensor1, tensor2):
+    exponents = torch.linspace(0.01, 5.0, steps=100, device=tensor1.device)
+    deltas = [delta(tensor1, tensor2, exponent) for exponent in exponents]
+    return torch.trapz(torch.tensor(deltas, device=tensor1.device), exponents) / 5.0
+
+
+# def f1_score(tensor1, tensor2, thresholds):
+#     x_lengths = torch.tensor((tensor1.shape[1],), device=tensor1.device)
+#     y_lengths = torch.tensor((tensor2.shape[1],), device=tensor2.device)
+#     dist1, dist2, idx1, idx2 = chamfer_cls(
+#         tensor1, tensor2, x_lengths=x_lengths, y_lengths=y_lengths
+#     )
+#     # compute precision recall
+#     precisions = [(dist1 < threshold).sum() / dist1.numel() for threshold in thresholds]
+#     recalls = [(dist2 < threshold).sum() / dist2.numel() for threshold in thresholds]
+#     precisions = torch.tensor(precisions, device=tensor1.device)
+#     recalls = torch.tensor(recalls, device=tensor1.device)
+#     f1_thresholds = 2 * precisions * recalls / (precisions + recalls)
+#     f1_thresholds = torch.where(
+#         torch.isnan(f1_thresholds), torch.zeros_like(f1_thresholds), f1_thresholds
+#     )
+#     f1_value = torch.trapz(f1_thresholds) / len(thresholds)
+#     return f1_value
+
+
+DICT_METRICS = {
+    "d1": partial(delta, exponent=1.0),
+    "d2": partial(delta, exponent=2.0),
+    "d3": partial(delta, exponent=3.0),
+    "rmse": lambda gt, pred: torch.sqrt(((gt - pred) ** 2).mean()),
+    "rmselog": lambda gt, pred: torch.sqrt(
+        ((torch.log(gt) - torch.log(pred)) ** 2).mean()
+    ),
+    "arel": lambda gt, pred: (torch.abs(gt - pred) / gt).mean(),
+    "sqrel": lambda gt, pred: (((gt - pred) ** 2) / gt).mean(),
+    "log10": lambda gt, pred: torch.abs(torch.log10(pred) - torch.log10(gt)).mean(),
+    "silog": lambda gt, pred: 100 * torch.std(torch.log(pred) - torch.log(gt)).mean(),
+    "medianlog": lambda gt, pred: 100
+    * (torch.log(pred) - torch.log(gt)).median().abs(),
+    "d_auc": d_auc,
+    "d1_ssi": d1_ssi,
+}
+
+
+# DICT_METRICS_3D = {
+#     "chamfer": lambda gt, pred, thresholds: chamfer_dist(
+#         gt.unsqueeze(0).permute(0, 2, 1), pred.unsqueeze(0).permute(0, 2, 1)
+#     ),
+#     "F1": lambda gt, pred, thresholds: f1_score(
+#         gt.unsqueeze(0).permute(0, 2, 1),
+#         pred.unsqueeze(0).permute(0, 2, 1),
+#         thresholds=thresholds,
+#     ),
+# }
+
+
+DICT_METRICS_D = {
+    "a1": lambda gt, pred: (torch.maximum((gt / pred), (pred / gt)) > 1.25**1.0).to(
+        torch.float32
+    ),
+    "abs_rel": lambda gt, pred: (torch.abs(gt - pred) / gt),
+}
+
+
+def eval_depth(
+    gts: torch.Tensor, preds: torch.Tensor, masks: torch.Tensor, max_depth=None
+):
+    summary_metrics = defaultdict(list)
+    preds = F.interpolate(preds, gts.shape[-2:], mode="bilinear")
+    for i, (gt, pred, mask) in enumerate(zip(gts, preds, masks)):
+        if max_depth is not None:
+            mask = torch.logical_and(mask, gt <= max_depth)
+        for name, fn in DICT_METRICS.items():
+            summary_metrics[name].append(fn(gt[mask], pred[mask]).mean())
+    return {name: torch.stack(vals, dim=0) for name, vals in summary_metrics.items()}
+
+
+# def eval_3d(
+#     gts: torch.Tensor, preds: torch.Tensor, masks: torch.Tensor, thresholds=None
+# ):
+#     summary_metrics = defaultdict(list)
+#     w_max = min(gts.shape[-1] // 4, 400)
+#     gts = F.interpolate(
+#         gts, (int(w_max * gts.shape[-2] / gts.shape[-1]), w_max), mode="nearest"
+#     )
+#     preds = F.interpolate(preds, gts.shape[-2:], mode="nearest")
+#     masks = F.interpolate(
+#         masks.to(torch.float32), gts.shape[-2:], mode="nearest"
+#     ).bool()
+#     for i, (gt, pred, mask) in enumerate(zip(gts, preds, masks)):
+#         if not torch.any(mask):
+#             continue
+#         for name, fn in DICT_METRICS_3D.items():
+#             summary_metrics[name].append(
+#                 fn(gt[:, mask.squeeze()], pred[:, mask.squeeze()], thresholds).mean()
+#             )
+#     return {name: torch.stack(vals, dim=0) for name, vals in summary_metrics.items()}
diff --git a/flash3d/unidepth/utils/geometric.py b/flash3d/unidepth/utils/geometric.py
new file mode 100644
index 0000000000000000000000000000000000000000..b942beb1af3c58e8910a1e113bfd974a1fd169ce
--- /dev/null
+++ b/flash3d/unidepth/utils/geometric.py
@@ -0,0 +1,248 @@
+"""
+Author: Luigi Piccinelli
+Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/)
+"""
+
+from typing import Tuple
+
+import torch
+from torch.nn import functional as F
+
+
+def generate_rays(
+    camera_intrinsics: torch.Tensor, image_shape: Tuple[int, int], noisy: bool = False
+):
+    batch_size, device, dtype = (
+        camera_intrinsics.shape[0],
+        camera_intrinsics.device,
+        camera_intrinsics.dtype,
+    )
+    height, width = image_shape
+    # Generate grid of pixel coordinates
+    pixel_coords_x = torch.linspace(0, width - 1, width, device=device, dtype=dtype)
+    pixel_coords_y = torch.linspace(0, height - 1, height, device=device, dtype=dtype)
+    if noisy:
+        pixel_coords_x += torch.rand_like(pixel_coords_x) - 0.5
+        pixel_coords_y += torch.rand_like(pixel_coords_y) - 0.5
+    pixel_coords = torch.stack(
+        [pixel_coords_x.repeat(height, 1), pixel_coords_y.repeat(width, 1).t()], dim=2
+    )  # (H, W, 2)
+    pixel_coords = pixel_coords + 0.5
+
+    # Calculate ray directions
+    intrinsics_inv = torch.inverse(camera_intrinsics.float()).to(dtype)  # (B, 3, 3)
+    homogeneous_coords = torch.cat(
+        [pixel_coords, torch.ones_like(pixel_coords[:, :, :1])], dim=2
+    )  # (H, W, 3)
+    ray_directions = torch.matmul(
+        intrinsics_inv, homogeneous_coords.permute(2, 0, 1).flatten(1)
+    )  # (3, H*W)
+    ray_directions = F.normalize(ray_directions, dim=1)  # (B, 3, H*W)
+    ray_directions = ray_directions.permute(0, 2, 1)  # (B, H*W, 3)
+
+    theta = torch.atan2(ray_directions[..., 0], ray_directions[..., -1])
+    phi = torch.acos(ray_directions[..., 1])
+    # pitch = torch.asin(ray_directions[..., 1])
+    # roll = torch.atan2(ray_directions[..., 0], - ray_directions[..., 1])
+    angles = torch.stack([theta, phi], dim=-1)
+    return ray_directions, angles
+
+
+@torch.jit.script
+def spherical_zbuffer_to_euclidean(spherical_tensor: torch.Tensor) -> torch.Tensor:
+    theta = spherical_tensor[..., 0]  # Extract polar angle
+    phi = spherical_tensor[..., 1]  # Extract azimuthal angle
+    z = spherical_tensor[..., 2]  # Extract zbuffer depth
+
+    # y = r * cos(phi)
+    # x = r * sin(phi) * sin(theta)
+    # z = r * sin(phi) * cos(theta)
+    # =>
+    # r = z / sin(phi) / cos(theta)
+    # y = z / (sin(phi) / cos(phi)) / cos(theta)
+    # x = z * sin(theta) / cos(theta)
+    x = z * torch.tan(theta)
+    y = z / torch.tan(phi) / torch.cos(theta)
+
+    euclidean_tensor = torch.stack((x, y, z), dim=-1)
+    return euclidean_tensor
+
+
+@torch.jit.script
+def spherical_to_euclidean(spherical_tensor: torch.Tensor) -> torch.Tensor:
+    theta = spherical_tensor[..., 0]  # Extract polar angle
+    phi = spherical_tensor[..., 1]  # Extract azimuthal angle
+    r = spherical_tensor[..., 2]  # Extract radius
+    # y = r * cos(phi)
+    # x = r * sin(phi) * sin(theta)
+    # z = r * sin(phi) * cos(theta)
+    x = r * torch.sin(phi) * torch.sin(theta)
+    y = r * torch.cos(phi)
+    z = r * torch.cos(theta) * torch.sin(phi)
+
+    euclidean_tensor = torch.stack((x, y, z), dim=-1)
+    return euclidean_tensor
+
+
+@torch.jit.script
+def euclidean_to_spherical(spherical_tensor: torch.Tensor) -> torch.Tensor:
+    x = spherical_tensor[..., 0]  # Extract polar angle
+    y = spherical_tensor[..., 1]  # Extract azimuthal angle
+    z = spherical_tensor[..., 2]  # Extract radius
+    # y = r * cos(phi)
+    # x = r * sin(phi) * sin(theta)
+    # z = r * sin(phi) * cos(theta)
+    r = torch.sqrt(x**2 + y**2 + z**2)
+    theta = torch.atan2(x / r, z / r)
+    phi = torch.acos(y / r)
+
+    euclidean_tensor = torch.stack((theta, phi, r), dim=-1)
+    return euclidean_tensor
+
+
+@torch.jit.script
+def euclidean_to_spherical_zbuffer(euclidean_tensor: torch.Tensor) -> torch.Tensor:
+    pitch = torch.asin(euclidean_tensor[..., 1])
+    yaw = torch.atan2(euclidean_tensor[..., 0], euclidean_tensor[..., -1])
+    z = euclidean_tensor[..., 2]  # Extract zbuffer depth
+    euclidean_tensor = torch.stack((pitch, yaw, z), dim=-1)
+    return euclidean_tensor
+
+
+@torch.jit.script
+def unproject_points(
+    depth: torch.Tensor, camera_intrinsics: torch.Tensor
+) -> torch.Tensor:
+    """
+    Unprojects a batch of depth maps to 3D point clouds using camera intrinsics.
+
+    Args:
+        depth (torch.Tensor): Batch of depth maps of shape (B, 1, H, W).
+        camera_intrinsics (torch.Tensor): Camera intrinsic matrix of shape (B, 3, 3).
+
+    Returns:
+        torch.Tensor: Batch of 3D point clouds of shape (B, 3, H, W).
+    """
+    batch_size, _, height, width = depth.shape
+    device = depth.device
+
+    # Create pixel grid
+    y_coords, x_coords = torch.meshgrid(
+        torch.arange(height, device=device),
+        torch.arange(width, device=device),
+        indexing="ij",
+    )
+    pixel_coords = torch.stack((x_coords, y_coords), dim=-1)  # (H, W, 2)
+
+    # Get homogeneous coords (u v 1)
+    pixel_coords_homogeneous = torch.cat(
+        (pixel_coords, torch.ones((height, width, 1), device=device)), dim=-1
+    )
+    pixel_coords_homogeneous = pixel_coords_homogeneous.permute(2, 0, 1).flatten(
+        1
+    )  # (3, H*W)
+    # Apply K^-1 @ (u v 1): [B, 3, 3] @ [3, H*W] -> [B, 3, H*W]
+    unprojected_points = torch.matmul(
+        torch.inverse(camera_intrinsics), pixel_coords_homogeneous
+    )  # (B, 3, H*W)
+    unprojected_points = unprojected_points.view(
+        batch_size, 3, height, width
+    )  # (B, 3, H, W)
+    unprojected_points = unprojected_points * depth  # (B, 3, H, W)
+    return unprojected_points
+
+
+@torch.jit.script
+def project_points(
+    points_3d: torch.Tensor,
+    intrinsic_matrix: torch.Tensor,
+    image_shape: Tuple[int, int],
+) -> torch.Tensor:
+    # Project 3D points onto the image plane via intrinsics (u v w) = (x y z) @ K^T
+    points_2d = torch.matmul(points_3d, intrinsic_matrix.transpose(1, 2))
+
+    # Normalize projected points: (u v w) -> (u / w, v / w, 1)
+    points_2d = points_2d[..., :2] / points_2d[..., 2:]
+
+    # To pixels (rounding!!!), no int as it breaks gradient
+    points_2d = points_2d.round()
+
+    # pointa need to be inside the image (can it diverge onto all points out???)
+    valid_mask = (
+        (points_2d[..., 0] >= 0)
+        & (points_2d[..., 0] < image_shape[1])
+        & (points_2d[..., 1] >= 0)
+        & (points_2d[..., 1] < image_shape[0])
+    )
+
+    # Calculate the flat indices of the valid pixels
+    flat_points_2d = points_2d[..., 0] + points_2d[..., 1] * image_shape[1]
+    flat_indices = flat_points_2d.long()
+
+    # Create depth maps and counts using scatter_add, (B, H, W)
+    depth_maps = torch.zeros(
+        [points_3d.shape[0], *image_shape], device=points_3d.device
+    )
+    counts = torch.zeros([points_3d.shape[0], *image_shape], device=points_3d.device)
+
+    # Loop over batches to apply masks and accumulate depth/count values
+    for i in range(points_3d.shape[0]):
+        valid_indices = flat_indices[i, valid_mask[i]]
+        depth_maps[i].view(-1).scatter_add_(
+            0, valid_indices, points_3d[i, valid_mask[i], 2]
+        )
+        counts[i].view(-1).scatter_add_(
+            0, valid_indices, torch.ones_like(points_3d[i, valid_mask[i], 2])
+        )
+
+    # Calculate mean depth for each pixel in each batch
+    mean_depth_maps = depth_maps / counts.clamp(min=1.0)
+    return mean_depth_maps.reshape(-1, 1, *image_shape)  # (B, 1, H, W)
+
+
+@torch.jit.script
+def downsample(data: torch.Tensor, downsample_factor: int = 2):
+    N, _, H, W = data.shape
+    data = data.view(
+        N,
+        H // downsample_factor,
+        downsample_factor,
+        W // downsample_factor,
+        downsample_factor,
+        1,
+    )
+    data = data.permute(0, 1, 3, 5, 2, 4).contiguous()
+    data = data.view(-1, downsample_factor * downsample_factor)
+    data_tmp = torch.where(data == 0.0, 1e5 * torch.ones_like(data), data)
+    data = torch.min(data_tmp, dim=-1).values
+    data = data.view(N, 1, H // downsample_factor, W // downsample_factor)
+    data = torch.where(data > 1000, torch.zeros_like(data), data)
+    return data
+
+
+@torch.jit.script
+def flat_interpolate(
+    flat_tensor: torch.Tensor,
+    old: Tuple[int, int],
+    new: Tuple[int, int],
+    antialias: bool = True,
+    mode: str = "bilinear",
+) -> torch.Tensor:
+    if old[0] == new[0] and old[1] == new[1]:
+        return flat_tensor
+    tensor = flat_tensor.view(flat_tensor.shape[0], old[0], old[1], -1).permute(
+        0, 3, 1, 2
+    )  # b c h w
+    tensor_interp = F.interpolate(
+        tensor,
+        size=(new[0], new[1]),
+        mode=mode,
+        align_corners=False,
+        antialias=antialias,
+    )
+    flat_tensor_interp = tensor_interp.view(
+        flat_tensor.shape[0], -1, new[0] * new[1]
+    ).permute(
+        0, 2, 1
+    )  # b (h w) c
+    return flat_tensor_interp.contiguous()
diff --git a/flash3d/unidepth/utils/misc.py b/flash3d/unidepth/utils/misc.py
new file mode 100644
index 0000000000000000000000000000000000000000..23ae7adc0048b313dca985175a5713644be5e75d
--- /dev/null
+++ b/flash3d/unidepth/utils/misc.py
@@ -0,0 +1,403 @@
+"""
+Author: Luigi Piccinelli
+Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/)
+"""
+
+from functools import wraps
+
+import numpy as np
+from scipy import interpolate
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from einops import rearrange, repeat, reduce
+
+
+def max_stack(tensors):
+    return torch.stack(tensors, dim=-1).max(dim=-1)[0]
+
+
+def softmax_stack(tensors, temperature=1.0):
+    return F.softmax(torch.stack(tensors, dim=-1) / temperature, dim=-1).sum(dim=-1)
+
+
+def mean_stack(tensors):
+    if len(tensors) == 1:
+        return tensors[0]
+    return torch.stack(tensors, dim=-1).mean(dim=-1)
+
+
+def sum_stack(tensors):
+    return torch.stack(tensors, dim=-1).sum(dim=-1)
+
+
+def convert_module_to_f16(l):
+    """
+    Convert primitive modules to float16.
+    """
+    if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
+        l.weight.data = l.weight.data.half()
+        if l.bias is not None:
+            l.bias.data = l.bias.data.half()
+
+
+def convert_module_to_f32(l):
+    """
+    Convert primitive modules to float32, undoing convert_module_to_f16().
+    """
+    if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
+        l.weight.data = l.weight.data.float()
+        if l.bias is not None:
+            l.bias.data = l.bias.data.float()
+
+
+def format_seconds(seconds):
+    minutes, seconds = divmod(seconds, 60)
+    hours, minutes = divmod(minutes, 60)
+    return f"{hours:d}:{minutes:02d}:{seconds:02d}"
+
+
+def get_params(module, lr, wd):
+    skip_list = {}
+    skip_keywords = {}
+    if hasattr(module, "no_weight_decay"):
+        skip_list = module.no_weight_decay()
+    if hasattr(module, "no_weight_decay_keywords"):
+        skip_keywords = module.no_weight_decay_keywords()
+    has_decay = []
+    no_decay = []
+    for name, param in module.named_parameters():
+        if not param.requires_grad:
+            continue  # frozen weights
+        if (
+            (name in skip_list)
+            or any((kw in name for kw in skip_keywords))
+            or len(param.shape) == 1
+        ):
+            # if (name in skip_list) or any((kw in name for kw in skip_keywords)):
+            # print(name, skip_keywords)
+            no_decay.append(param)
+        else:
+            has_decay.append(param)
+
+    group1 = {
+        "params": has_decay,
+        "weight_decay": wd,
+        "lr": lr,
+        "weight_decay_init": wd,
+        "weight_decay_base": wd,
+        "lr_init": lr,
+        "lr_base": lr,
+    }
+    group2 = {
+        "params": no_decay,
+        "weight_decay": 0.0,
+        "lr": lr,
+        "weight_decay_init": 0.0,
+        "weight_decay_base": 0.0,
+        "weight_decay_final": 0.0,
+        "lr_init": lr,
+        "lr_base": lr,
+    }
+    return [group1, group2], [lr, lr]
+
+
+def get_num_layer_for_swin(var_name, num_max_layer, layers_per_stage):
+    if var_name in ("cls_token", "mask_token", "pos_embed", "absolute_pos_embed"):
+        return 0
+    elif var_name.startswith("patch_embed"):
+        return 0
+    elif var_name.startswith("layers"):
+        if var_name.split(".")[2] == "blocks":
+            stage_id = int(var_name.split(".")[1])
+            layer_id = int(var_name.split(".")[3]) + sum(layers_per_stage[:stage_id])
+            return layer_id + 1
+        elif var_name.split(".")[2] == "downsample":
+            stage_id = int(var_name.split(".")[1])
+            layer_id = sum(layers_per_stage[: stage_id + 1])
+            return layer_id
+    else:
+        return num_max_layer - 1
+
+
+def get_params_layerdecayswin(module, lr, wd, ld):
+    skip_list = {}
+    skip_keywords = {}
+    if hasattr(module, "no_weight_decay"):
+        skip_list = module.no_weight_decay()
+    if hasattr(module, "no_weight_decay_keywords"):
+        skip_keywords = module.no_weight_decay_keywords()
+    layers_per_stage = module.depths
+    num_layers = sum(layers_per_stage) + 1
+    lrs = []
+    params = []
+    for name, param in module.named_parameters():
+        if not param.requires_grad:
+            print(f"{name} frozen")
+            continue  # frozen weights
+        layer_id = get_num_layer_for_swin(name, num_layers, layers_per_stage)
+        lr_cur = lr * ld ** (num_layers - layer_id - 1)
+        # if (name in skip_list) or any((kw in name for kw in skip_keywords)) or len(param.shape) == 1 or name.endswith(".bias"):
+        if (name in skip_list) or any((kw in name for kw in skip_keywords)):
+            wd_cur = 0.0
+        else:
+            wd_cur = wd
+        params.append({"params": param, "weight_decay": wd_cur, "lr": lr_cur})
+        lrs.append(lr_cur)
+    return params, lrs
+
+
+def log(t, eps: float = 1e-5):
+    return torch.log(t.clamp(min=eps))
+
+
+def l2norm(t):
+    return F.normalize(t, dim=-1)
+
+
+def exists(val):
+    return val is not None
+
+
+def identity(t, *args, **kwargs):
+    return t
+
+
+def divisible_by(numer, denom):
+    return (numer % denom) == 0
+
+
+def first(arr, d=None):
+    if len(arr) == 0:
+        return d
+    return arr[0]
+
+
+def default(val, d):
+    if exists(val):
+        return val
+    return d() if callable(d) else d
+
+
+def maybe(fn):
+    @wraps(fn)
+    def inner(x):
+        if not exists(x):
+            return x
+        return fn(x)
+
+    return inner
+
+
+def once(fn):
+    called = False
+
+    @wraps(fn)
+    def inner(x):
+        nonlocal called
+        if called:
+            return
+        called = True
+        return fn(x)
+
+    return inner
+
+
+def _many(fn):
+    @wraps(fn)
+    def inner(tensors, pattern, **kwargs):
+        return (fn(tensor, pattern, **kwargs) for tensor in tensors)
+
+    return inner
+
+
+rearrange_many = _many(rearrange)
+repeat_many = _many(repeat)
+reduce_many = _many(reduce)
+
+
+def load_pretrained(state_dict, checkpoint):
+    checkpoint_model = checkpoint["model"]
+    if any([True if "encoder." in k else False for k in checkpoint_model.keys()]):
+        checkpoint_model = {
+            k.replace("encoder.", ""): v
+            for k, v in checkpoint_model.items()
+            if k.startswith("encoder.")
+        }
+        print("Detect pre-trained model, remove [encoder.] prefix.")
+    else:
+        print("Detect non-pre-trained model, pass without doing anything.")
+    print(f">>>>>>>>>> Remapping pre-trained keys for SWIN ..........")
+    checkpoint = load_checkpoint_swin(state_dict, checkpoint_model)
+
+
+def load_checkpoint_swin(model, checkpoint_model):
+    state_dict = model.state_dict()
+    # Geometric interpolation when pre-trained patch size mismatch with fine-tuned patch size
+    all_keys = list(checkpoint_model.keys())
+    for key in all_keys:
+        if "relative_position_bias_table" in key:
+            relative_position_bias_table_pretrained = checkpoint_model[key]
+            relative_position_bias_table_current = state_dict[key]
+            L1, nH1 = relative_position_bias_table_pretrained.size()
+            L2, nH2 = relative_position_bias_table_current.size()
+            if nH1 != nH2:
+                print(f"Error in loading {key}, passing......")
+            else:
+                if L1 != L2:
+                    print(f"{key}: Interpolate relative_position_bias_table using geo.")
+                    src_size = int(L1**0.5)
+                    dst_size = int(L2**0.5)
+
+                    def geometric_progression(a, r, n):
+                        return a * (1.0 - r**n) / (1.0 - r)
+
+                    left, right = 1.01, 1.5
+                    while right - left > 1e-6:
+                        q = (left + right) / 2.0
+                        gp = geometric_progression(1, q, src_size // 2)
+                        if gp > dst_size // 2:
+                            right = q
+                        else:
+                            left = q
+
+                    # if q > 1.090307:
+                    #     q = 1.090307
+
+                    dis = []
+                    cur = 1
+                    for i in range(src_size // 2):
+                        dis.append(cur)
+                        cur += q ** (i + 1)
+
+                    r_ids = [-_ for _ in reversed(dis)]
+
+                    x = r_ids + [0] + dis
+                    y = r_ids + [0] + dis
+
+                    t = dst_size // 2.0
+                    dx = np.arange(-t, t + 0.1, 1.0)
+                    dy = np.arange(-t, t + 0.1, 1.0)
+
+                    print("Original positions = %s" % str(x))
+                    print("Target positions = %s" % str(dx))
+
+                    all_rel_pos_bias = []
+
+                    for i in range(nH1):
+                        z = (
+                            relative_position_bias_table_pretrained[:, i]
+                            .view(src_size, src_size)
+                            .float()
+                            .numpy()
+                        )
+                        f_cubic = interpolate.interp2d(x, y, z, kind="cubic")
+                        all_rel_pos_bias.append(
+                            torch.Tensor(f_cubic(dx, dy))
+                            .contiguous()
+                            .view(-1, 1)
+                            .to(relative_position_bias_table_pretrained.device)
+                        )
+
+                    new_rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1)
+                    checkpoint_model[key] = new_rel_pos_bias
+
+    # delete relative_position_index since we always re-init it
+    relative_position_index_keys = [
+        k for k in checkpoint_model.keys() if "relative_position_index" in k
+    ]
+    for k in relative_position_index_keys:
+        del checkpoint_model[k]
+
+    # delete relative_coords_table since we always re-init it
+    relative_coords_table_keys = [
+        k for k in checkpoint_model.keys() if "relative_coords_table" in k
+    ]
+    for k in relative_coords_table_keys:
+        del checkpoint_model[k]
+
+    # # re-map keys due to name change
+    rpe_mlp_keys = [k for k in checkpoint_model.keys() if "cpb_mlp" in k]
+    for k in rpe_mlp_keys:
+        checkpoint_model[k.replace("cpb_mlp", "rpe_mlp")] = checkpoint_model.pop(k)
+
+    # delete attn_mask since we always re-init it
+    attn_mask_keys = [k for k in checkpoint_model.keys() if "attn_mask" in k]
+    for k in attn_mask_keys:
+        del checkpoint_model[k]
+
+    encoder_keys = [k for k in checkpoint_model.keys() if k.startswith("encoder.")]
+    for k in encoder_keys:
+        checkpoint_model[k.replace("encoder.", "")] = checkpoint_model.pop(k)
+
+    return checkpoint_model
+
+
+def add_padding_metas(out, image_metas):
+    device = out.device
+    # left, right, top, bottom
+    paddings = [img_meta.get("padding_size", [0] * 4) for img_meta in image_metas]
+    paddings = torch.stack(paddings).to(device)
+    outs = [F.pad(o, padding, value=0.0) for padding, o in zip(paddings, out)]
+    return torch.stack(outs)
+
+
+def remove_padding(out, paddings):
+    B, C, H, W = out.shape
+    device = out.device
+    # left, right, top, bottom
+    paddings = torch.stack(paddings).to(device)
+    outs = [
+        o[:, padding[1] : H - padding[3], padding[0] : W - padding[2]]
+        for padding, o in zip(paddings, out)
+    ]
+    return torch.stack(outs)
+
+
+def remove_padding_metas(out, image_metas):
+    B, C, H, W = out.shape
+    device = out.device
+    # left, right, top, bottom
+    paddings = [
+        torch.tensor(img_meta.get("padding_size", [0] * 4)) for img_meta in image_metas
+    ]
+    return remove_padding(out, paddings)
+
+
+def ssi_helper(tensor1, tensor2):
+    stability_mat = 1e-4 * torch.eye(2, device=tensor1.device)
+    tensor2_one = torch.stack([tensor2, torch.ones_like(tensor2)], dim=1)
+    scale_shift = torch.inverse(tensor2_one.T @ tensor2_one + stability_mat) @ (
+        tensor2_one.T @ tensor1.unsqueeze(1)
+    )
+    scale, shift = scale_shift.squeeze().chunk(2, dim=0)
+    return scale, shift
+
+
+def calculate_mean_values(names, values):
+    # Create a defaultdict to store sum and count for each name
+    name_values = {name: {} for name in names}
+
+    # Iterate through the lists and accumulate values for each name
+    for name, value in zip(names, values):
+        name_values[name]["sum"] = name_values[name].get("sum", 0.0) + value
+        name_values[name]["count"] = name_values[name].get("count", 0.0) + 1
+
+    # Calculate mean values and create the output dictionary
+    output_dict = {
+        name: name_values[name]["sum"] / name_values[name]["count"]
+        for name in name_values
+    }
+
+    return output_dict
+
+
+def remove_leading_dim(infos):
+    if isinstance(infos, dict):
+        return {k: remove_leading_dim(v) for k, v in infos.items()}
+    elif isinstance(infos, torch.Tensor):
+        return infos.squeeze(0)
+    else:
+        return infos
diff --git a/flash3d/unidepth/utils/positional_embedding.py b/flash3d/unidepth/utils/positional_embedding.py
new file mode 100644
index 0000000000000000000000000000000000000000..8883076cb8d10896322d2973d6d0cd5df35e6943
--- /dev/null
+++ b/flash3d/unidepth/utils/positional_embedding.py
@@ -0,0 +1,274 @@
+"""
+Author: Luigi Piccinelli
+Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/)
+"""
+
+from math import pi
+from typing import Optional
+
+import torch
+import torch.nn as nn
+
+from einops import rearrange, repeat
+
+
+class PositionEmbeddingSine(nn.Module):
+    def __init__(
+        self, num_pos_feats=64, temperature=10000, normalize=False, scale=None
+    ):
+        super().__init__()
+        self.num_pos_feats = num_pos_feats
+        self.temperature = temperature
+        self.normalize = normalize
+        if scale is not None and normalize is False:
+            raise ValueError("normalize should be True if scale is passed")
+        if scale is None:
+            scale = 2 * pi
+        self.scale = scale
+
+    def forward(
+        self, x: torch.Tensor, mask: Optional[torch.Tensor] = None
+    ) -> torch.Tensor:
+        if mask is None:
+            mask = torch.zeros(
+                (x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool
+            )
+        not_mask = ~mask
+        y_embed = not_mask.cumsum(1, dtype=torch.float32)
+        x_embed = not_mask.cumsum(2, dtype=torch.float32)
+        if self.normalize:
+            eps = 1e-6
+            y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
+            x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
+
+        dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
+        dim_t = self.temperature ** (
+            2 * torch.div(dim_t, 2, rounding_mode="floor") / self.num_pos_feats
+        )
+
+        pos_x = x_embed[:, :, :, None] / dim_t
+        pos_y = y_embed[:, :, :, None] / dim_t
+        pos_x = torch.stack(
+            (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
+        ).flatten(3)
+        pos_y = torch.stack(
+            (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
+        ).flatten(3)
+        pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
+        return pos
+
+    def __repr__(self, _repr_indent=4):
+        head = "Positional encoding " + self.__class__.__name__
+        body = [
+            "num_pos_feats: {}".format(self.num_pos_feats),
+            "temperature: {}".format(self.temperature),
+            "normalize: {}".format(self.normalize),
+            "scale: {}".format(self.scale),
+        ]
+        # _repr_indent = 4
+        lines = [head] + [" " * _repr_indent + line for line in body]
+        return "\n".join(lines)
+
+
+class LearnedSinusoidalPosEmb(nn.Module):
+    def __init__(self, dim):
+        super().__init__()
+        assert (dim % 2) == 0
+        half_dim = dim // 2
+        self.weights = nn.Parameter(torch.randn(half_dim))
+
+    def forward(self, x):
+        x = rearrange(x, "b -> b 1")
+        freqs = x * rearrange(self.weights, "d -> 1 d") * 2 * pi
+        fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1)
+        fouriered = torch.cat((x, fouriered), dim=-1)
+        return fouriered
+
+
+def broadcat(tensors, dim=-1):
+    num_tensors = len(tensors)
+    shape_lens = set(list(map(lambda t: len(t.shape), tensors)))
+    assert len(shape_lens) == 1, "tensors must all have the same number of dimensions"
+    shape_len = list(shape_lens)[0]
+    dim = (dim + shape_len) if dim < 0 else dim
+    dims = list(zip(*map(lambda t: list(t.shape), tensors)))
+    expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
+    assert all(
+        [*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]
+    ), "invalid dimensions for broadcastable concatentation"
+    max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims))
+    expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims))
+    expanded_dims.insert(dim, (dim, dims[dim]))
+    expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims)))
+    tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes)))
+    return torch.cat(tensors, dim=dim)
+
+
+def rotate_half(x):
+    x = rearrange(x, "... (d r) -> ... d r", r=2)
+    x1, x2 = x.unbind(dim=-1)
+    x = torch.stack((-x2, x1), dim=-1)
+    return rearrange(x, "... d r -> ... (d r)")
+
+
+class VisionRotaryEmbedding(nn.Module):
+    def __init__(
+        self,
+        dim,
+        pt_seq_len,
+        ft_seq_len=None,
+        custom_freqs=None,
+        freqs_for="lang",
+        theta=10000,
+        max_freq=10,
+        num_freqs=1,
+    ):
+        super().__init__()
+        if custom_freqs:
+            freqs = custom_freqs
+        elif freqs_for == "lang":
+            freqs = 1.0 / (
+                theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
+            )
+        elif freqs_for == "pixel":
+            freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi
+        elif freqs_for == "constant":
+            freqs = torch.ones(num_freqs).float()
+        else:
+            raise ValueError(f"unknown modality {freqs_for}")
+
+        if ft_seq_len is None:
+            ft_seq_len = pt_seq_len
+        t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
+
+        freqs_h = torch.einsum("..., f -> ... f", t, freqs)
+        freqs_h = repeat(freqs_h, "... n -> ... (n r)", r=2)
+
+        freqs_w = torch.einsum("..., f -> ... f", t, freqs)
+        freqs_w = repeat(freqs_w, "... n -> ... (n r)", r=2)
+
+        freqs = broadcat((freqs_h[:, None, :], freqs_w[None, :, :]), dim=-1)
+
+        self.register_buffer("freqs_cos", freqs.cos())
+        self.register_buffer("freqs_sin", freqs.sin())
+
+        print("======== shape of rope freq", self.freqs_cos.shape, "========")
+
+    def forward(self, t, start_index=0):
+        rot_dim = self.freqs_cos.shape[-1]
+        end_index = start_index + rot_dim
+        assert (
+            rot_dim <= t.shape[-1]
+        ), f"feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}"
+        t_left, t, t_right = (
+            t[..., :start_index],
+            t[..., start_index:end_index],
+            t[..., end_index:],
+        )
+        t = (t * self.freqs_cos) + (rotate_half(t) * self.freqs_sin)
+        return torch.cat((t_left, t, t_right), dim=-1)
+
+
+class VisionRotaryEmbeddingFast(nn.Module):
+    def __init__(
+        self,
+        dim,
+        pt_seq_len,
+        ft_seq_len=None,
+        custom_freqs=None,
+        freqs_for="lang",
+        theta=10000,
+        max_freq=10,
+        num_freqs=1,
+    ):
+        super().__init__()
+        if custom_freqs:
+            freqs = custom_freqs
+        elif freqs_for == "lang":
+            freqs = 1.0 / (
+                theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
+            )
+        elif freqs_for == "pixel":
+            freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi
+        elif freqs_for == "constant":
+            freqs = torch.ones(num_freqs).float()
+        else:
+            raise ValueError(f"unknown modality {freqs_for}")
+
+        if ft_seq_len is None:
+            ft_seq_len = pt_seq_len
+        t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
+
+        freqs = torch.einsum("..., f -> ... f", t, freqs)
+        freqs = repeat(freqs, "... n -> ... (n r)", r=2)
+        freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim=-1)
+
+        freqs_cos = freqs.cos().view(-1, freqs.shape[-1])
+        freqs_sin = freqs.sin().view(-1, freqs.shape[-1])
+
+        self.register_buffer("freqs_cos", freqs_cos)
+        self.register_buffer("freqs_sin", freqs_sin)
+
+    def forward(self, t):
+        return t * self.freqs_cos + rotate_half(t) * self.freqs_sin
+
+
+from math import log2
+
+
+def generate_fourier_features(
+    x: torch.Tensor,
+    dim: int = 512,
+    max_freq: int = 64,
+    use_cos: bool = False,
+    use_log: bool = False,
+    cat_orig: bool = False,
+):
+    x_orig = x
+    device, dtype, input_dim = x.device, x.dtype, x.shape[-1]
+    num_bands = dim // (2 * input_dim) if use_cos else dim // input_dim
+
+    if use_log:
+        scales = 2.0 ** torch.linspace(
+            0.0, log2(max_freq), steps=num_bands, device=device, dtype=dtype
+        )
+    else:
+        scales = torch.linspace(
+            1.0, max_freq / 2, num_bands, device=device, dtype=dtype
+        )
+
+    x = x.unsqueeze(-1)
+    scales = scales[(*((None,) * (len(x.shape) - 1)), Ellipsis)]
+
+    x = x * scales * pi
+    x = torch.cat(
+        (
+            [x.sin(), x.cos()]
+            if use_cos
+            else [
+                x.sin(),
+            ]
+        ),
+        dim=-1,
+    )
+    x = x.flatten(-2)
+    if cat_orig:
+        return torch.cat((x, x_orig), dim=-1)
+    return x
+
+
+# from PIL import Image
+# from unidepth.utils import image_grid, colorize
+# if __name__ == "__main__":
+#     H, W = 512, 512
+#     resolution = 128
+#     mesh = torch.meshgrid(torch.linspace(-1, 1, H), torch.linspace(-1, 1, W))
+#     mesh = torch.stack(mesh, dim=0).unsqueeze(0)
+#     mesh = mesh.view(1, 2, -1).permute(0, 2, 1)
+
+#     features = generate_fourier_features(mesh, dim=32, max_freq=resolution, use_log=True)
+#     channels = features.shape[-1]
+#     print(features.shape)
+
+#     features = features[0].view(H, W, channels).permute(2, 0, 1).numpy()
+#     Image.fromarray(image_grid([colorize(1+x, 0.0, 2.0, "viridis") for x in features], rows=8, cols=4)).save(f"tmp_{resolution}.png")
diff --git a/flash3d/unidepth/utils/sht.py b/flash3d/unidepth/utils/sht.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b89273a8f20b4da5ba296b175c856c974df0984
--- /dev/null
+++ b/flash3d/unidepth/utils/sht.py
@@ -0,0 +1,1637 @@
+"""Real spherical harmonics in Cartesian form for PyTorch.
+
+This is an autogenerated file. See
+https://github.com/cheind/torch-spherical-harmonics
+for more information.
+"""
+
+import torch
+
+
+def rsh_cart_0(xyz: torch.Tensor):
+    """Computes all real spherical harmonics up to degree 0.
+
+    This is an autogenerated method. See
+    https://github.com/cheind/torch-spherical-harmonics
+    for more information.
+
+    Params:
+        xyz: (N,...,3) tensor of points on the unit sphere
+
+    Returns:
+        rsh: (N,...,1) real spherical harmonics
+            projections of input. Ynm is found at index
+            `n*(n+1) + m`, with `0 <= n <= degree` and
+            `-n <= m <= n`.
+    """
+
+    return torch.stack(
+        [
+            xyz.new_tensor(0.282094791773878).expand(xyz.shape[:-1]),
+        ],
+        -1,
+    )
+
+
+def rsh_cart_1(xyz: torch.Tensor):
+    """Computes all real spherical harmonics up to degree 1.
+
+    This is an autogenerated method. See
+    https://github.com/cheind/torch-spherical-harmonics
+    for more information.
+
+    Params:
+        xyz: (N,...,3) tensor of points on the unit sphere
+
+    Returns:
+        rsh: (N,...,4) real spherical harmonics
+            projections of input. Ynm is found at index
+            `n*(n+1) + m`, with `0 <= n <= degree` and
+            `-n <= m <= n`.
+    """
+    x = xyz[..., 0]
+    y = xyz[..., 1]
+    z = xyz[..., 2]
+
+    return torch.stack(
+        [
+            xyz.new_tensor(0.282094791773878).expand(xyz.shape[:-1]),
+            -0.48860251190292 * y,
+            0.48860251190292 * z,
+            -0.48860251190292 * x,
+        ],
+        -1,
+    )
+
+
+def rsh_cart_2(xyz: torch.Tensor):
+    """Computes all real spherical harmonics up to degree 2.
+
+    This is an autogenerated method. See
+    https://github.com/cheind/torch-spherical-harmonics
+    for more information.
+
+    Params:
+        xyz: (N,...,3) tensor of points on the unit sphere
+
+    Returns:
+        rsh: (N,...,9) real spherical harmonics
+            projections of input. Ynm is found at index
+            `n*(n+1) + m`, with `0 <= n <= degree` and
+            `-n <= m <= n`.
+    """
+    x = xyz[..., 0]
+    y = xyz[..., 1]
+    z = xyz[..., 2]
+
+    x2 = x**2
+    y2 = y**2
+    z2 = z**2
+    xy = x * y
+    xz = x * z
+    yz = y * z
+
+    return torch.stack(
+        [
+            xyz.new_tensor(0.282094791773878).expand(xyz.shape[:-1]),
+            -0.48860251190292 * y,
+            0.48860251190292 * z,
+            -0.48860251190292 * x,
+            1.09254843059208 * xy,
+            -1.09254843059208 * yz,
+            0.94617469575756 * z2 - 0.31539156525252,
+            -1.09254843059208 * xz,
+            0.54627421529604 * x2 - 0.54627421529604 * y2,
+        ],
+        -1,
+    )
+
+
+def rsh_cart_3(xyz: torch.Tensor):
+    """Computes all real spherical harmonics up to degree 3.
+
+    This is an autogenerated method. See
+    https://github.com/cheind/torch-spherical-harmonics
+    for more information.
+
+    Params:
+        xyz: (N,...,3) tensor of points on the unit sphere
+
+    Returns:
+        rsh: (N,...,16) real spherical harmonics
+            projections of input. Ynm is found at index
+            `n*(n+1) + m`, with `0 <= n <= degree` and
+            `-n <= m <= n`.
+    """
+    x = xyz[..., 0]
+    y = xyz[..., 1]
+    z = xyz[..., 2]
+
+    x2 = x**2
+    y2 = y**2
+    z2 = z**2
+    xy = x * y
+    xz = x * z
+    yz = y * z
+
+    return torch.stack(
+        [
+            xyz.new_tensor(0.282094791773878).expand(xyz.shape[:-1]),
+            -0.48860251190292 * y,
+            0.48860251190292 * z,
+            -0.48860251190292 * x,
+            1.09254843059208 * xy,
+            -1.09254843059208 * yz,
+            0.94617469575756 * z2 - 0.31539156525252,
+            -1.09254843059208 * xz,
+            0.54627421529604 * x2 - 0.54627421529604 * y2,
+            -0.590043589926644 * y * (3.0 * x2 - y2),
+            2.89061144264055 * xy * z,
+            0.304697199642977 * y * (1.5 - 7.5 * z2),
+            1.24392110863372 * z * (1.5 * z2 - 0.5) - 0.497568443453487 * z,
+            0.304697199642977 * x * (1.5 - 7.5 * z2),
+            1.44530572132028 * z * (x2 - y2),
+            -0.590043589926644 * x * (x2 - 3.0 * y2),
+        ],
+        -1,
+    )
+
+
+def rsh_cart_4(xyz: torch.Tensor):
+    """Computes all real spherical harmonics up to degree 4.
+
+    This is an autogenerated method. See
+    https://github.com/cheind/torch-spherical-harmonics
+    for more information.
+
+    Params:
+        xyz: (N,...,3) tensor of points on the unit sphere
+
+    Returns:
+        rsh: (N,...,25) real spherical harmonics
+            projections of input. Ynm is found at index
+            `n*(n+1) + m`, with `0 <= n <= degree` and
+            `-n <= m <= n`.
+    """
+    x = xyz[..., 0]
+    y = xyz[..., 1]
+    z = xyz[..., 2]
+
+    x2 = x**2
+    y2 = y**2
+    z2 = z**2
+    xy = x * y
+    xz = x * z
+    yz = y * z
+    x4 = x2**2
+    y4 = y2**2
+    z4 = z2**2
+
+    return torch.stack(
+        [
+            xyz.new_tensor(0.282094791773878).expand(xyz.shape[:-1]),
+            -0.48860251190292 * y,
+            0.48860251190292 * z,
+            -0.48860251190292 * x,
+            1.09254843059208 * xy,
+            -1.09254843059208 * yz,
+            0.94617469575756 * z2 - 0.31539156525252,
+            -1.09254843059208 * xz,
+            0.54627421529604 * x2 - 0.54627421529604 * y2,
+            -0.590043589926644 * y * (3.0 * x2 - y2),
+            2.89061144264055 * xy * z,
+            0.304697199642977 * y * (1.5 - 7.5 * z2),
+            1.24392110863372 * z * (1.5 * z2 - 0.5) - 0.497568443453487 * z,
+            0.304697199642977 * x * (1.5 - 7.5 * z2),
+            1.44530572132028 * z * (x2 - y2),
+            -0.590043589926644 * x * (x2 - 3.0 * y2),
+            2.5033429417967 * xy * (x2 - y2),
+            -1.77013076977993 * yz * (3.0 * x2 - y2),
+            0.126156626101008 * xy * (52.5 * z2 - 7.5),
+            0.267618617422916 * y * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z),
+            1.48099765681286
+            * z
+            * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
+            - 0.952069922236839 * z2
+            + 0.317356640745613,
+            0.267618617422916 * x * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z),
+            0.063078313050504 * (x2 - y2) * (52.5 * z2 - 7.5),
+            -1.77013076977993 * xz * (x2 - 3.0 * y2),
+            -3.75501441269506 * x2 * y2
+            + 0.625835735449176 * x4
+            + 0.625835735449176 * y4,
+        ],
+        -1,
+    )
+
+
+def rsh_cart_5(xyz: torch.Tensor):
+    """Computes all real spherical harmonics up to degree 5.
+
+    This is an autogenerated method. See
+    https://github.com/cheind/torch-spherical-harmonics
+    for more information.
+
+    Params:
+        xyz: (N,...,3) tensor of points on the unit sphere
+
+    Returns:
+        rsh: (N,...,36) real spherical harmonics
+            projections of input. Ynm is found at index
+            `n*(n+1) + m`, with `0 <= n <= degree` and
+            `-n <= m <= n`.
+    """
+    x = xyz[..., 0]
+    y = xyz[..., 1]
+    z = xyz[..., 2]
+
+    x2 = x**2
+    y2 = y**2
+    z2 = z**2
+    xy = x * y
+    xz = x * z
+    yz = y * z
+    x4 = x2**2
+    y4 = y2**2
+    z4 = z2**2
+
+    return torch.stack(
+        [
+            xyz.new_tensor(0.282094791773878).expand(xyz.shape[:-1]),
+            -0.48860251190292 * y,
+            0.48860251190292 * z,
+            -0.48860251190292 * x,
+            1.09254843059208 * xy,
+            -1.09254843059208 * yz,
+            0.94617469575756 * z2 - 0.31539156525252,
+            -1.09254843059208 * xz,
+            0.54627421529604 * x2 - 0.54627421529604 * y2,
+            -0.590043589926644 * y * (3.0 * x2 - y2),
+            2.89061144264055 * xy * z,
+            0.304697199642977 * y * (1.5 - 7.5 * z2),
+            1.24392110863372 * z * (1.5 * z2 - 0.5) - 0.497568443453487 * z,
+            0.304697199642977 * x * (1.5 - 7.5 * z2),
+            1.44530572132028 * z * (x2 - y2),
+            -0.590043589926644 * x * (x2 - 3.0 * y2),
+            2.5033429417967 * xy * (x2 - y2),
+            -1.77013076977993 * yz * (3.0 * x2 - y2),
+            0.126156626101008 * xy * (52.5 * z2 - 7.5),
+            0.267618617422916 * y * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z),
+            1.48099765681286
+            * z
+            * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
+            - 0.952069922236839 * z2
+            + 0.317356640745613,
+            0.267618617422916 * x * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z),
+            0.063078313050504 * (x2 - y2) * (52.5 * z2 - 7.5),
+            -1.77013076977993 * xz * (x2 - 3.0 * y2),
+            -3.75501441269506 * x2 * y2
+            + 0.625835735449176 * x4
+            + 0.625835735449176 * y4,
+            -0.65638205684017 * y * (-10.0 * x2 * y2 + 5.0 * x4 + y4),
+            8.30264925952416 * xy * z * (x2 - y2),
+            0.00931882475114763 * y * (52.5 - 472.5 * z2) * (3.0 * x2 - y2),
+            0.0913054625709205 * xy * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z),
+            0.241571547304372
+            * y
+            * (
+                2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+                + 9.375 * z2
+                - 1.875
+            ),
+            -1.24747010616985 * z * (1.5 * z2 - 0.5)
+            + 1.6840846433293
+            * z
+            * (
+                1.75
+                * z
+                * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
+                - 1.125 * z2
+                + 0.375
+            )
+            + 0.498988042467941 * z,
+            0.241571547304372
+            * x
+            * (
+                2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+                + 9.375 * z2
+                - 1.875
+            ),
+            0.0456527312854602 * (x2 - y2) * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z),
+            0.00931882475114763 * x * (52.5 - 472.5 * z2) * (x2 - 3.0 * y2),
+            2.07566231488104 * z * (-6.0 * x2 * y2 + x4 + y4),
+            -0.65638205684017 * x * (-10.0 * x2 * y2 + x4 + 5.0 * y4),
+        ],
+        -1,
+    )
+
+
+def rsh_cart_6(xyz: torch.Tensor):
+    """Computes all real spherical harmonics up to degree 6.
+
+    This is an autogenerated method. See
+    https://github.com/cheind/torch-spherical-harmonics
+    for more information.
+
+    Params:
+        xyz: (N,...,3) tensor of points on the unit sphere
+
+    Returns:
+        rsh: (N,...,49) real spherical harmonics
+            projections of input. Ynm is found at index
+            `n*(n+1) + m`, with `0 <= n <= degree` and
+            `-n <= m <= n`.
+    """
+    x = xyz[..., 0]
+    y = xyz[..., 1]
+    z = xyz[..., 2]
+
+    x2 = x**2
+    y2 = y**2
+    z2 = z**2
+    xy = x * y
+    xz = x * z
+    yz = y * z
+    x4 = x2**2
+    y4 = y2**2
+    z4 = z2**2
+
+    return torch.stack(
+        [
+            xyz.new_tensor(0.282094791773878).expand(xyz.shape[:-1]),
+            -0.48860251190292 * y,
+            0.48860251190292 * z,
+            -0.48860251190292 * x,
+            1.09254843059208 * xy,
+            -1.09254843059208 * yz,
+            0.94617469575756 * z2 - 0.31539156525252,
+            -1.09254843059208 * xz,
+            0.54627421529604 * x2 - 0.54627421529604 * y2,
+            -0.590043589926644 * y * (3.0 * x2 - y2),
+            2.89061144264055 * xy * z,
+            0.304697199642977 * y * (1.5 - 7.5 * z2),
+            1.24392110863372 * z * (1.5 * z2 - 0.5) - 0.497568443453487 * z,
+            0.304697199642977 * x * (1.5 - 7.5 * z2),
+            1.44530572132028 * z * (x2 - y2),
+            -0.590043589926644 * x * (x2 - 3.0 * y2),
+            2.5033429417967 * xy * (x2 - y2),
+            -1.77013076977993 * yz * (3.0 * x2 - y2),
+            0.126156626101008 * xy * (52.5 * z2 - 7.5),
+            0.267618617422916 * y * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z),
+            1.48099765681286
+            * z
+            * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
+            - 0.952069922236839 * z2
+            + 0.317356640745613,
+            0.267618617422916 * x * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z),
+            0.063078313050504 * (x2 - y2) * (52.5 * z2 - 7.5),
+            -1.77013076977993 * xz * (x2 - 3.0 * y2),
+            -3.75501441269506 * x2 * y2
+            + 0.625835735449176 * x4
+            + 0.625835735449176 * y4,
+            -0.65638205684017 * y * (-10.0 * x2 * y2 + 5.0 * x4 + y4),
+            8.30264925952416 * xy * z * (x2 - y2),
+            0.00931882475114763 * y * (52.5 - 472.5 * z2) * (3.0 * x2 - y2),
+            0.0913054625709205 * xy * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z),
+            0.241571547304372
+            * y
+            * (
+                2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+                + 9.375 * z2
+                - 1.875
+            ),
+            -1.24747010616985 * z * (1.5 * z2 - 0.5)
+            + 1.6840846433293
+            * z
+            * (
+                1.75
+                * z
+                * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
+                - 1.125 * z2
+                + 0.375
+            )
+            + 0.498988042467941 * z,
+            0.241571547304372
+            * x
+            * (
+                2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+                + 9.375 * z2
+                - 1.875
+            ),
+            0.0456527312854602 * (x2 - y2) * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z),
+            0.00931882475114763 * x * (52.5 - 472.5 * z2) * (x2 - 3.0 * y2),
+            2.07566231488104 * z * (-6.0 * x2 * y2 + x4 + y4),
+            -0.65638205684017 * x * (-10.0 * x2 * y2 + x4 + 5.0 * y4),
+            4.09910463115149 * x**4 * xy
+            - 13.6636821038383 * xy**3
+            + 4.09910463115149 * xy * y**4,
+            -2.36661916223175 * yz * (-10.0 * x2 * y2 + 5.0 * x4 + y4),
+            0.00427144889505798 * xy * (x2 - y2) * (5197.5 * z2 - 472.5),
+            0.00584892228263444
+            * y
+            * (3.0 * x2 - y2)
+            * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z),
+            0.0701870673916132
+            * xy
+            * (
+                2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
+                - 91.875 * z2
+                + 13.125
+            ),
+            0.221950995245231
+            * y
+            * (
+                -2.8 * z * (1.5 - 7.5 * z2)
+                + 2.2
+                * z
+                * (
+                    2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+                    + 9.375 * z2
+                    - 1.875
+                )
+                - 4.8 * z
+            ),
+            -1.48328138624466
+            * z
+            * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
+            + 1.86469659985043
+            * z
+            * (
+                -1.33333333333333 * z * (1.5 * z2 - 0.5)
+                + 1.8
+                * z
+                * (
+                    1.75
+                    * z
+                    * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
+                    - 1.125 * z2
+                    + 0.375
+                )
+                + 0.533333333333333 * z
+            )
+            + 0.953538034014426 * z2
+            - 0.317846011338142,
+            0.221950995245231
+            * x
+            * (
+                -2.8 * z * (1.5 - 7.5 * z2)
+                + 2.2
+                * z
+                * (
+                    2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+                    + 9.375 * z2
+                    - 1.875
+                )
+                - 4.8 * z
+            ),
+            0.0350935336958066
+            * (x2 - y2)
+            * (
+                2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
+                - 91.875 * z2
+                + 13.125
+            ),
+            0.00584892228263444
+            * x
+            * (x2 - 3.0 * y2)
+            * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z),
+            0.0010678622237645 * (5197.5 * z2 - 472.5) * (-6.0 * x2 * y2 + x4 + y4),
+            -2.36661916223175 * xz * (-10.0 * x2 * y2 + x4 + 5.0 * y4),
+            0.683184105191914 * x2**3
+            + 10.2477615778787 * x2 * y4
+            - 10.2477615778787 * x4 * y2
+            - 0.683184105191914 * y2**3,
+        ],
+        -1,
+    )
+
+
+def rsh_cart_7(xyz: torch.Tensor):
+    """Computes all real spherical harmonics up to degree 7.
+
+    This is an autogenerated method. See
+    https://github.com/cheind/torch-spherical-harmonics
+    for more information.
+
+    Params:
+        xyz: (N,...,3) tensor of points on the unit sphere
+
+    Returns:
+        rsh: (N,...,64) real spherical harmonics
+            projections of input. Ynm is found at index
+            `n*(n+1) + m`, with `0 <= n <= degree` and
+            `-n <= m <= n`.
+    """
+    x = xyz[..., 0]
+    y = xyz[..., 1]
+    z = xyz[..., 2]
+
+    x2 = x**2
+    y2 = y**2
+    z2 = z**2
+    xy = x * y
+    xz = x * z
+    yz = y * z
+    x4 = x2**2
+    y4 = y2**2
+    z4 = z2**2
+
+    return torch.stack(
+        [
+            xyz.new_tensor(0.282094791773878).expand(xyz.shape[:-1]),
+            -0.48860251190292 * y,
+            0.48860251190292 * z,
+            -0.48860251190292 * x,
+            1.09254843059208 * xy,
+            -1.09254843059208 * yz,
+            0.94617469575756 * z2 - 0.31539156525252,
+            -1.09254843059208 * xz,
+            0.54627421529604 * x2 - 0.54627421529604 * y2,
+            -0.590043589926644 * y * (3.0 * x2 - y2),
+            2.89061144264055 * xy * z,
+            0.304697199642977 * y * (1.5 - 7.5 * z2),
+            1.24392110863372 * z * (1.5 * z2 - 0.5) - 0.497568443453487 * z,
+            0.304697199642977 * x * (1.5 - 7.5 * z2),
+            1.44530572132028 * z * (x2 - y2),
+            -0.590043589926644 * x * (x2 - 3.0 * y2),
+            2.5033429417967 * xy * (x2 - y2),
+            -1.77013076977993 * yz * (3.0 * x2 - y2),
+            0.126156626101008 * xy * (52.5 * z2 - 7.5),
+            0.267618617422916 * y * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z),
+            1.48099765681286
+            * z
+            * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
+            - 0.952069922236839 * z2
+            + 0.317356640745613,
+            0.267618617422916 * x * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z),
+            0.063078313050504 * (x2 - y2) * (52.5 * z2 - 7.5),
+            -1.77013076977993 * xz * (x2 - 3.0 * y2),
+            -3.75501441269506 * x2 * y2
+            + 0.625835735449176 * x4
+            + 0.625835735449176 * y4,
+            -0.65638205684017 * y * (-10.0 * x2 * y2 + 5.0 * x4 + y4),
+            8.30264925952416 * xy * z * (x2 - y2),
+            0.00931882475114763 * y * (52.5 - 472.5 * z2) * (3.0 * x2 - y2),
+            0.0913054625709205 * xy * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z),
+            0.241571547304372
+            * y
+            * (
+                2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+                + 9.375 * z2
+                - 1.875
+            ),
+            -1.24747010616985 * z * (1.5 * z2 - 0.5)
+            + 1.6840846433293
+            * z
+            * (
+                1.75
+                * z
+                * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
+                - 1.125 * z2
+                + 0.375
+            )
+            + 0.498988042467941 * z,
+            0.241571547304372
+            * x
+            * (
+                2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+                + 9.375 * z2
+                - 1.875
+            ),
+            0.0456527312854602 * (x2 - y2) * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z),
+            0.00931882475114763 * x * (52.5 - 472.5 * z2) * (x2 - 3.0 * y2),
+            2.07566231488104 * z * (-6.0 * x2 * y2 + x4 + y4),
+            -0.65638205684017 * x * (-10.0 * x2 * y2 + x4 + 5.0 * y4),
+            4.09910463115149 * x**4 * xy
+            - 13.6636821038383 * xy**3
+            + 4.09910463115149 * xy * y**4,
+            -2.36661916223175 * yz * (-10.0 * x2 * y2 + 5.0 * x4 + y4),
+            0.00427144889505798 * xy * (x2 - y2) * (5197.5 * z2 - 472.5),
+            0.00584892228263444
+            * y
+            * (3.0 * x2 - y2)
+            * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z),
+            0.0701870673916132
+            * xy
+            * (
+                2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
+                - 91.875 * z2
+                + 13.125
+            ),
+            0.221950995245231
+            * y
+            * (
+                -2.8 * z * (1.5 - 7.5 * z2)
+                + 2.2
+                * z
+                * (
+                    2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+                    + 9.375 * z2
+                    - 1.875
+                )
+                - 4.8 * z
+            ),
+            -1.48328138624466
+            * z
+            * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
+            + 1.86469659985043
+            * z
+            * (
+                -1.33333333333333 * z * (1.5 * z2 - 0.5)
+                + 1.8
+                * z
+                * (
+                    1.75
+                    * z
+                    * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
+                    - 1.125 * z2
+                    + 0.375
+                )
+                + 0.533333333333333 * z
+            )
+            + 0.953538034014426 * z2
+            - 0.317846011338142,
+            0.221950995245231
+            * x
+            * (
+                -2.8 * z * (1.5 - 7.5 * z2)
+                + 2.2
+                * z
+                * (
+                    2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+                    + 9.375 * z2
+                    - 1.875
+                )
+                - 4.8 * z
+            ),
+            0.0350935336958066
+            * (x2 - y2)
+            * (
+                2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
+                - 91.875 * z2
+                + 13.125
+            ),
+            0.00584892228263444
+            * x
+            * (x2 - 3.0 * y2)
+            * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z),
+            0.0010678622237645 * (5197.5 * z2 - 472.5) * (-6.0 * x2 * y2 + x4 + y4),
+            -2.36661916223175 * xz * (-10.0 * x2 * y2 + x4 + 5.0 * y4),
+            0.683184105191914 * x2**3
+            + 10.2477615778787 * x2 * y4
+            - 10.2477615778787 * x4 * y2
+            - 0.683184105191914 * y2**3,
+            -0.707162732524596
+            * y
+            * (7.0 * x2**3 + 21.0 * x2 * y4 - 35.0 * x4 * y2 - y2**3),
+            2.6459606618019 * z * (6.0 * x**4 * xy - 20.0 * xy**3 + 6.0 * xy * y**4),
+            9.98394571852353e-5
+            * y
+            * (5197.5 - 67567.5 * z2)
+            * (-10.0 * x2 * y2 + 5.0 * x4 + y4),
+            0.00239614697244565
+            * xy
+            * (x2 - y2)
+            * (4.33333333333333 * z * (5197.5 * z2 - 472.5) - 3150.0 * z),
+            0.00397356022507413
+            * y
+            * (3.0 * x2 - y2)
+            * (
+                3.25 * z * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z)
+                + 1063.125 * z2
+                - 118.125
+            ),
+            0.0561946276120613
+            * xy
+            * (
+                -4.8 * z * (52.5 * z2 - 7.5)
+                + 2.6
+                * z
+                * (
+                    2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
+                    - 91.875 * z2
+                    + 13.125
+                )
+                + 48.0 * z
+            ),
+            0.206472245902897
+            * y
+            * (
+                -2.625 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+                + 2.16666666666667
+                * z
+                * (
+                    -2.8 * z * (1.5 - 7.5 * z2)
+                    + 2.2
+                    * z
+                    * (
+                        2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+                        + 9.375 * z2
+                        - 1.875
+                    )
+                    - 4.8 * z
+                )
+                - 10.9375 * z2
+                + 2.1875
+            ),
+            1.24862677781952 * z * (1.5 * z2 - 0.5)
+            - 1.68564615005635
+            * z
+            * (
+                1.75
+                * z
+                * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
+                - 1.125 * z2
+                + 0.375
+            )
+            + 2.02901851395672
+            * z
+            * (
+                -1.45833333333333
+                * z
+                * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
+                + 1.83333333333333
+                * z
+                * (
+                    -1.33333333333333 * z * (1.5 * z2 - 0.5)
+                    + 1.8
+                    * z
+                    * (
+                        1.75
+                        * z
+                        * (
+                            1.66666666666667 * z * (1.5 * z2 - 0.5)
+                            - 0.666666666666667 * z
+                        )
+                        - 1.125 * z2
+                        + 0.375
+                    )
+                    + 0.533333333333333 * z
+                )
+                + 0.9375 * z2
+                - 0.3125
+            )
+            - 0.499450711127808 * z,
+            0.206472245902897
+            * x
+            * (
+                -2.625 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+                + 2.16666666666667
+                * z
+                * (
+                    -2.8 * z * (1.5 - 7.5 * z2)
+                    + 2.2
+                    * z
+                    * (
+                        2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+                        + 9.375 * z2
+                        - 1.875
+                    )
+                    - 4.8 * z
+                )
+                - 10.9375 * z2
+                + 2.1875
+            ),
+            0.0280973138060306
+            * (x2 - y2)
+            * (
+                -4.8 * z * (52.5 * z2 - 7.5)
+                + 2.6
+                * z
+                * (
+                    2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
+                    - 91.875 * z2
+                    + 13.125
+                )
+                + 48.0 * z
+            ),
+            0.00397356022507413
+            * x
+            * (x2 - 3.0 * y2)
+            * (
+                3.25 * z * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z)
+                + 1063.125 * z2
+                - 118.125
+            ),
+            0.000599036743111412
+            * (4.33333333333333 * z * (5197.5 * z2 - 472.5) - 3150.0 * z)
+            * (-6.0 * x2 * y2 + x4 + y4),
+            9.98394571852353e-5
+            * x
+            * (5197.5 - 67567.5 * z2)
+            * (-10.0 * x2 * y2 + x4 + 5.0 * y4),
+            2.6459606618019 * z * (x2**3 + 15.0 * x2 * y4 - 15.0 * x4 * y2 - y2**3),
+            -0.707162732524596
+            * x
+            * (x2**3 + 35.0 * x2 * y4 - 21.0 * x4 * y2 - 7.0 * y2**3),
+        ],
+        -1,
+    )
+
+
+# @torch.jit.script
+def rsh_cart_8(xyz: torch.Tensor):
+    """Computes all real spherical harmonics up to degree 8.
+
+    This is an autogenerated method. See
+    https://github.com/cheind/torch-spherical-harmonics
+    for more information.
+
+    Params:
+        xyz: (N,...,3) tensor of points on the unit sphere
+
+    Returns:
+        rsh: (N,...,81) real spherical harmonics
+            projections of input. Ynm is found at index
+            `n*(n+1) + m`, with `0 <= n <= degree` and
+            `-n <= m <= n`.
+    """
+    x = xyz[..., 0]
+    y = xyz[..., 1]
+    z = xyz[..., 2]
+
+    x2 = x**2
+    y2 = y**2
+    z2 = z**2
+    xy = x * y
+    xz = x * z
+    yz = y * z
+    x4 = x2**2
+    y4 = y2**2
+    # z4 = z2**2
+    return torch.stack(
+        [
+            0.282094791773878 * torch.ones(1, device=xyz.device).expand(xyz.shape[:-1]),
+            -0.48860251190292 * y,
+            0.48860251190292 * z,
+            -0.48860251190292 * x,
+            1.09254843059208 * xy,
+            -1.09254843059208 * yz,
+            0.94617469575756 * z2 - 0.31539156525252,
+            -1.09254843059208 * xz,
+            0.54627421529604 * x2 - 0.54627421529604 * y2,
+            -0.590043589926644 * y * (3.0 * x2 - y2),
+            2.89061144264055 * xy * z,
+            0.304697199642977 * y * (1.5 - 7.5 * z2),
+            1.24392110863372 * z * (1.5 * z2 - 0.5) - 0.497568443453487 * z,
+            0.304697199642977 * x * (1.5 - 7.5 * z2),
+            1.44530572132028 * z * (x2 - y2),
+            -0.590043589926644 * x * (x2 - 3.0 * y2),
+            2.5033429417967 * xy * (x2 - y2),
+            -1.77013076977993 * yz * (3.0 * x2 - y2),
+            0.126156626101008 * xy * (52.5 * z2 - 7.5),
+            0.267618617422916 * y * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z),
+            1.48099765681286
+            * z
+            * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
+            - 0.952069922236839 * z2
+            + 0.317356640745613,
+            0.267618617422916 * x * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z),
+            0.063078313050504 * (x2 - y2) * (52.5 * z2 - 7.5),
+            -1.77013076977993 * xz * (x2 - 3.0 * y2),
+            -3.75501441269506 * x2 * y2
+            + 0.625835735449176 * x4
+            + 0.625835735449176 * y4,
+            -0.65638205684017 * y * (-10.0 * x2 * y2 + 5.0 * x4 + y4),
+            8.30264925952416 * xy * z * (x2 - y2),
+            0.00931882475114763 * y * (52.5 - 472.5 * z2) * (3.0 * x2 - y2),
+            0.0913054625709205 * xy * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z),
+            0.241571547304372
+            * y
+            * (
+                2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+                + 9.375 * z2
+                - 1.875
+            ),
+            -1.24747010616985 * z * (1.5 * z2 - 0.5)
+            + 1.6840846433293
+            * z
+            * (
+                1.75
+                * z
+                * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
+                - 1.125 * z2
+                + 0.375
+            )
+            + 0.498988042467941 * z,
+            0.241571547304372
+            * x
+            * (
+                2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+                + 9.375 * z2
+                - 1.875
+            ),
+            0.0456527312854602 * (x2 - y2) * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z),
+            0.00931882475114763 * x * (52.5 - 472.5 * z2) * (x2 - 3.0 * y2),
+            2.07566231488104 * z * (-6.0 * x2 * y2 + x4 + y4),
+            -0.65638205684017 * x * (-10.0 * x2 * y2 + x4 + 5.0 * y4),
+            4.09910463115149 * x**4 * xy
+            - 13.6636821038383 * xy**3
+            + 4.09910463115149 * xy * y**4,
+            -2.36661916223175 * yz * (-10.0 * x2 * y2 + 5.0 * x4 + y4),
+            0.00427144889505798 * xy * (x2 - y2) * (5197.5 * z2 - 472.5),
+            0.00584892228263444
+            * y
+            * (3.0 * x2 - y2)
+            * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z),
+            0.0701870673916132
+            * xy
+            * (
+                2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
+                - 91.875 * z2
+                + 13.125
+            ),
+            0.221950995245231
+            * y
+            * (
+                -2.8 * z * (1.5 - 7.5 * z2)
+                + 2.2
+                * z
+                * (
+                    2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+                    + 9.375 * z2
+                    - 1.875
+                )
+                - 4.8 * z
+            ),
+            -1.48328138624466
+            * z
+            * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
+            + 1.86469659985043
+            * z
+            * (
+                -1.33333333333333 * z * (1.5 * z2 - 0.5)
+                + 1.8
+                * z
+                * (
+                    1.75
+                    * z
+                    * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
+                    - 1.125 * z2
+                    + 0.375
+                )
+                + 0.533333333333333 * z
+            )
+            + 0.953538034014426 * z2
+            - 0.317846011338142,
+            0.221950995245231
+            * x
+            * (
+                -2.8 * z * (1.5 - 7.5 * z2)
+                + 2.2
+                * z
+                * (
+                    2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+                    + 9.375 * z2
+                    - 1.875
+                )
+                - 4.8 * z
+            ),
+            0.0350935336958066
+            * (x2 - y2)
+            * (
+                2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
+                - 91.875 * z2
+                + 13.125
+            ),
+            0.00584892228263444
+            * x
+            * (x2 - 3.0 * y2)
+            * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z),
+            0.0010678622237645 * (5197.5 * z2 - 472.5) * (-6.0 * x2 * y2 + x4 + y4),
+            -2.36661916223175 * xz * (-10.0 * x2 * y2 + x4 + 5.0 * y4),
+            0.683184105191914 * x2**3
+            + 10.2477615778787 * x2 * y4
+            - 10.2477615778787 * x4 * y2
+            - 0.683184105191914 * y2**3,
+            -0.707162732524596
+            * y
+            * (7.0 * x2**3 + 21.0 * x2 * y4 - 35.0 * x4 * y2 - y2**3),
+            2.6459606618019 * z * (6.0 * x**4 * xy - 20.0 * xy**3 + 6.0 * xy * y**4),
+            9.98394571852353e-5
+            * y
+            * (5197.5 - 67567.5 * z2)
+            * (-10.0 * x2 * y2 + 5.0 * x4 + y4),
+            0.00239614697244565
+            * xy
+            * (x2 - y2)
+            * (4.33333333333333 * z * (5197.5 * z2 - 472.5) - 3150.0 * z),
+            0.00397356022507413
+            * y
+            * (3.0 * x2 - y2)
+            * (
+                3.25 * z * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z)
+                + 1063.125 * z2
+                - 118.125
+            ),
+            0.0561946276120613
+            * xy
+            * (
+                -4.8 * z * (52.5 * z2 - 7.5)
+                + 2.6
+                * z
+                * (
+                    2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
+                    - 91.875 * z2
+                    + 13.125
+                )
+                + 48.0 * z
+            ),
+            0.206472245902897
+            * y
+            * (
+                -2.625 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+                + 2.16666666666667
+                * z
+                * (
+                    -2.8 * z * (1.5 - 7.5 * z2)
+                    + 2.2
+                    * z
+                    * (
+                        2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+                        + 9.375 * z2
+                        - 1.875
+                    )
+                    - 4.8 * z
+                )
+                - 10.9375 * z2
+                + 2.1875
+            ),
+            1.24862677781952 * z * (1.5 * z2 - 0.5)
+            - 1.68564615005635
+            * z
+            * (
+                1.75
+                * z
+                * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
+                - 1.125 * z2
+                + 0.375
+            )
+            + 2.02901851395672
+            * z
+            * (
+                -1.45833333333333
+                * z
+                * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
+                + 1.83333333333333
+                * z
+                * (
+                    -1.33333333333333 * z * (1.5 * z2 - 0.5)
+                    + 1.8
+                    * z
+                    * (
+                        1.75
+                        * z
+                        * (
+                            1.66666666666667 * z * (1.5 * z2 - 0.5)
+                            - 0.666666666666667 * z
+                        )
+                        - 1.125 * z2
+                        + 0.375
+                    )
+                    + 0.533333333333333 * z
+                )
+                + 0.9375 * z2
+                - 0.3125
+            )
+            - 0.499450711127808 * z,
+            0.206472245902897
+            * x
+            * (
+                -2.625 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+                + 2.16666666666667
+                * z
+                * (
+                    -2.8 * z * (1.5 - 7.5 * z2)
+                    + 2.2
+                    * z
+                    * (
+                        2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+                        + 9.375 * z2
+                        - 1.875
+                    )
+                    - 4.8 * z
+                )
+                - 10.9375 * z2
+                + 2.1875
+            ),
+            0.0280973138060306
+            * (x2 - y2)
+            * (
+                -4.8 * z * (52.5 * z2 - 7.5)
+                + 2.6
+                * z
+                * (
+                    2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
+                    - 91.875 * z2
+                    + 13.125
+                )
+                + 48.0 * z
+            ),
+            0.00397356022507413
+            * x
+            * (x2 - 3.0 * y2)
+            * (
+                3.25 * z * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z)
+                + 1063.125 * z2
+                - 118.125
+            ),
+            0.000599036743111412
+            * (4.33333333333333 * z * (5197.5 * z2 - 472.5) - 3150.0 * z)
+            * (-6.0 * x2 * y2 + x4 + y4),
+            9.98394571852353e-5
+            * x
+            * (5197.5 - 67567.5 * z2)
+            * (-10.0 * x2 * y2 + x4 + 5.0 * y4),
+            2.6459606618019 * z * (x2**3 + 15.0 * x2 * y4 - 15.0 * x4 * y2 - y2**3),
+            -0.707162732524596
+            * x
+            * (x2**3 + 35.0 * x2 * y4 - 21.0 * x4 * y2 - 7.0 * y2**3),
+            5.83141328139864 * xy * (x2**3 + 7.0 * x2 * y4 - 7.0 * x4 * y2 - y2**3),
+            -2.91570664069932
+            * yz
+            * (7.0 * x2**3 + 21.0 * x2 * y4 - 35.0 * x4 * y2 - y2**3),
+            7.87853281621404e-6
+            * (1013512.5 * z2 - 67567.5)
+            * (6.0 * x**4 * xy - 20.0 * xy**3 + 6.0 * xy * y**4),
+            5.10587282657803e-5
+            * y
+            * (5.0 * z * (5197.5 - 67567.5 * z2) + 41580.0 * z)
+            * (-10.0 * x2 * y2 + 5.0 * x4 + y4),
+            0.00147275890257803
+            * xy
+            * (x2 - y2)
+            * (
+                3.75 * z * (4.33333333333333 * z * (5197.5 * z2 - 472.5) - 3150.0 * z)
+                - 14293.125 * z2
+                + 1299.375
+            ),
+            0.0028519853513317
+            * y
+            * (3.0 * x2 - y2)
+            * (
+                -7.33333333333333 * z * (52.5 - 472.5 * z2)
+                + 3.0
+                * z
+                * (
+                    3.25 * z * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z)
+                    + 1063.125 * z2
+                    - 118.125
+                )
+                - 560.0 * z
+            ),
+            0.0463392770473559
+            * xy
+            * (
+                -4.125 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
+                + 2.5
+                * z
+                * (
+                    -4.8 * z * (52.5 * z2 - 7.5)
+                    + 2.6
+                    * z
+                    * (
+                        2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
+                        - 91.875 * z2
+                        + 13.125
+                    )
+                    + 48.0 * z
+                )
+                + 137.8125 * z2
+                - 19.6875
+            ),
+            0.193851103820053
+            * y
+            * (
+                3.2 * z * (1.5 - 7.5 * z2)
+                - 2.51428571428571
+                * z
+                * (
+                    2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+                    + 9.375 * z2
+                    - 1.875
+                )
+                + 2.14285714285714
+                * z
+                * (
+                    -2.625 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+                    + 2.16666666666667
+                    * z
+                    * (
+                        -2.8 * z * (1.5 - 7.5 * z2)
+                        + 2.2
+                        * z
+                        * (
+                            2.25
+                            * z
+                            * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+                            + 9.375 * z2
+                            - 1.875
+                        )
+                        - 4.8 * z
+                    )
+                    - 10.9375 * z2
+                    + 2.1875
+                )
+                + 5.48571428571429 * z
+            ),
+            1.48417251362228
+            * z
+            * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
+            - 1.86581687426801
+            * z
+            * (
+                -1.33333333333333 * z * (1.5 * z2 - 0.5)
+                + 1.8
+                * z
+                * (
+                    1.75
+                    * z
+                    * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
+                    - 1.125 * z2
+                    + 0.375
+                )
+                + 0.533333333333333 * z
+            )
+            + 2.1808249179756
+            * z
+            * (
+                1.14285714285714 * z * (1.5 * z2 - 0.5)
+                - 1.54285714285714
+                * z
+                * (
+                    1.75
+                    * z
+                    * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
+                    - 1.125 * z2
+                    + 0.375
+                )
+                + 1.85714285714286
+                * z
+                * (
+                    -1.45833333333333
+                    * z
+                    * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
+                    + 1.83333333333333
+                    * z
+                    * (
+                        -1.33333333333333 * z * (1.5 * z2 - 0.5)
+                        + 1.8
+                        * z
+                        * (
+                            1.75
+                            * z
+                            * (
+                                1.66666666666667 * z * (1.5 * z2 - 0.5)
+                                - 0.666666666666667 * z
+                            )
+                            - 1.125 * z2
+                            + 0.375
+                        )
+                        + 0.533333333333333 * z
+                    )
+                    + 0.9375 * z2
+                    - 0.3125
+                )
+                - 0.457142857142857 * z
+            )
+            - 0.954110901614325 * z2
+            + 0.318036967204775,
+            0.193851103820053
+            * x
+            * (
+                3.2 * z * (1.5 - 7.5 * z2)
+                - 2.51428571428571
+                * z
+                * (
+                    2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+                    + 9.375 * z2
+                    - 1.875
+                )
+                + 2.14285714285714
+                * z
+                * (
+                    -2.625 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+                    + 2.16666666666667
+                    * z
+                    * (
+                        -2.8 * z * (1.5 - 7.5 * z2)
+                        + 2.2
+                        * z
+                        * (
+                            2.25
+                            * z
+                            * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+                            + 9.375 * z2
+                            - 1.875
+                        )
+                        - 4.8 * z
+                    )
+                    - 10.9375 * z2
+                    + 2.1875
+                )
+                + 5.48571428571429 * z
+            ),
+            0.0231696385236779
+            * (x2 - y2)
+            * (
+                -4.125 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
+                + 2.5
+                * z
+                * (
+                    -4.8 * z * (52.5 * z2 - 7.5)
+                    + 2.6
+                    * z
+                    * (
+                        2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
+                        - 91.875 * z2
+                        + 13.125
+                    )
+                    + 48.0 * z
+                )
+                + 137.8125 * z2
+                - 19.6875
+            ),
+            0.0028519853513317
+            * x
+            * (x2 - 3.0 * y2)
+            * (
+                -7.33333333333333 * z * (52.5 - 472.5 * z2)
+                + 3.0
+                * z
+                * (
+                    3.25 * z * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z)
+                    + 1063.125 * z2
+                    - 118.125
+                )
+                - 560.0 * z
+            ),
+            0.000368189725644507
+            * (-6.0 * x2 * y2 + x4 + y4)
+            * (
+                3.75 * z * (4.33333333333333 * z * (5197.5 * z2 - 472.5) - 3150.0 * z)
+                - 14293.125 * z2
+                + 1299.375
+            ),
+            5.10587282657803e-5
+            * x
+            * (5.0 * z * (5197.5 - 67567.5 * z2) + 41580.0 * z)
+            * (-10.0 * x2 * y2 + x4 + 5.0 * y4),
+            7.87853281621404e-6
+            * (1013512.5 * z2 - 67567.5)
+            * (x2**3 + 15.0 * x2 * y4 - 15.0 * x4 * y2 - y2**3),
+            -2.91570664069932
+            * xz
+            * (x2**3 + 35.0 * x2 * y4 - 21.0 * x4 * y2 - 7.0 * y2**3),
+            -20.4099464848952 * x2**3 * y2
+            - 20.4099464848952 * x2 * y2**3
+            + 0.72892666017483 * x4**2
+            + 51.0248662122381 * x4 * y4
+            + 0.72892666017483 * y4**2,
+        ],
+        -1,
+    )
+
+
+__all__ = [
+    "rsh_cart_0",
+    "rsh_cart_1",
+    "rsh_cart_2",
+    "rsh_cart_3",
+    "rsh_cart_4",
+    "rsh_cart_5",
+    "rsh_cart_6",
+    "rsh_cart_7",
+    "rsh_cart_8",
+]
+
+
+from typing import Optional
+import torch
+
+
+class SphHarm(torch.nn.Module):
+    def __init__(self, m, n, dtype=torch.float32) -> None:
+        super().__init__()
+        self.dtype = dtype
+        m = torch.tensor(list(range(-m + 1, m)))
+        n = torch.tensor(list(range(n)))
+        self.is_normalized = False
+        vals = torch.cartesian_prod(m, n).T
+        vals = vals[:, vals[0] <= vals[1]]
+        m, n = vals.unbind(0)
+
+        self.register_buffer("m", tensor=m)
+        self.register_buffer("n", tensor=n)
+        self.register_buffer("l_max", tensor=torch.max(self.n))
+
+        f_a, f_b, initial_value, d0_mask_3d, d1_mask_3d = self._init_legendre()
+        self.register_buffer("f_a", tensor=f_a)
+        self.register_buffer("f_b", tensor=f_b)
+        self.register_buffer("d0_mask_3d", tensor=d0_mask_3d)
+        self.register_buffer("d1_mask_3d", tensor=d1_mask_3d)
+        self.register_buffer("initial_value", tensor=initial_value)
+
+    @property
+    def device(self):
+        return next(self.buffers()).device
+
+    def forward(self, points: torch.Tensor) -> torch.Tensor:
+        """Computes the spherical harmonics."""
+        # Y_l^m = (-1) ^ m c_l^m P_l^m(cos(theta)) exp(i m phi)
+        B, N, D = points.shape
+        dtype = points.dtype
+        theta, phi = points.view(-1, D).to(self.dtype).unbind(-1)
+        cos_colatitude = torch.cos(phi)
+        legendre = self._gen_associated_legendre(cos_colatitude)
+        vals = torch.stack([self.m.abs(), self.n], dim=0)
+        vals = torch.cat(
+            [
+                vals.repeat(1, theta.shape[0]),
+                torch.arange(theta.shape[0], device=theta.device)
+                .unsqueeze(0)
+                .repeat_interleave(vals.shape[1], dim=1),
+            ],
+            dim=0,
+        )
+        legendre_vals = legendre[vals[0], vals[1], vals[2]]
+        legendre_vals = legendre_vals.reshape(-1, theta.shape[0])
+        angle = torch.outer(self.m.abs(), theta)
+        vandermonde = torch.complex(torch.cos(angle), torch.sin(angle))
+        harmonics = torch.complex(
+            legendre_vals * torch.real(vandermonde),
+            legendre_vals * torch.imag(vandermonde),
+        )
+
+        # Negative order.
+        m = self.m.unsqueeze(-1)
+        harmonics = torch.where(
+            m < 0, (-1.0) ** m.abs() * torch.conj(harmonics), harmonics
+        )
+        harmonics = harmonics.permute(1, 0).reshape(B, N, -1).to(dtype)
+        return harmonics
+
+    def _gen_recurrence_mask(self) -> tuple[torch.Tensor, torch.Tensor]:
+        """Generates mask for recurrence relation on the remaining entries.
+
+        The remaining entries are with respect to the diagonal and offdiagonal
+        entries.
+
+        Args:
+        l_max: see `gen_normalized_legendre`.
+        Returns:
+        torch.Tensors representing the mask used by the recurrence relations.
+        """
+
+        # Computes all coefficients.
+        m_mat, l_mat = torch.meshgrid(
+            torch.arange(0, self.l_max + 1, device=self.device, dtype=self.dtype),
+            torch.arange(0, self.l_max + 1, device=self.device, dtype=self.dtype),
+            indexing="ij",
+        )
+        if self.is_normalized:
+            c0 = l_mat * l_mat
+            c1 = m_mat * m_mat
+            c2 = 2.0 * l_mat
+            c3 = (l_mat - 1.0) * (l_mat - 1.0)
+            d0 = torch.sqrt((4.0 * c0 - 1.0) / (c0 - c1))
+            d1 = torch.sqrt(((c2 + 1.0) * (c3 - c1)) / ((c2 - 3.0) * (c0 - c1)))
+        else:
+            d0 = (2.0 * l_mat - 1.0) / (l_mat - m_mat)
+            d1 = (l_mat + m_mat - 1.0) / (l_mat - m_mat)
+
+        d0_mask_indices = torch.triu_indices(self.l_max + 1, 1)
+        d1_mask_indices = torch.triu_indices(self.l_max + 1, 2)
+
+        d_zeros = torch.zeros(
+            (self.l_max + 1, self.l_max + 1), dtype=self.dtype, device=self.device
+        )
+        d_zeros[d0_mask_indices] = d0[d0_mask_indices]
+        d0_mask = d_zeros
+
+        d_zeros = torch.zeros(
+            (self.l_max + 1, self.l_max + 1), dtype=self.dtype, device=self.device
+        )
+        d_zeros[d1_mask_indices] = d1[d1_mask_indices]
+        d1_mask = d_zeros
+
+        # Creates a 3D mask that contains 1s on the diagonal plane and 0s elsewhere.
+        i = torch.arange(self.l_max + 1, device=self.device)[:, None, None]
+        j = torch.arange(self.l_max + 1, device=self.device)[None, :, None]
+        k = torch.arange(self.l_max + 1, device=self.device)[None, None, :]
+        mask = (i + j - k == 0).to(self.dtype)
+        d0_mask_3d = torch.einsum("jk,ijk->ijk", d0_mask, mask)
+        d1_mask_3d = torch.einsum("jk,ijk->ijk", d1_mask, mask)
+        return (d0_mask_3d, d1_mask_3d)
+
+    def _recursive(self, i: int, p_val: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
+        coeff_0 = self.d0_mask_3d[i]
+        coeff_1 = self.d1_mask_3d[i]
+        h = torch.einsum(
+            "ij,ijk->ijk",
+            coeff_0,
+            torch.einsum("ijk,k->ijk", torch.roll(p_val, shifts=1, dims=1), x),
+        ) - torch.einsum("ij,ijk->ijk", coeff_1, torch.roll(p_val, shifts=2, dims=1))
+        p_val = p_val + h
+        return p_val
+
+    def _init_legendre(self):
+        a_idx = torch.arange(1, self.l_max + 1, dtype=self.dtype, device=self.device)
+        b_idx = torch.arange(self.l_max, dtype=self.dtype, device=self.device)
+        if self.is_normalized:
+            # The initial value p(0,0).
+            initial_value: torch.Tensor = torch.tensor(
+                0.5 / (torch.pi**0.5), device=self.device
+            )
+            f_a = torch.cumprod(-1 * torch.sqrt(1.0 + 0.5 / a_idx), dim=0)
+            f_b = torch.sqrt(2.0 * b_idx + 3.0)
+        else:
+            # The initial value p(0,0).
+            initial_value = torch.tensor(1.0, device=self.device)
+            f_a = torch.cumprod(1.0 - 2.0 * a_idx, dim=0)
+            f_b = 2.0 * b_idx + 1.0
+
+        d0_mask_3d, d1_mask_3d = self._gen_recurrence_mask()
+        return f_a, f_b, initial_value, d0_mask_3d, d1_mask_3d
+
+    def _gen_associated_legendre(self, x: torch.Tensor) -> torch.Tensor:
+        r"""Computes associated Legendre functions (ALFs) of the first kind.
+
+        The ALFs of the first kind are used in spherical harmonics. The spherical
+        harmonic of degree `l` and order `m` can be written as
+        `Y_l^m(θ, φ) = N_l^m * P_l^m(cos(θ)) * exp(i m φ)`, where `N_l^m` is the
+        normalization factor and θ and φ are the colatitude and longitude,
+        repectively. `N_l^m` is chosen in the way that the spherical harmonics form
+        a set of orthonormal basis function of L^2(S^2). For the computational
+        efficiency of spherical harmonics transform, the normalization factor is
+        used in the computation of the ALFs. In addition, normalizing `P_l^m`
+        avoids overflow/underflow and achieves better numerical stability. Three
+        recurrence relations are used in the computation.
+
+        Args:
+        l_max: The maximum degree of the associated Legendre function. Both the
+            degrees and orders are `[0, 1, 2, ..., l_max]`.
+        x: A vector of type `float32`, `float64` containing the sampled points in
+            spherical coordinates, at which the ALFs are computed; `x` is essentially
+            `cos(θ)`. For the numerical integration used by the spherical harmonics
+            transforms, `x` contains the quadrature points in the interval of
+            `[-1, 1]`. There are several approaches to provide the quadrature points:
+            Gauss-Legendre method (`scipy.special.roots_legendre`), Gauss-Chebyshev
+            method (`scipy.special.roots_chebyu`), and Driscoll & Healy
+            method (Driscoll, James R., and Dennis M. Healy. "Computing Fourier
+            transforms and convolutions on the 2-sphere." Advances in applied
+            mathematics 15, no. 2 (1994): 202-250.). The Gauss-Legendre quadrature
+            points are nearly equal-spaced along θ and provide exact discrete
+            orthogonality, (P^m)^T W P_m = I, where `T` represents the transpose
+            operation, `W` is a diagonal matrix containing the quadrature weights,
+            and `I` is the identity matrix. The Gauss-Chebyshev points are equally
+            spaced, which only provide approximate discrete orthogonality. The
+            Driscoll & Healy qudarture points are equally spaced and provide the
+            exact discrete orthogonality. The number of sampling points is required to
+            be twice as the number of frequency points (modes) in the Driscoll & Healy
+            approach, which enables FFT and achieves a fast spherical harmonics
+            transform.
+        is_normalized: True if the associated Legendre functions are normalized.
+            With normalization, `N_l^m` is applied such that the spherical harmonics
+            form a set of orthonormal basis functions of L^2(S^2).
+
+        Returns:
+        The 3D array of shape `(l_max + 1, l_max + 1, len(x))` containing the values
+        of the ALFs at `x`; the dimensions in the sequence of order, degree, and
+        evalution points.
+        """
+        p = torch.zeros(
+            (self.l_max + 1, self.l_max + 1, x.shape[0]), dtype=x.dtype, device=x.device
+        )
+        p[0, 0] = self.initial_value
+
+        # Compute the diagonal entries p(l,l) with recurrence.
+        y = torch.cumprod(
+            torch.broadcast_to(torch.sqrt(1.0 - x * x), (self.l_max, x.shape[0])), dim=0
+        )
+        p_diag = self.initial_value * torch.einsum("i,ij->ij", self.f_a, y)
+        # torch.diag_indices(l_max + 1)
+        diag_indices = torch.stack(
+            [torch.arange(0, self.l_max + 1, device=x.device)] * 2, dim=0
+        )
+        p[(diag_indices[0][1:], diag_indices[1][1:])] = p_diag
+
+        diag_indices = torch.stack(
+            [torch.arange(0, self.l_max, device=x.device)] * 2, dim=0
+        )
+
+        # Compute the off-diagonal entries with recurrence.
+        p_offdiag = torch.einsum(
+            "ij,ij->ij",
+            torch.einsum("i,j->ij", self.f_b, x),
+            p[(diag_indices[0], diag_indices[1])],
+        )  # p[torch.diag_indices(l_max)])
+        p[(diag_indices[0][: self.l_max], diag_indices[1][: self.l_max] + 1)] = (
+            p_offdiag
+        )
+
+        # Compute the remaining entries with recurrence.
+        if self.l_max > 1:
+            for i in range(2, self.l_max + 1):
+                p = self._recursive(i, p, x)
+        return p
diff --git a/flash3d/unidepth/utils/visualization.py b/flash3d/unidepth/utils/visualization.py
new file mode 100644
index 0000000000000000000000000000000000000000..8504ec0430924847a1c2123be0dcea6e00c6945d
--- /dev/null
+++ b/flash3d/unidepth/utils/visualization.py
@@ -0,0 +1,201 @@
+"""
+Author: Luigi Piccinelli
+Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/)
+"""
+
+import os
+
+import numpy as np
+from PIL import Image
+import matplotlib.cm
+import wandb
+import torch
+
+from unidepth.utils.misc import ssi_helper
+
+
+def colorize(
+    value: np.ndarray, vmin: float = None, vmax: float = None, cmap: str = "magma_r"
+):
+    # if already RGB, do nothing
+    if value.ndim > 2:
+        if value.shape[-1] > 1:
+            return value
+        value = value[..., 0]
+    invalid_mask = value < 0.0001
+    # normalize
+    vmin = value.min() if vmin is None else vmin
+    vmax = value.max() if vmax is None else vmax
+    value = (value - vmin) / (vmax - vmin)  # vmin..vmax
+
+    # set color
+    cmapper = matplotlib.cm.get_cmap(cmap)
+    value = cmapper(value, bytes=True)  # (nxmx4)
+    value[invalid_mask] = 0
+    img = value[..., :3]
+    return img
+
+
+def image_grid(imgs: list[np.ndarray], rows: int, cols: int) -> np.ndarray:
+    if not len(imgs):
+        return None
+    assert len(imgs) == rows * cols
+    h, w = imgs[0].shape[:2]
+    grid = Image.new("RGB", size=(cols * w, rows * h))
+
+    for i, img in enumerate(imgs):
+        grid.paste(
+            Image.fromarray(img.astype(np.uint8)).resize(
+                (w, h), resample=Image.BILINEAR
+            ),
+            box=(i % cols * w, i // cols * h),
+        )
+
+    return np.array(grid)
+
+
+def get_pointcloud_from_rgbd(
+    image: np.array,
+    depth: np.array,
+    mask: np.ndarray,
+    intrinsic_matrix: np.array,
+    extrinsic_matrix: np.array = None,
+):
+    depth = np.array(depth).squeeze()
+    mask = np.array(mask).squeeze()
+    # Mask the depth array
+    masked_depth = np.ma.masked_where(mask == False, depth)
+    # masked_depth = np.ma.masked_greater(masked_depth, 8000)
+    # Create idx array
+    idxs = np.indices(masked_depth.shape)
+    u_idxs = idxs[1]
+    v_idxs = idxs[0]
+    # Get only non-masked depth and idxs
+    z = masked_depth[~masked_depth.mask]
+    compressed_u_idxs = u_idxs[~masked_depth.mask]
+    compressed_v_idxs = v_idxs[~masked_depth.mask]
+    image = np.stack(
+        [image[..., i][~masked_depth.mask] for i in range(image.shape[-1])], axis=-1
+    )
+
+    # Calculate local position of each point
+    # Apply vectorized math to depth using compressed arrays
+    cx = intrinsic_matrix[0, 2]
+    fx = intrinsic_matrix[0, 0]
+    x = (compressed_u_idxs - cx) * z / fx
+    cy = intrinsic_matrix[1, 2]
+    fy = intrinsic_matrix[1, 1]
+    # Flip y as we want +y pointing up not down
+    y = -((compressed_v_idxs - cy) * z / fy)
+
+    # # Apply camera_matrix to pointcloud as to get the pointcloud in world coords
+    # if extrinsic_matrix is not None:
+    #     # Calculate camera pose from extrinsic matrix
+    #     camera_matrix = np.linalg.inv(extrinsic_matrix)
+    #     # Create homogenous array of vectors by adding 4th entry of 1
+    #     # At the same time flip z as for eye space the camera is looking down the -z axis
+    #     w = np.ones(z.shape)
+    #     x_y_z_eye_hom = np.vstack((x, y, -z, w))
+    #     # Transform the points from eye space to world space
+    #     x_y_z_world = np.dot(camera_matrix, x_y_z_eye_hom)[:3]
+    #     return x_y_z_world.T
+    # else:
+    x_y_z_local = np.stack((x, y, z), axis=-1)
+    return np.concatenate([x_y_z_local, image], axis=-1)
+
+
+def save_file_ply(xyz, rgb, pc_file):
+    if rgb.max() < 1.001:
+        rgb = rgb * 255.0
+    rgb = rgb.astype(np.uint8)
+    # print(rgb)
+    with open(pc_file, "w") as f:
+        # headers
+        f.writelines(
+            [
+                "ply\n" "format ascii 1.0\n",
+                "element vertex {}\n".format(xyz.shape[0]),
+                "property float x\n",
+                "property float y\n",
+                "property float z\n",
+                "property uchar red\n",
+                "property uchar green\n",
+                "property uchar blue\n",
+                "end_header\n",
+            ]
+        )
+
+        for i in range(xyz.shape[0]):
+            str_v = "{:10.6f} {:10.6f} {:10.6f} {:d} {:d} {:d}\n".format(
+                xyz[i, 0], xyz[i, 1], xyz[i, 2], rgb[i, 0], rgb[i, 1], rgb[i, 2]
+            )
+            f.write(str_v)
+
+
+# really awful fct... FIXME
+def log_train_artifacts(rgbs, gts, preds, ds_name, step, infos={}):
+    rgbs = [
+        (127.5 * (rgb + 1))
+        .clip(0, 255)
+        .to(torch.uint8)
+        .cpu()
+        .detach()
+        .permute(1, 2, 0)
+        .numpy()
+        for rgb in rgbs
+    ]
+
+    new_gts, new_preds = [], []
+    if len(gts) > 0:
+        for i, gt in enumerate(gts):
+            scale, shift = ssi_helper(
+                gts[i][gts[i] > 0].cpu().detach(), preds[i][gts[i] > 0].cpu().detach()
+            )
+            gt = gts[i].cpu().detach().squeeze().numpy()
+            pred = (preds[i].cpu().detach() * scale + shift).squeeze().numpy()
+            vmin = gt[gt > 0].min() if (gt > 0).any() else 0.0
+            vmax = gt.max() if (gt > 0).any() else 0.1
+            new_gts.append(colorize(gt, vmin=vmin, vmax=vmax))
+            new_preds.append(colorize(pred, vmin=vmin, vmax=vmax))
+        gts, preds = new_gts, new_preds
+    else:
+        preds = [
+            colorize(pred.cpu().detach().squeeze().numpy(), 0.0, 80.0)
+            for i, pred in enumerate(preds)
+        ]
+
+    num_additional, additionals = 0, []
+    for name, info in infos.items():
+        num_additional += 1
+        if info.shape[1] == 3:
+            additionals.extend(
+                [
+                    (127.5 * (x + 1))
+                    .clip(0, 255)
+                    .to(torch.uint8)
+                    .cpu()
+                    .detach()
+                    .permute(1, 2, 0)
+                    .numpy()
+                    for x in info[:4]
+                ]
+            )
+        else:
+            additionals.extend(
+                [
+                    colorize(x.cpu().detach().squeeze().numpy())
+                    for i, x in enumerate(info[:4])
+                ]
+            )
+
+    num_rows = 2 + int(len(gts) > 0) + num_additional
+    artifacts_grid = image_grid(
+        [*rgbs, *gts, *preds, *additionals], num_rows, len(rgbs)
+    )
+    try:
+        wandb.log({f"{ds_name}_training": [wandb.Image(artifacts_grid)]}, step=step)
+    except:
+        Image.fromarray(artifacts_grid).save(
+            os.path.join(os.environ["HOME"], "Workspace", f"art_grid{step}.png")
+        )
+        print("Logging training images failed")
diff --git a/flash3d/util/vis3d.py b/flash3d/util/vis3d.py
new file mode 100644
index 0000000000000000000000000000000000000000..deb53d03ffa735e4352e1f6eda38f12164b1fd71
--- /dev/null
+++ b/flash3d/util/vis3d.py
@@ -0,0 +1,135 @@
+from pathlib import Path
+from jaxtyping import Float
+import numpy as np
+from scipy.spatial.transform import Rotation as R
+from plyfile import PlyData, PlyElement
+import torch
+from torch import Tensor
+from einops import rearrange, einsum
+
+
+def construct_list_of_attributes(num_rest: int) -> list[str]:
+    attributes = ["x", "y", "z", "nx", "ny", "nz"]
+    for i in range(3):
+        attributes.append(f"f_dc_{i}")
+    for i in range(num_rest):
+        attributes.append(f"f_rest_{i}")
+    attributes.append("opacity")
+    for i in range(3):
+        attributes.append(f"scale_{i}")
+    for i in range(4):
+        attributes.append(f"rot_{i}")
+    return attributes
+
+
+def export_ply(
+    means: Float[Tensor, "gaussian 3"],
+    scales: Float[Tensor, "gaussian 3"],
+    rotations: Float[Tensor, "gaussian 4"],
+    harmonics: Float[Tensor, "gaussian 3 d_sh"],
+    opacities: Float[Tensor, "gaussian"],
+    path: Path,
+):
+    path = Path(path)
+    # Shift the scene so that the median Gaussian is at the origin.
+    means = means - means.median(dim=0).values
+
+    # Rescale the scene so that most Gaussians are within range [-1, 1].
+    scale_factor = means.abs().quantile(0.95, dim=0).max()
+    means = means / scale_factor
+    scales = scales / scale_factor
+    scales = scales * 4.0
+    scales = torch.clamp(scales, 0, 0.0075)
+
+    # Define a rotation that makes +Z be the world up vector.
+    # rotation = [
+    #     [0, 0, 1],
+    #     [-1, 0, 0],
+    #     [0, -1, 0],
+    # ]
+    rotation = [
+        [1, 0, 0],
+        [0, 1, 0],
+        [0, 0, 1],
+    ]
+    rotation = torch.tensor(rotation, dtype=torch.float32, device=means.device)
+
+    # The Polycam viewer seems to start at a 45 degree angle. Since we want to be
+    # looking directly at the object, we compose a 45 degree rotation onto the above
+    # rotation.
+    # adjustment = torch.tensor(
+    #     R.from_rotvec([0, 0, -45], True).as_matrix(),
+    #     dtype=torch.float32,
+    #     device=means.device,
+    # )
+    # rotation = adjustment @ rotation
+
+    # We also want to see the scene in camera space (as the default view). We therefore
+    # compose the w2c rotation onto the above rotation.
+    # rotation = rotation @ extrinsics[:3, :3].inverse()
+
+    # Apply the rotation to the means (Gaussian positions).
+    means = einsum(rotation, means, "i j, ... j -> ... i")
+
+    # Apply the rotation to the Gaussian rotations.
+    rotations = R.from_quat(rotations.detach().cpu().numpy()).as_matrix()
+    rotations = rotation.detach().cpu().numpy() @ rotations
+    rotations = R.from_matrix(rotations).as_quat()
+    x, y, z, w = rearrange(rotations, "g xyzw -> xyzw g")
+    rotations = np.stack((w, x, y, z), axis=-1)
+
+    # Since our axes are swizzled for the spherical harmonics, we only export the DC
+    # band.
+    harmonics_view_invariant = harmonics
+
+    dtype_full = [(attribute, "f4") for attribute in construct_list_of_attributes(0)]
+    elements = np.empty(means.shape[0], dtype=dtype_full)
+    attributes = (
+        means.detach().cpu().numpy(),
+        torch.zeros_like(means).detach().cpu().numpy(),
+        harmonics_view_invariant.detach().cpu().contiguous().numpy(),
+        opacities.detach().cpu().numpy(),
+        scales.log().detach().cpu().numpy(),
+        rotations,
+    )
+    attributes = np.concatenate(attributes, axis=1)
+    elements[:] = list(map(tuple, attributes))
+    path.parent.mkdir(exist_ok=True, parents=True)
+    PlyData([PlyElement.describe(elements, "vertex")]).write(path)
+
+
+def save_ply(outputs, path, num_gauss=3):
+    pad = 32
+
+    def crop_r(t):
+        h, w = 256, 384
+        H = h + pad * 2
+        W = w + pad * 2
+        t = rearrange(t, "b c (h w) -> b c h w", h=H, w=W)
+        t = t[..., pad:H-pad, pad:W-pad]
+        t = rearrange(t, "b c h w -> b c (h w)")
+        return t
+
+    def crop(t):
+        h, w = 256, 384
+        H = h + pad * 2
+        W = w + pad * 2
+        t = t[..., pad:H-pad, pad:W-pad]
+        return t
+
+    # import pdb
+    # pdb.set_trace()
+    means = rearrange(crop_r(outputs[('gauss_means', 0, 0)]), "(b v) c n -> b (v n) c", v=num_gauss)[0, :, :3]
+    scales = rearrange(crop(outputs[('gauss_scaling', 0, 0)]), "(b v) c h w -> b (v h w) c", v=num_gauss)[0]
+    rotations = rearrange(crop(outputs[('gauss_rotation', 0, 0)]), "(b v) c h w -> b (v h w) c", v=num_gauss)[0]
+    opacities = rearrange(crop(outputs[('gauss_opacity', 0, 0)]), "(b v) c h w -> b (v h w) c", v=num_gauss)[0]
+    harmonics = rearrange(crop(outputs[('gauss_features_dc', 0, 0)]), "(b v) c h w -> b (v h w) c", v=num_gauss)[0]
+
+    export_ply(
+        means,
+        scales,
+        rotations,
+        harmonics,
+        opacities,
+        path
+    )
\ No newline at end of file
diff --git a/pre-requirements.txt b/pre-requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..63860662be0d341df031b72e453683cbf5cbd1df
--- /dev/null
+++ b/pre-requirements.txt
@@ -0,0 +1,5 @@
+--extra-index-url https://download.pytorch.org/whl/cu118
+torch==2.2.2
+torchvision
+torchaudio
+xformers==0.0.25.post1
\ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..2b23c51e40c57a4c0b1d91d1a799d15d3ae77e01
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,16 @@
+einops
+huggingface-hub>=0.22.0
+imageio
+matplotlib
+safetensors
+scipy
+timm
+tqdm
+wandb
+neptune
+scikit-image
+plyfile
+omegaconf
+jaxtyping
+gradio
+spaces
\ No newline at end of file