Spaces:
Runtime error
Runtime error
| """ | |
| Author: Luigi Piccinelli | |
| Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/) | |
| """ | |
| import importlib | |
| from copy import deepcopy | |
| from math import ceil | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torchvision.transforms.functional as TF | |
| from einops import rearrange | |
| from huggingface_hub import PyTorchModelHubMixin | |
| from unidepth.models.unidepthv1.decoder import Decoder | |
| from unidepth.utils.constants import (IMAGENET_DATASET_MEAN, | |
| IMAGENET_DATASET_STD) | |
| from unidepth.utils.distributed import is_main_process | |
| from unidepth.utils.geometric import (generate_rays, | |
| spherical_zbuffer_to_euclidean) | |
| from unidepth.utils.misc import get_params | |
| MAP_BACKBONES = {"ViTL14": "vitl14", "ConvNextL": "cnvnxtl"} | |
| # inference helpers | |
| def _paddings(image_shape, network_shape): | |
| cur_h, cur_w = image_shape | |
| h, w = network_shape | |
| pad_top, pad_bottom = (h - cur_h) // 2, h - cur_h - (h - cur_h) // 2 | |
| pad_left, pad_right = (w - cur_w) // 2, w - cur_w - (w - cur_w) // 2 | |
| return pad_left, pad_right, pad_top, pad_bottom | |
| def _shapes(image_shape, network_shape): | |
| h, w = image_shape | |
| input_ratio = w / h | |
| output_ratio = network_shape[1] / network_shape[0] | |
| if output_ratio > input_ratio: | |
| ratio = network_shape[0] / h | |
| elif output_ratio <= input_ratio: | |
| ratio = network_shape[1] / w | |
| return (ceil(h * ratio - 0.5), ceil(w * ratio - 0.5)), ratio | |
| def _preprocess(rgbs, intrinsics, shapes, pads, ratio, output_shapes): | |
| (pad_left, pad_right, pad_top, pad_bottom) = pads | |
| rgbs = F.interpolate( | |
| rgbs, size=shapes, mode="bilinear", align_corners=False, antialias=True | |
| ) | |
| rgbs = F.pad(rgbs, (pad_left, pad_right, pad_top, pad_bottom), mode="constant") | |
| if intrinsics is not None: | |
| intrinsics = intrinsics.clone() | |
| intrinsics[:, 0, 0] = intrinsics[:, 0, 0] * ratio | |
| intrinsics[:, 1, 1] = intrinsics[:, 1, 1] * ratio | |
| intrinsics[:, 0, 2] = intrinsics[:, 0, 2] * ratio + pad_left | |
| intrinsics[:, 1, 2] = intrinsics[:, 1, 2] * ratio + pad_top | |
| return rgbs, intrinsics | |
| return rgbs, None | |
| def _postprocess(predictions, intrinsics, shapes, pads, ratio, original_shapes): | |
| (pad_left, pad_right, pad_top, pad_bottom) = pads | |
| # pred mean, trim paddings, and upsample to input dim | |
| predictions = sum( | |
| [ | |
| F.interpolate( | |
| x.clone(), | |
| size=shapes, | |
| mode="bilinear", | |
| align_corners=False, | |
| antialias=True, | |
| ) | |
| for x in predictions | |
| ] | |
| ) / len(predictions) | |
| predictions = predictions[ | |
| ..., pad_top : shapes[0] - pad_bottom, pad_left : shapes[1] - pad_right | |
| ] | |
| predictions = F.interpolate( | |
| predictions, | |
| size=original_shapes, | |
| mode="bilinear", | |
| align_corners=False, | |
| antialias=True, | |
| ) | |
| intrinsics[:, 0, 0] = intrinsics[:, 0, 0] / ratio | |
| intrinsics[:, 1, 1] = intrinsics[:, 1, 1] / ratio | |
| intrinsics[:, 0, 2] = (intrinsics[:, 0, 2] - pad_left) / ratio | |
| intrinsics[:, 1, 2] = (intrinsics[:, 1, 2] - pad_top) / ratio | |
| return predictions, intrinsics | |
| class UniDepthV1( | |
| nn.Module, | |
| PyTorchModelHubMixin, | |
| library_name="UniDepth", | |
| repo_url="https://github.com/lpiccinelli-eth/UniDepth", | |
| tags=["monocular-metric-depth-estimation"], | |
| ): | |
| def __init__( | |
| self, | |
| config, | |
| eps: float = 1e-6, | |
| **kwargs, | |
| ): | |
| super().__init__() | |
| self.build(config) | |
| self.eps = eps | |
| def forward(self, inputs, image_metas=None): | |
| rgbs = inputs['image'] | |
| gt_intrinsics = inputs.get('K') | |
| H, W = rgbs.shape[-2:] | |
| # Encode | |
| encoder_outputs, cls_tokens = self.pixel_encoder(rgbs) | |
| if "dino" in self.pixel_encoder.__class__.__name__.lower(): | |
| encoder_outputs = [ | |
| (x + y.unsqueeze(1)).contiguous() | |
| for x, y in zip(encoder_outputs, cls_tokens) | |
| ] | |
| inputs["encoder_outputs"] = encoder_outputs | |
| inputs["cls_tokens"] = cls_tokens | |
| # Get camera infos, if any | |
| if gt_intrinsics is not None: | |
| rays, angles = generate_rays( | |
| gt_intrinsics, self.image_shape, noisy=self.training | |
| ) | |
| inputs["rays"] = rays | |
| inputs["angles"] = angles | |
| inputs["K"] = gt_intrinsics | |
| self.pixel_decoder.test_fixed_camera = True # use GT camera in fwd | |
| # Decode | |
| pred_intrinsics, predictions, _ = self.pixel_decoder(inputs, {}) | |
| predictions = sum( | |
| [ | |
| F.interpolate( | |
| x.clone(), | |
| size=self.image_shape, | |
| mode="bilinear", | |
| align_corners=False, | |
| antialias=True, | |
| ) | |
| for x in predictions | |
| ] | |
| ) / len(predictions) | |
| # Final 3D points backprojection | |
| pred_angles = generate_rays(pred_intrinsics, (H, W), noisy=False)[-1] | |
| # You may want to use inputs["angles"] if available? | |
| pred_angles = rearrange(pred_angles, "b (h w) c -> b c h w", h=H, w=W) | |
| pred_angles = F.interpolate( | |
| pred_angles.clone(), | |
| size=self.image_shape, | |
| mode="bilinear", | |
| align_corners=False, | |
| antialias=True, | |
| ) | |
| points_3d = torch.cat((pred_angles, predictions), dim=1) | |
| points_3d = spherical_zbuffer_to_euclidean( | |
| points_3d.permute(0, 2, 3, 1) | |
| ).permute(0, 3, 1, 2) | |
| # Output data, use for loss computation | |
| outputs = { | |
| "angles": pred_angles, | |
| "intrinsics": pred_intrinsics, | |
| "points": points_3d, | |
| "depth": predictions.squeeze(1), | |
| } | |
| self.pixel_decoder.test_fixed_camera = False | |
| return outputs | |
| def infer(self, rgbs: torch.Tensor, intrinsics=None, skip_camera=False): | |
| if rgbs.ndim == 3: | |
| rgbs = rgbs.unsqueeze(0) | |
| if intrinsics is not None and intrinsics.ndim == 2: | |
| intrinsics = intrinsics.unsqueeze(0) | |
| B, _, H, W = rgbs.shape | |
| rgbs = rgbs.to(self.device) | |
| if intrinsics is not None: | |
| intrinsics = intrinsics.to(self.device) | |
| # process image and intrinsiscs (if any) to match network input (slow?) | |
| if rgbs.max() > 5 or rgbs.dtype == torch.uint8: | |
| rgbs = rgbs.to(torch.float32).div(255) | |
| if rgbs.min() >= 0.0 and rgbs.max() <= 1.0: | |
| rgbs = TF.normalize( | |
| rgbs, | |
| mean=IMAGENET_DATASET_MEAN, | |
| std=IMAGENET_DATASET_STD, | |
| ) | |
| (h, w), ratio = _shapes((H, W), self.image_shape) | |
| pad_left, pad_right, pad_top, pad_bottom = _paddings((h, w), self.image_shape) | |
| rgbs, gt_intrinsics = _preprocess( | |
| rgbs, | |
| intrinsics, | |
| (h, w), | |
| (pad_left, pad_right, pad_top, pad_bottom), | |
| ratio, | |
| self.image_shape, | |
| ) | |
| # run encoder | |
| encoder_outputs, cls_tokens = self.pixel_encoder(rgbs) | |
| if "dino" in self.pixel_encoder.__class__.__name__.lower(): | |
| encoder_outputs = [ | |
| (x + y.unsqueeze(1)).contiguous() | |
| for x, y in zip(encoder_outputs, cls_tokens) | |
| ] | |
| # get data for decoder and adapt to given camera | |
| inputs = {} | |
| inputs["encoder_outputs"] = encoder_outputs | |
| inputs["cls_tokens"] = cls_tokens | |
| inputs["image"] = rgbs | |
| if gt_intrinsics is not None: | |
| rays, angles = generate_rays( | |
| gt_intrinsics, self.image_shape, noisy=self.training | |
| ) | |
| inputs["rays"] = rays | |
| inputs["angles"] = angles | |
| inputs["K"] = gt_intrinsics | |
| self.pixel_decoder.test_fixed_camera = True | |
| self.pixel_decoder.skip_camera = skip_camera | |
| # decode all | |
| pred_intrinsics, predictions, _ = self.pixel_decoder(inputs, {}) | |
| # undo the reshaping and get original image size (slow) | |
| predictions, pred_intrinsics = _postprocess( | |
| predictions, | |
| pred_intrinsics, | |
| self.image_shape, | |
| (pad_left, pad_right, pad_top, pad_bottom), | |
| ratio, | |
| (H, W), | |
| ) | |
| # final 3D points backprojection | |
| intrinsics = gt_intrinsics if gt_intrinsics is not None else pred_intrinsics | |
| angles = generate_rays(intrinsics, (H, W), noisy=False)[-1] | |
| angles = rearrange(angles, "b (h w) c -> b c h w", h=H, w=W) | |
| points_3d = torch.cat((angles, predictions), dim=1) | |
| points_3d = spherical_zbuffer_to_euclidean( | |
| points_3d.permute(0, 2, 3, 1) | |
| ).permute(0, 3, 1, 2) | |
| # output data | |
| outputs = { | |
| "intrinsics": pred_intrinsics, | |
| "points": points_3d, | |
| "depth": predictions[:, -1:], | |
| } | |
| self.pixel_decoder.test_fixed_camera = False | |
| self.pixel_decoder.skip_camera = False | |
| return outputs | |
| def load_pretrained(self, model_file): | |
| device = ( | |
| torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | |
| ) | |
| dict_model = torch.load(model_file, map_location=device) | |
| if "model" in dict_model: | |
| dict_model = dict_model["model"] | |
| new_state_dict = deepcopy( | |
| {k.replace("module.", ""): v for k, v in dict_model.items()} | |
| ) | |
| info = self.load_state_dict(new_state_dict, strict=False) | |
| if is_main_process(): | |
| print( | |
| f"Loaded from {model_file} for {self.__class__.__name__} results in:", | |
| info, | |
| ) | |
| def get_params(self, config): | |
| if hasattr(self.pixel_encoder, "get_params"): | |
| encoder_p, encoder_lr = self.pixel_encoder.get_params( | |
| config["model"]["pixel_encoder"]["lr"], | |
| config["training"]["wd"], | |
| config["training"]["ld"], | |
| ) | |
| else: | |
| encoder_p, encoder_lr = get_params( | |
| self.pixel_encoder, | |
| config["model"]["pixel_encoder"]["lr"], | |
| config["training"]["wd"], | |
| ) | |
| decoder_p, decoder_lr = get_params( | |
| self.pixel_decoder, config["training"]["lr"], config["training"]["wd"] | |
| ) | |
| return [*encoder_p, *decoder_p], [*encoder_lr, *decoder_lr] | |
| def device(self): | |
| return next(self.parameters()).device | |
| def build(self, config): | |
| mod = importlib.import_module("unidepth.models.encoder") | |
| pixel_encoder_factory = getattr(mod, config["model"]["pixel_encoder"]["name"]) | |
| pixel_encoder_config = { | |
| **config["training"], | |
| **config["data"], | |
| **config["model"]["pixel_encoder"], | |
| "interpolate_offset": 0.1, | |
| } | |
| pixel_encoder = pixel_encoder_factory(pixel_encoder_config) | |
| config["model"]["pixel_encoder"]["patch_size"] = ( | |
| 14 if "dino" in config["model"]["pixel_encoder"]["name"] else 16 | |
| ) | |
| pixel_encoder_embed_dims = ( | |
| pixel_encoder.embed_dims | |
| if hasattr(pixel_encoder, "embed_dims") | |
| else [getattr(pixel_encoder, "embed_dim") * 2**i for i in range(4)] | |
| ) | |
| config["model"]["pixel_encoder"]["embed_dim"] = getattr( | |
| pixel_encoder, "embed_dim" | |
| ) | |
| config["model"]["pixel_encoder"]["embed_dims"] = pixel_encoder_embed_dims | |
| config["model"]["pixel_encoder"]["depths"] = pixel_encoder.depths | |
| self.pixel_encoder = pixel_encoder | |
| self.pixel_decoder = Decoder(config) | |
| self.image_shape = config["data"]["image_shape"] | |