diff --git a/README.md b/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..0a51e5ebcd8f129c76f43225cd770a2344101cc7
--- /dev/null
+++ b/README.md
@@ -0,0 +1,135 @@
+
+
Map It Anywhere (MIA): Empowering Bird’s Eye View Mapping using Large-scale Public Data
+
+
+ Cherie Ho*
+ ·
+ Jiaye (Tony) Zou*
+ ·
+ Omar Alama*
+
+ Sai Mitheran Jagadesh Kumar
+ ·
+ Benjamin Chiang
+ ·
+ Taneesh Gupta
+ ·
+ Chen Wang
+
+ Nikhil Keetha
+ ·
+ Katia Sycara
+ ·
+ Sebastian Scherer
+
+
+
+
+
+")
+
+## Table of Contents
+ - [Using the MIA Data Engine](#using-the-mia-data-engine)
+ - [Downloading the MIA dataset](#downloading-the-mia-dataset)
+ - [Training](#training)
+ - [Evaluation](#evaluation)
+ - [Acknowledgement](#acknowledgement)
+
+
+## Using the MIA data engine
+
+### 0. Setting up the environment
+0. Install docker by following the instructions on their [website](https://www.docker.com/get-started/)
+1. Build the docker image `mia/Dockerfile` by running:
+
+ docker build -t mia:release mia/Dockerfile
+2. Launch the container while mounting this repository to the container file system.
+
+ docker run -v :/home/MapItAnywhere --network=bridge -it mia:release
+
+### 1. Getting FPVs
+
+The first stage of the MIA data engine is to get the first person images.
+First, if you want to pull your own locations, copy the example configuration from `mia/conf/example.yaml` and edit the cities list to specify the cities you want. Feel free to explore the other well-documented FPV options in the configuration file.
+
+Once configuration is done simply run the following from inside your docker container with working dir set to this repo:
+
+ python3.9 -m mia.fpv.get_fpv --cfg mia/conf/.yaml
+
+That's it ! The engine will now automatically fetch, filter, and process your FPV images. You may get a few errors specifying that some images were unable to be fetched due to permission limitations. That is normal and the engine will continue.
+
+Once all your locations have been downloaded, you will see that parquet files, images, and raw_images, have been populated in your `dataset_dir` for each location. You can now move on to getting BEVs.
+
+### 2. Getting BEVs
+Once you have the FPV parquet dataframes downloaded, you are now ready to fetch and generate the BEV smenatic maps.
+
+Edit the documented bev options in your configuration file to suit your use case. The defaults are tuned to what we used to produce the MIA datasets and you can use them as is.
+
+Once configuration is done simply run the following from inside your docker container with working dir set to this repo:
+
+ python3.9 -m mia.bev.get_bev
+
+The data engine will now fetch, process, and save the semantic masks.
+
+You now have FPV-BEV pairs with associated metadata and camera parameters !
+
+**Note** to get satellite imagery for comparison you must first download it by toggling the store_sat option in the configuration
+
+### 3. (Optional) Visualize your data
+You can visualize a few samples using the tool `mia/misc_tools/vis_samples.py`.
+
+From inside the container with working dir set to this repo, run:
+
+ python3.9 -m mia/misc_tools/vis_samples --dataset_dir /home/mia_dataset_release --locations
+
+If successful, the script will generate a PDF called `compare.pdf` in the pittsburgh directory. Upon openning you should see the metadata, FPVs, and BEVs of a few samples of the dataset.
+
+
+## Downloading the MIA dataset
+Refer to [mia/dataset.md](mia/dataset.md) for instructions.
+
+## Training
+
+### Pre-train with MIA Dataset
+To pretrain using our paper configuration simply run:
+
+ python -m mapper.mapper data.split= data.data_dir=
+
+### Finetune with NuScenes Dataset
+To finetune using NuScenes Dataset with our paper configuration, run:
+
+ python -m mapper.mapper -cn mapper_nuscenes training.checkpoint= data.data_dir= data.map_dir=
+
+## Reproduction
+#### Dataset Setup
+**MIA**: Follow download instructions in [Downloading the MIA Dataset](#downloading-the-mia-dataset)
+
+**NuScenes**: Follow the data generation instructions in [Mono-Semantic-Maps](https://github.com/tom-roddick/mono-semantic-maps?tab=readme-ov-file#nuscenes). To match the newest available information, we use v1.3 of the NuScenes' map expansion pack.
+
+**KITTI360-BEV**: Follow the KITTI360-BEV dataset instructions in [SkyEye](https://github.com/robot-learning-freiburg/SkyEye?tab=readme-ov-file#skyeye-datasets)
+
+#### Inference
+To generate MIA dataset prediction results(on test split), use:
+
+ python -m mapper.mapper data.split= data.data_dir= training.checkpoint= training.eval=true
+*To specify location, add `data.scenes` in the argument. For example, for held-out cities `data.scenes="[pittsburgh, houston]"`*
+
+To Generate NuScenes dataset prediction results(on validation split), use:
+
+ python -m mapper.mapper -cn mapper_nuscenes training.checkpoint= data.data_dir= data.map_dir= training.eval=true
+
+To Generate KITTI360-BEV dataset prediction results (on validation split), use:
+
+ python -m mapper.mapper -cn mapper_kitti training.checkpoint= data.seam_root_dir= data.dataset_root_dir= training.eval=true
+
+
+## License
+[More Information Needed]
+
+## Acknowledgement
+We thank the authors of the following repositories for their open-source code:
+- [OrienterNet](https://github.com/facebookresearch/OrienterNet)
+- [Map Machine](https://github.com/enzet/map-machine)
+- [Mono-Semantic-Maps](https://github.com/tom-roddick/mono-semantic-maps)
+- [Translating Images Into Maps](https://github.com/avishkarsaha/translating-images-into-maps)
+- [SkyEye](https://github.com/robot-learning-freiburg/SkyEye)
\ No newline at end of file
diff --git a/mapper/__init__.py b/mapper/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..618c21f0657b17e6f0a21c6a23c47c1f3e19e0ba
--- /dev/null
+++ b/mapper/__init__.py
@@ -0,0 +1,30 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+import os, sys
+
+sys.path.append(os.path.dirname(os.path.realpath(__file__)))
+from pathlib import Path
+import logging
+
+import pytorch_lightning # noqa: F401
+
+
+formatter = logging.Formatter(
+ fmt="[%(asctime)s %(name)s %(levelname)s] %(message)s",
+ datefmt="%Y-%m-%d %H:%M:%S",
+)
+handler = logging.StreamHandler()
+handler.setFormatter(formatter)
+handler.setLevel(logging.INFO)
+
+logger = logging.getLogger("mapper")
+logger.setLevel(logging.INFO)
+logger.addHandler(handler)
+logger.propagate = False
+
+pl_logger = logging.getLogger("pytorch_lightning")
+if len(pl_logger.handlers):
+ pl_logger.handlers[0].setFormatter(formatter)
+
+repo_dir = Path(__file__).parent.parent
+EXPERIMENTS_PATH = repo_dir / "experiments/"
+DATASETS_PATH = repo_dir / "datasets/"
diff --git a/mapper/callbacks.py b/mapper/callbacks.py
new file mode 100644
index 0000000000000000000000000000000000000000..58b60720711b023ef57b2ed23f8b237d16a319a6
--- /dev/null
+++ b/mapper/callbacks.py
@@ -0,0 +1,105 @@
+import torch
+import pytorch_lightning as pl
+from pathlib import Path
+from typing import Any
+import torchvision
+import wandb
+
+
+class EvalSaveCallback(pl.Callback):
+
+ def __init__(self, save_dir: Path) -> None:
+ super().__init__()
+ self.save_dir = save_dir
+
+ def save(self, outputs, batch, batch_idx):
+ name = batch['name']
+
+ filename = self.save_dir / f"{batch_idx:06d}_{name[0]}.pt"
+ torch.save({
+ "fpv": batch['image'],
+ "seg_masks": batch['seg_masks'],
+ 'name': name,
+ "output": outputs["output"],
+ "valid_bev": outputs["valid_bev"],
+ }, filename)
+
+ def on_test_batch_end(self, trainer: pl.Trainer,
+ pl_module: pl.LightningModule,
+ outputs: torch.Tensor | Any | None,
+ batch: Any,
+ batch_idx: int,
+ dataloader_idx: int = 0) -> None:
+ if not outputs:
+ return
+
+ self.save(outputs, batch, batch_idx)
+
+ def on_validation_batch_end(self, trainer: pl.Trainer,
+ pl_module: pl.LightningModule,
+ outputs: torch.Tensor | Any | None,
+ batch: Any,
+ batch_idx: int,
+ dataloader_idx: int = 0) -> None:
+ if not outputs:
+
+ return
+
+ self.save(outputs, batch, batch_idx)
+
+
+class ImageLoggerCallback(pl.Callback):
+ def __init__(self, num_classes):
+ super().__init__()
+ self.num_classes = num_classes
+
+ def log_image(self, trainer, pl_module, outputs, batch, batch_idx, mode="train"):
+ fpv_rgb = batch["image"]
+ fpv_grid = torchvision.utils.make_grid(
+ fpv_rgb, nrow=8, normalize=False)
+ images = [
+ wandb.Image(fpv_grid, caption="fpv")
+ ]
+
+ pred = outputs['output'].permute(0, 2, 3, 1)
+ pred[outputs["valid_bev"][..., :-1] == 0] = 0
+ pred = (pred > 0.5).float()
+ pred = pred.permute(0, 3, 1, 2)
+
+ for i in range(self.num_classes):
+ gt_class_i = batch['seg_masks'][..., i]
+ gt_class_i_grid = torchvision.utils.make_grid(
+ gt_class_i.unsqueeze(1), nrow=8, normalize=False, pad_value=0)
+ pred_class_i = pred[:, i]
+ pred_class_i_grid = torchvision.utils.make_grid(
+ pred_class_i.unsqueeze(1), nrow=8, normalize=False, pad_value=0)
+
+ images += [
+ wandb.Image(gt_class_i_grid, caption=f"gt_class_{i}"),
+ wandb.Image(pred_class_i_grid, caption=f"pred_class_{i}")
+ ]
+
+ trainer.logger.experiment.log(
+ {
+ "{}/images".format(mode): images
+ }
+ )
+
+ def on_validation_batch_end(self, trainer, pl_module: pl.LightningModule, outputs, batch, batch_idx):
+ if batch_idx == 0:
+ with torch.no_grad():
+ outputs = pl_module(batch)
+ self.log_image(trainer, pl_module, outputs,
+ batch, batch_idx, mode="val")
+
+ def on_train_batch_end(self, trainer, pl_module: pl.LightningModule, outputs, batch, batch_idx):
+ if batch_idx == 0:
+ pl_module.eval()
+
+ with torch.no_grad():
+ outputs = pl_module(batch)
+
+ self.log_image(trainer, pl_module, outputs,
+ batch, batch_idx, mode="train")
+
+ pl_module.train()
diff --git a/mapper/conf/data/kitti.yaml b/mapper/conf/data/kitti.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..a5f6b42dfbef12684e7ec0a687baed00dc584d58
--- /dev/null
+++ b/mapper/conf/data/kitti.yaml
@@ -0,0 +1,40 @@
+name: kitti
+seam_root_dir: /path/to/generated/seam
+dataset_root_dir: /path/to/kitti/dataset
+bev_percentage: 100
+pixel_per_meter: 2
+crop_size_meters: 50
+target_focal_length: 256
+resize_image: null
+pad_to_multiple: 14
+num_classes: 8
+loading:
+ train:
+ batch_size: 32
+ num_workers: 32
+ val:
+ batch_size: 32
+ num_workers: 32
+ test:
+ batch_size: 32
+ num_workers: 32
+pad_to_square: true
+rectify_pitch: true
+gravity_align: false
+class_mapping: [0, 0, 1, 2, 0, 3]
+augmentations:
+ enabled: True
+ brightness: 0.5
+ contrast: 0.5
+ saturation: 0.5
+ random_flip: 0.5
+ hue: 0.5
+ random_resized_crop: False
+ gaussian_noise:
+ enabled: False
+ mean: 0.0
+ std: 0.1
+ brightness_contrast:
+ enabled: True
+ brightness_factor: 0.2
+ contrast_factor: 0.2
\ No newline at end of file
diff --git a/mapper/conf/data/mia.yaml b/mapper/conf/data/mia.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..9b4bd88038f1d933940627f5ab02d5331413cc0e
--- /dev/null
+++ b/mapper/conf/data/mia.yaml
@@ -0,0 +1,44 @@
+name: mapillary
+scenes:
+- chicago
+- new_york
+- los_angeles
+- san_francisco
+split: /path/to/split/file
+data_dir: /path/to/mia/dataset
+loading:
+ train:
+ batch_size: 128
+ num_workers: 30
+ val:
+ batch_size: 128
+ num_workers: 30
+ test:
+ batch_size: 1
+ num_workers: 0
+ testsmall:
+ batch_size: 1
+ num_workers: 0
+num_classes: 6
+pixel_per_meter: 2
+crop_size_meters: 64
+resize_image: 512
+pad_to_square: true
+rectify_pitch: true
+gravity_align: true
+augmentations:
+ enabled: True
+ brightness: 0.5
+ contrast: 0.5
+ saturation: 0.5
+ random_flip: 0.5
+ hue: 0.5
+ random_resized_crop: False
+ gaussian_noise:
+ enabled: False
+ mean: 0.0
+ std: 0.1
+ brightness_contrast:
+ enabled: True
+ brightness_factor: 0.2
+ contrast_factor: 0.2
\ No newline at end of file
diff --git a/mapper/conf/data/nuscenes.yaml b/mapper/conf/data/nuscenes.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..4d002947f587232680507d69f737782ec6bc21d5
--- /dev/null
+++ b/mapper/conf/data/nuscenes.yaml
@@ -0,0 +1,38 @@
+name: nuscenes
+data_dir: /path/to/nuscenes/data
+map_dir: /path/to/generated/maps
+version: v1.0-trainval
+pixel_per_meter: 2
+crop_size_meters: 50
+resize_image: 512
+percentage: 1.0
+class_mapping: [0, 1, 2, 0, 0, 3]
+num_classes: 14
+loading:
+ train:
+ batch_size: 128
+ num_workers: 10
+ val:
+ batch_size: 128
+ num_workers: 10
+ test:
+ batch_size: 128
+ num_workers: 10
+pad_to_square: true
+rectify_pitch: true
+gravity_align: true
+augmentations:
+ enabled: True
+ brightness: 0.5
+ contrast: 0.5
+ saturation: 0.5
+ hue: 0.5
+ random_resized_crop: False
+ gaussian_noise:
+ enabled: False
+ mean: 0.0
+ std: 0.1
+ brightness_contrast:
+ enabled: True
+ brightness_factor: 0.2
+ contrast_factor: 0.2
\ No newline at end of file
diff --git a/mapper/conf/mapper_kitti.yaml b/mapper/conf/mapper_kitti.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..016e4196521378cee81957af82c47b2e0bf8e989
--- /dev/null
+++ b/mapper/conf/mapper_kitti.yaml
@@ -0,0 +1,23 @@
+defaults:
+ - schema/data: kitti
+ - data: kitti
+ - model: mapper
+ - training
+ - _self_
+
+experiment:
+ name: MIA_DINOv2_Mapper_KITTI
+
+model:
+ loss:
+ xent_weight: 1.0
+ dice_weight: 1.0
+ focal_loss: false
+ focal_loss_gamma: 2.0
+ requires_frustrum: true
+ requires_flood_mask: true
+ class_weights: null
+ label_smoothing: 0.1
+
+training:
+ checkpoint: /path/to/checkpoint
\ No newline at end of file
diff --git a/mapper/conf/mapper_nuscenes.yaml b/mapper/conf/mapper_nuscenes.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..8b036c1299b4ab472eb1ce11482cc04c909b3833
--- /dev/null
+++ b/mapper/conf/mapper_nuscenes.yaml
@@ -0,0 +1,26 @@
+defaults:
+ - schema/data: nuscenes
+ - data: nuscenes
+ - model: mapper
+ - training
+ - _self_
+
+experiment:
+ name: MIA_DINOv2_Mapper_NuScenes
+
+model:
+ loss:
+ xent_weight: 1.0
+ dice_weight: 1.0
+ focal_loss: false
+ focal_loss_gamma: 2.0
+ class_weights: [1.00060036, 1.85908161, 1.0249052, 0., 0., 2.57267816]
+ requires_frustrum: true
+ label_smoothing: 0.1
+
+training:
+ checkpoint: /path/to/checkpoint
+ finetune: true
+ lr: 0.0001
+ trainer:
+ max_epochs: 50
\ No newline at end of file
diff --git a/mapper/conf/model/image_encoder/dino.yaml b/mapper/conf/model/image_encoder/dino.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..85f60eeafec39b63139888678638284a9c412f8c
--- /dev/null
+++ b/mapper/conf/model/image_encoder/dino.yaml
@@ -0,0 +1,5 @@
+name: feature_extractor_DPT
+backbone:
+ pretrained: true
+ frozen: true
+ output_dim: ${model.latent_dim} # Match Latent Dimension
\ No newline at end of file
diff --git a/mapper/conf/model/image_encoder/resnet.yaml b/mapper/conf/model/image_encoder/resnet.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..4faebe95fcc7dd25b802484df70a96bbc8f2bdfa
--- /dev/null
+++ b/mapper/conf/model/image_encoder/resnet.yaml
@@ -0,0 +1,12 @@
+name: feature_extractor_resnet
+backbone:
+ pretrained: true
+ frozen: true
+ output_dim: ${model.latent_dim} # Match Latent Dimension
+ input_dim: 3
+ encoder: resnet50
+ num_downsample: null
+ remove_stride_from_first_conv: false
+ decoder_norm: "nn.BatchNorm2d"
+ do_average_pooling: false
+ checkpointed: false
\ No newline at end of file
diff --git a/mapper/conf/model/mapper.yaml b/mapper/conf/model/mapper.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..50cb6e8ac2ae492c1f8751018270315797c7f9ec
--- /dev/null
+++ b/mapper/conf/model/mapper.yaml
@@ -0,0 +1,15 @@
+defaults:
+ - schema/backbone: dino
+ - image_encoder: dino
+
+segmentation_head:
+ dropout_rate: 0.2
+name: map_perception_net
+num_classes: 6
+latent_dim: 128
+z_max: 50
+x_max: 25
+pixel_per_meter: ${data.pixel_per_meter}
+num_scale_bins: 32
+loss:
+ num_classes: ${..num_classes}
\ No newline at end of file
diff --git a/mapper/conf/pretrain.yaml b/mapper/conf/pretrain.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..cb4b3180bbd88dad062f36bcf70d0795bcc572aa
--- /dev/null
+++ b/mapper/conf/pretrain.yaml
@@ -0,0 +1,24 @@
+defaults:
+ - schema/data: mia
+ - data: mia
+ - model: mapper
+ - training
+ - _self_
+
+experiment:
+ name: MIA_DINOv2_Pretrain
+
+model:
+ loss:
+ xent_weight: 1.0
+ dice_weight: 1.0
+ focal_loss: false
+ focal_loss_gamma: 2.0
+ requires_frustrum: true
+ class_weights: [ 1.00351229, 4.34782609, 1.00110121, 1.03124678,
+ 6.69792364, 7.55857899 ]
+ label_smoothing: 0.1
+
+training:
+ trainer:
+ max_epochs: 15
\ No newline at end of file
diff --git a/mapper/conf/pretrain_resnet.yaml b/mapper/conf/pretrain_resnet.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..957f2ef552c1a2fe69e958183e97df15ee66d3fb
--- /dev/null
+++ b/mapper/conf/pretrain_resnet.yaml
@@ -0,0 +1,26 @@
+defaults:
+ - schema/data: mia
+ - data: mia
+ - model: mapper
+ - training
+ - _self_
+ - override model/schema/backbone: resnet
+ - override model/image_encoder: resnet
+
+experiment:
+ name: MIA_DINOv2_Pretrain
+
+model:
+ loss:
+ xent_weight: 1.0
+ dice_weight: 1.0
+ focal_loss: false
+ focal_loss_gamma: 2.0
+ requires_frustrum: true
+ class_weights: [ 1.00351229, 4.34782609, 1.00110121, 1.03124678,
+ 6.69792364, 7.55857899 ]
+
+training:
+ trainer:
+ max_steps: 10
+ max_epochs: 15
\ No newline at end of file
diff --git a/mapper/conf/training.yaml b/mapper/conf/training.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..e4a6d52f8dab9974b5f43abfe621fc9229a4c778
--- /dev/null
+++ b/mapper/conf/training.yaml
@@ -0,0 +1,30 @@
+experiment:
+ name: MGL_DINOv2_v4-baseline-less-class
+ seed: 42
+training:
+ num_classes: ${model.num_classes}
+ lr: 0.001
+ lr_scheduler:
+ name: "CosineAnnealingLR"
+ args:
+ T_max: $total_epochs
+ eta_min: 0.0000001
+ checkpoint: null
+ finetune: false
+ eval: false
+ save_dir: eval_results
+ trainer:
+ # val_check_interval: 250
+ # log_every_n_steps: 100
+ # limit_val_batches: 0
+ # max_steps: 500000
+ # num_epochs: 15
+ precision: bf16-mixed
+ accelerator: gpu
+ strategy: ddp_find_unused_parameters_true
+ checkpointing:
+ dirpath: checkpoints/
+ monitor: val/total/loss
+ save_top_k: -1
+ mode: min
+ save_last: True
\ No newline at end of file
diff --git a/mapper/data/__init__.py b/mapper/data/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..1d0981f3f2386dc647cbcc644ab9395b733310b6
--- /dev/null
+++ b/mapper/data/__init__.py
@@ -0,0 +1,7 @@
+from .mapillary.data_module import MapillaryDataModule
+from .nuscenes.data_module import NuScenesData
+
+modules = {
+ "mapillary": MapillaryDataModule,
+ "nuscenes": NuScenesData
+}
diff --git a/mapper/data/base.py b/mapper/data/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..d85eff918825e27a662db2eecfd335fd9db81583
--- /dev/null
+++ b/mapper/data/base.py
@@ -0,0 +1,19 @@
+from abc import abstractmethod
+from typing import Optional
+
+
+class DataBase():
+ def __init__(self) -> None:
+ raise NotImplementedError
+
+ @abstractmethod
+ def prepare_data(self) -> None:
+ raise NotImplementedError
+
+ @abstractmethod
+ def setup(self, stage: Optional[str] = None):
+ raise NotImplementedError
+
+ @abstractmethod
+ def dataset(self, stage: str):
+ raise NotImplementedError
\ No newline at end of file
diff --git a/mapper/data/image.py b/mapper/data/image.py
new file mode 100644
index 0000000000000000000000000000000000000000..c8b5080cbcd316df4837db23f93ba8219ddda415
--- /dev/null
+++ b/mapper/data/image.py
@@ -0,0 +1,140 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+
+from typing import Callable, Optional, Union, Sequence
+
+import numpy as np
+import torch
+import torchvision.transforms.functional as tvf
+import collections
+from scipy.spatial.transform import Rotation
+
+from ..utils.geometry import from_homogeneous, to_homogeneous
+from ..utils.wrappers import Camera
+
+
+def rectify_image(
+ image: torch.Tensor,
+ cam: Camera,
+ roll: float,
+ pitch: Optional[float] = None,
+ valid: Optional[torch.Tensor] = None,
+):
+ *_, h, w = image.shape
+ grid = torch.meshgrid(
+ [torch.arange(w, device=image.device), torch.arange(h, device=image.device)],
+ indexing="xy",
+ )
+ grid = torch.stack(grid, -1).to(image.dtype)
+
+ if pitch is not None:
+ args = ("ZX", (roll, pitch))
+ else:
+ args = ("Z", roll)
+ R = Rotation.from_euler(*args, degrees=True).as_matrix()
+ R = torch.from_numpy(R).to(image)
+
+ grid_rect = to_homogeneous(cam.normalize(grid)) @ R.T
+ grid_rect = cam.denormalize(from_homogeneous(grid_rect))
+ grid_norm = (grid_rect + 0.5) / grid.new_tensor([w, h]) * 2 - 1
+ rectified = torch.nn.functional.grid_sample(
+ image[None],
+ grid_norm[None],
+ align_corners=False,
+ mode="bilinear",
+ ).squeeze(0)
+ if valid is None:
+ valid = torch.all((grid_norm >= -1) & (grid_norm <= 1), -1)
+ else:
+ valid = (
+ torch.nn.functional.grid_sample(
+ valid[None, None].float(),
+ grid_norm[None],
+ align_corners=False,
+ mode="nearest",
+ )[0, 0]
+ > 0
+ )
+ return rectified, valid
+
+
+def resize_image(
+ image: torch.Tensor,
+ size: Union[int, Sequence, np.ndarray],
+ fn: Optional[Callable] = None,
+ camera: Optional[Camera] = None,
+ valid: np.ndarray = None,
+):
+ """Resize an image to a fixed size, or according to max or min edge."""
+ *_, h, w = image.shape
+ if fn is not None:
+ assert isinstance(size, int)
+ scale = size / fn(h, w)
+ h_new, w_new = int(round(h * scale)), int(round(w * scale))
+ scale = (scale, scale)
+ else:
+ if isinstance(size, (collections.abc.Sequence, np.ndarray)):
+ w_new, h_new = size
+ elif isinstance(size, int):
+ w_new = h_new = size
+ else:
+ raise ValueError(f"Incorrect new size: {size}")
+ scale = (w_new / w, h_new / h)
+ if (w, h) != (w_new, h_new):
+ mode = tvf.InterpolationMode.BILINEAR
+ image = tvf.resize(image, (int(h_new), int(w_new)), interpolation=mode, antialias=True)
+ image.clip_(0, 1)
+ if camera is not None:
+ camera = camera.scale(scale)
+ if valid is not None:
+ valid = tvf.resize(
+ valid.unsqueeze(0),
+ (int(h_new), int(w_new)),
+ interpolation=tvf.InterpolationMode.NEAREST,
+ ).squeeze(0)
+ ret = [image, scale]
+ if camera is not None:
+ ret.append(camera)
+ if valid is not None:
+ ret.append(valid)
+ return ret
+
+
+def pad_image(
+ image: torch.Tensor,
+ size: Union[int, Sequence, np.ndarray],
+ camera: Optional[Camera] = None,
+ valid: torch.Tensor = None,
+ crop_and_center: bool = False,
+):
+ if isinstance(size, int):
+ w_new = h_new = size
+ elif isinstance(size, (collections.abc.Sequence, np.ndarray)):
+ w_new, h_new = size
+ else:
+ raise ValueError(f"Incorrect new size: {size}")
+ *c, h, w = image.shape
+ if crop_and_center:
+ diff = np.array([w - w_new, h - h_new])
+ left, top = left_top = np.round(diff / 2).astype(int)
+ right, bottom = diff - left_top
+ else:
+ assert h <= h_new
+ assert w <= w_new
+ top = bottom = left = right = 0
+ slice_out = np.s_[..., : min(h, h_new), : min(w, w_new)]
+ slice_in = np.s_[
+ ..., max(top, 0) : h - max(bottom, 0), max(left, 0) : w - max(right, 0)
+ ]
+ if (w, h) == (w_new, h_new):
+ out = image
+ else:
+ out = torch.zeros((*c, h_new, w_new), dtype=image.dtype)
+ out[slice_out] = image[slice_in]
+ if camera is not None:
+ camera = camera.crop((max(left, 0), max(top, 0)), (w_new, h_new))
+ out_valid = torch.zeros((h_new, w_new), dtype=torch.bool)
+ out_valid[slice_out] = True if valid is None else valid[slice_in]
+ if camera is not None:
+ return out, out_valid, camera
+ else:
+ return out, out_valid
diff --git a/mapper/data/kitti/data_module.py b/mapper/data/kitti/data_module.py
new file mode 100644
index 0000000000000000000000000000000000000000..39d1ac02cf17b9baabd5dbf182b632ab2081866a
--- /dev/null
+++ b/mapper/data/kitti/data_module.py
@@ -0,0 +1,32 @@
+from ..base import DataBase
+from .dataset import BEVKitti360Dataset
+from ..schema import KITTIDataConfiguration
+
+class BEVKitti360Data(DataBase):
+ def __init__(self, cfg: KITTIDataConfiguration) -> None:
+ self.cfg = cfg
+ self._dataset = {}
+
+ def prepare_data(self) -> None:
+ return
+
+ def setup(self, stage: str) -> None:
+ split = {
+ 'fit': 'train',
+ 'val': 'val',
+ 'validate': 'val',
+ 'test': 'val',
+ "train": "train"
+ }[stage]
+
+ self._dataset[stage] = BEVKitti360Dataset(
+ cfg=self.cfg,
+ split_name=split
+ )
+
+ def dataset(self, stage: str):
+ if self._dataset.get(stage) is None:
+ self.setup(stage)
+
+ return self._dataset[stage]
+
\ No newline at end of file
diff --git a/mapper/data/kitti/dataset.py b/mapper/data/kitti/dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..1b4af972de17ae8382a38fcf36b86da34550f1e1
--- /dev/null
+++ b/mapper/data/kitti/dataset.py
@@ -0,0 +1,317 @@
+import os
+import numpy as np
+import torch.utils.data as data
+import umsgpack
+from PIL import Image
+import json
+import torchvision.transforms as tvf
+
+from .transform import BEVTransform
+from ..schema import KITTIDataConfiguration
+
+class BEVKitti360Dataset(data.Dataset):
+ _IMG_DIR = "img"
+ _BEV_MSK_DIR = "bev_msk"
+ _BEV_PLABEL_DIR = "bev_plabel_dynamic"
+ _FV_MSK_DIR = "front_msk_seam"
+ _BEV_DIR = "bev_ortho"
+ _LST_DIR = "split"
+ _PERCENTAGES_DIR = "percentages"
+ _BEV_METADATA_FILE = "metadata_ortho.bin"
+ _FV_METADATA_FILE = "metadata_front.bin"
+
+ def __init__(self, cfg: KITTIDataConfiguration, split_name="train"):
+ super(BEVKitti360Dataset, self).__init__()
+ self.cfg = cfg
+ self.seam_root_dir = cfg.seam_root_dir # Directory of seamless data
+ self.kitti_root_dir = cfg.dataset_root_dir # Directory of the KITTI360 data
+ self.split_name = split_name
+
+ self.rgb_cameras = ['front']
+ if cfg.bev_percentage < 1:
+ self.bev_percentage = cfg.bev_percentage
+ else:
+ self.bev_percentage = int(cfg.bev_percentage)
+
+ # Folders
+ self._img_dir = os.path.join(self.seam_root_dir, BEVKitti360Dataset._IMG_DIR)
+ self._bev_msk_dir = os.path.join(self.seam_root_dir, BEVKitti360Dataset._BEV_MSK_DIR, BEVKitti360Dataset._BEV_DIR)
+ self._bev_plabel_dir = os.path.join(self.seam_root_dir, BEVKitti360Dataset._BEV_PLABEL_DIR, BEVKitti360Dataset._BEV_DIR)
+ self._fv_msk_dir = os.path.join(self.seam_root_dir, BEVKitti360Dataset._FV_MSK_DIR, "front")
+ self._lst_dir = os.path.join(self.seam_root_dir, BEVKitti360Dataset._LST_DIR)
+ self._percentages_dir = os.path.join(self.seam_root_dir, BEVKitti360Dataset._LST_DIR, BEVKitti360Dataset._PERCENTAGES_DIR)
+
+ # Load meta-data and split
+ self._bev_meta, self._bev_images, self._bev_images_all, self._fv_meta, self._fv_images, self._fv_images_all,\
+ self._img_map, self.bev_percent_split = self._load_split()
+
+ self.tfs = self.get_augmentations() if split_name == "train" else tvf.Compose([])
+ self.transform = BEVTransform(cfg, self.tfs)
+
+ def get_augmentations(self):
+
+ print(f"Augmentation!", "\n" * 10)
+ augmentations = [
+ tvf.ColorJitter(
+ brightness=self.cfg.augmentations.brightness,
+ contrast=self.cfg.augmentations.contrast,
+ saturation=self.cfg.augmentations.saturation,
+ hue=self.cfg.augmentations.hue,
+ )
+ ]
+
+ if self.cfg.augmentations.random_resized_crop:
+ augmentations.append(
+ tvf.RandomResizedCrop(scale=(0.8, 1.0))
+ ) # RandomResizedCrop
+
+ if self.cfg.augmentations.gaussian_noise.enabled:
+ augmentations.append(
+ tvf.GaussianNoise(
+ mean=self.cfg.augmentations.gaussian_noise.mean,
+ std=self.cfg.augmentations.gaussian_noise.std,
+ )
+ ) # Gaussian noise
+
+ if self.cfg.augmentations.brightness_contrast.enabled:
+ augmentations.append(
+ tvf.ColorJitter(
+ brightness=self.cfg.augmentations.brightness_contrast.brightness_factor,
+ contrast=self.cfg.augmentations.brightness_contrast.contrast_factor,
+ saturation=0, # Keep saturation at 0 for brightness and contrast adjustment
+ hue=0,
+ )
+ ) # Brightness and contrast adjustment
+
+ return tvf.Compose(augmentations)
+
+ # Load the train or the validation split
+ def _load_split(self):
+ with open(os.path.join(self.seam_root_dir, BEVKitti360Dataset._BEV_METADATA_FILE), "rb") as fid:
+ bev_metadata = umsgpack.unpack(fid, encoding="utf-8")
+
+ with open(os.path.join(self.seam_root_dir, BEVKitti360Dataset._FV_METADATA_FILE), 'rb') as fid:
+ fv_metadata = umsgpack.unpack(fid, encoding="utf-8")
+
+ # Read the files for this split
+ with open(os.path.join(self._lst_dir, self.split_name + ".txt"), "r") as fid:
+ lst = fid.readlines()
+ lst = [line.strip() for line in lst]
+
+ if self.split_name == "train":
+ # Get all the frames in the train dataset. This will be used for generating samples for temporal consistency.
+ with open(os.path.join(self._lst_dir, "{}_all.txt".format(self.split_name)), 'r') as fid:
+ lst_all = fid.readlines()
+ lst_all = [line.strip() for line in lst_all]
+
+ # Get all the samples for which the BEV plabels have to be loaded.
+ percentage_file = os.path.join(self._percentages_dir, "{}_{}.txt".format(self.split_name, self.bev_percentage))
+ print("Loading {}% file".format(self.bev_percentage))
+ with open(percentage_file, 'r') as fid:
+ lst_percent = fid.readlines()
+ lst_percent = [line.strip() for line in lst_percent]
+ else:
+ lst_all = lst
+ lst_percent = lst
+
+ # Remove elements from lst if they are not in _FRONT_MSK_DIR
+ fv_msk_frames = os.listdir(self._fv_msk_dir)
+ fv_msk_frames = [frame.split(".")[0] for frame in fv_msk_frames]
+ fv_msk_frames_exist_map = {entry: True for entry in fv_msk_frames} # This is to speed-up the dataloader
+ lst = [entry for entry in lst if entry in fv_msk_frames_exist_map]
+ lst_all = [entry for entry in lst_all if entry in fv_msk_frames_exist_map]
+
+ # Filter based on the samples plabels
+ if self.bev_percentage < 100:
+ lst_filt = [entry for entry in lst if entry in lst_percent]
+ lst = lst_filt
+
+ # Remove any potential duplicates
+ lst = set(lst)
+ lst_percent = set(lst_percent)
+
+ img_map = {}
+ for camera in self.rgb_cameras:
+ with open(os.path.join(self._img_dir, "{}.json".format(camera))) as fp:
+ map_list = json.load(fp)
+ map_dict = {k: v for d in map_list for k, v in d.items()}
+ img_map[camera] = map_dict
+
+ bev_meta = bev_metadata["meta"]
+ bev_images = [img_desc for img_desc in bev_metadata["images"] if img_desc["id"] in lst]
+ fv_meta = fv_metadata["meta"]
+ fv_images = [img_desc for img_desc in fv_metadata['images'] if img_desc['id'] in lst]
+
+ # Check for inconsistency due to inconsistencies in the input files or dataset
+ bev_images_ids = [bev_img["id"] for bev_img in bev_images]
+ fv_images_ids = [fv_img["id"] for fv_img in fv_images]
+ assert set(bev_images_ids) == set(fv_images_ids) and len(bev_images_ids) == len(fv_images_ids), 'Inconsistency between fv_images and bev_images detected'
+
+ if lst_all is not None:
+ bev_images_all = [img_desc for img_desc in bev_metadata['images'] if img_desc['id'] in lst_all]
+ fv_images_all = [img_desc for img_desc in fv_metadata['images'] if img_desc['id'] in lst_all]
+ else:
+ bev_images_all, fv_images_all = None, None
+
+ return bev_meta, bev_images, bev_images_all, fv_meta, fv_images, fv_images_all, img_map, lst_percent
+
+ def _find_index(self, list, key, value):
+ for i, dic in enumerate(list):
+ if dic[key] == value:
+ return i
+ return None
+
+ def _load_item(self, item_idx):
+ # Find the index of the element in the list containing all elements
+ all_idx = self._find_index(self._fv_images_all, "id", self._fv_images[item_idx]['id'])
+ if all_idx is None:
+ raise IOError("Required index not found!")
+
+ bev_img_desc = self._bev_images[item_idx]
+ fv_img_desc = self._fv_images[item_idx]
+
+ scene, frame_id = self._bev_images[item_idx]["id"].split(";")
+
+ # Get the RGB file names
+ img_file = os.path.join(
+ self.kitti_root_dir,
+ self._img_map["front"]["{}.png"
+ .format(bev_img_desc['id'])]
+ )
+
+ if not os.path.exists(img_file):
+ raise IOError(
+ "RGB image not found! Scene: {}, Frame: {}".format(scene, frame_id)
+ )
+
+ # Load the images
+ img = Image.open(img_file).convert(mode="RGB")
+
+ # Load the BEV mask
+ bev_msk_file = os.path.join(
+ self._bev_msk_dir,
+ "{}.png".format(bev_img_desc['id'])
+ )
+ bev_msk = Image.open(bev_msk_file)
+ bev_plabel = None
+
+ # Load the front mask
+ fv_msk_file = os.path.join(
+ self._fv_msk_dir,
+ "{}.png".format(fv_img_desc['id'])
+ )
+ fv_msk = Image.open(fv_msk_file)
+
+
+ bev_weights_msk_combined = None
+
+ # Get the other information
+ bev_cat = bev_img_desc["cat"]
+ bev_iscrowd = bev_img_desc["iscrowd"]
+ fv_cat = fv_img_desc['cat']
+ fv_iscrowd = fv_img_desc['iscrowd']
+ fv_intrinsics = fv_img_desc["cam_intrinsic"]
+ ego_pose = fv_img_desc['ego_pose'] # This loads the cam0 pose
+
+ # Get the ids of all the frames
+ frame_ids = bev_img_desc["id"]
+
+ return img, bev_msk, bev_plabel, fv_msk, bev_weights_msk_combined, bev_cat, \
+ bev_iscrowd, fv_cat, fv_iscrowd, fv_intrinsics, ego_pose, frame_ids
+
+ @property
+ def fv_categories(self):
+ """Category names"""
+ return self._fv_meta["categories"]
+
+ @property
+ def fv_num_categories(self):
+ """Number of categories"""
+ return len(self.fv_categories)
+
+ @property
+ def fv_num_stuff(self):
+ """Number of "stuff" categories"""
+ return self._fv_meta["num_stuff"]
+
+ @property
+ def fv_num_thing(self):
+ """Number of "thing" categories"""
+ return self.fv_num_categories - self.fv_num_stuff
+
+ @property
+ def bev_categories(self):
+ """Category names"""
+ return self._bev_meta["categories"]
+
+ @property
+ def bev_num_categories(self):
+ """Number of categories"""
+ return len(self.bev_categories)
+
+ @property
+ def bev_num_stuff(self):
+ """Number of "stuff" categories"""
+ return self._bev_meta["num_stuff"]
+
+ @property
+ def bev_num_thing(self):
+ """Number of "thing" categories"""
+ return self.bev_num_categories - self.bev_num_stuff
+
+ @property
+ def original_ids(self):
+ """Original class id of each category"""
+ return self._fv_meta["original_ids"]
+
+ @property
+ def palette(self):
+ """Default palette to be used when color-coding semantic labels"""
+ return np.array(self._fv_meta["palette"], dtype=np.uint8)
+
+ @property
+ def img_sizes(self):
+ """Size of each image of the dataset"""
+ return [img_desc["size"] for img_desc in self._fv_images]
+
+ @property
+ def img_categories(self):
+ """Categories present in each image of the dataset"""
+ return [img_desc["cat"] for img_desc in self._fv_images]
+
+ @property
+ def dataset_name(self):
+ return "Kitti360"
+
+ def __len__(self):
+ if self.cfg.percentage < 1:
+ return int(len(self._fv_images) * self.cfg.percentage)
+
+ return len(self._fv_images)
+
+ def __getitem__(self, item):
+ img, bev_msk, bev_plabel, fv_msk, bev_weights_msk, bev_cat, bev_iscrowd, fv_cat, fv_iscrowd, fv_intrinsics, ego_pose, idx = self._load_item(item)
+
+ rec = self.transform(img=img, bev_msk=bev_msk, bev_plabel=bev_plabel, fv_msk=fv_msk, bev_weights_msk=bev_weights_msk, bev_cat=bev_cat,
+ bev_iscrowd=bev_iscrowd, fv_cat=fv_cat, fv_iscrowd=fv_iscrowd, fv_intrinsics=fv_intrinsics,
+ ego_pose=ego_pose)
+ size = (img.size[1], img.size[0])
+
+ # Close the file
+ img.close()
+ bev_msk.close()
+ fv_msk.close()
+
+ rec["index"] = idx
+ rec["size"] = size
+ rec['name'] = idx
+
+ return rec
+
+ def get_image_desc(self, idx):
+ """Look up an image descriptor given the id"""
+ matching = [img_desc for img_desc in self._images if img_desc["id"] == idx]
+ if len(matching) == 1:
+ return matching[0]
+ else:
+ raise ValueError("No image found with id %s" % idx)
\ No newline at end of file
diff --git a/mapper/data/kitti/transform.py b/mapper/data/kitti/transform.py
new file mode 100644
index 0000000000000000000000000000000000000000..03652cb2e98749cc00ed6c9bab61bd84b8a9b998
--- /dev/null
+++ b/mapper/data/kitti/transform.py
@@ -0,0 +1,149 @@
+import numpy as np
+import torch
+from torchvision.transforms import functional as tfn
+import torchvision.transforms.functional as tvf
+
+from ..utils import decompose_rotmat
+from ..image import pad_image, rectify_image, resize_image
+from ...utils.wrappers import Camera
+from ..schema import KITTIDataConfiguration
+
+
+class BEVTransform:
+ def __init__(self,
+ cfg: KITTIDataConfiguration, augmentations):
+ self.cfg = cfg
+ self.augmentations = augmentations
+
+ @staticmethod
+ def _compact_labels(msk, cat, iscrowd):
+ ids = np.unique(msk)
+ if 0 not in ids:
+ ids = np.concatenate((np.array([0], dtype=np.int32), ids), axis=0)
+
+ ids_to_compact = np.zeros((ids.max() + 1,), dtype=np.int32)
+ ids_to_compact[ids] = np.arange(0, ids.size, dtype=np.int32)
+
+ msk = ids_to_compact[msk]
+ cat = cat[ids]
+ iscrowd = iscrowd[ids]
+
+ return msk, cat, iscrowd
+
+ def __call__(self, img, bev_msk=None, bev_plabel=None, fv_msk=None, bev_weights_msk=None,
+ bev_cat=None, bev_iscrowd=None, fv_cat=None, fv_iscrowd=None,
+ fv_intrinsics=None, ego_pose=None):
+ # Wrap in np.array
+ if bev_cat is not None:
+ bev_cat = np.array(bev_cat, dtype=np.int32)
+ if bev_iscrowd is not None:
+ bev_iscrowd = np.array(bev_iscrowd, dtype=np.uint8)
+
+ if ego_pose is not None:
+ ego_pose = np.array(ego_pose, dtype=np.float32)
+
+ roll, pitch, yaw = decompose_rotmat(ego_pose[:3, :3])
+
+ # Image transformations
+ img = tfn.to_tensor(img)
+ # img = [self._normalize_image(rgb) for rgb in img]
+ fx = fv_intrinsics[0][0]
+ fy = fv_intrinsics[1][1]
+ cx = fv_intrinsics[0][2]
+ cy = fv_intrinsics[1][2]
+ width = img.shape[2]
+ height = img.shape[1]
+
+ cam = Camera(torch.tensor(
+ [width, height, fx, fy, cx - 0.5, cy - 0.5])).float()
+
+ if not self.cfg.gravity_align:
+ # Turn off gravity alignment
+ roll = 0.0
+ pitch = 0.0
+ img, valid = rectify_image(img, cam, roll, pitch)
+ else:
+ img, valid = rectify_image(
+ img, cam, roll, pitch if self.cfg.rectify_pitch else None
+ )
+ roll = 0.0
+ if self.cfg.rectify_pitch:
+ pitch = 0.0
+
+ if self.cfg.target_focal_length is not None:
+ # Resize to a canonical focal length
+ factor = self.cfg.target_focal_length / cam.f.numpy()
+ size = (np.array(img.shape[-2:][::-1]) * factor).astype(int)
+ img, _, cam, valid = resize_image(img, size, camera=cam, valid=valid)
+ size_out = self.cfg.resize_image
+ if size_out is None:
+ # Round the edges up such that they are multiple of a factor
+ stride = self.cfg.pad_to_multiple
+ size_out = (np.ceil((size / stride)) * stride).astype(int)
+ # Crop or pad such that both edges are of the given size
+ img, valid, cam = pad_image(
+ img, size_out, cam, valid, crop_and_center=False
+ )
+ elif self.cfg.resize_image is not None:
+ img, _, cam, valid = resize_image(
+ img, self.cfg.resize_image, fn=max, camera=cam, valid=valid
+ )
+ if self.cfg.pad_to_square:
+ # Pad such that both edges are of the given size
+ img, valid, cam = pad_image(img, self.cfg.resize_image, cam, valid)
+
+ # Label transformations,
+ if bev_msk is not None:
+ bev_msk = np.expand_dims(
+ np.array(bev_msk, dtype=np.int32, copy=False),
+ axis=0
+ )
+ bev_msk, bev_cat, bev_iscrowd = self._compact_labels(
+ bev_msk, bev_cat, bev_iscrowd
+ )
+
+ bev_msk = torch.from_numpy(bev_msk)
+ bev_cat = torch.from_numpy(bev_cat)
+
+ rotated_mask = torch.rot90(bev_msk, dims=(1, 2))
+ cropped_mask = rotated_mask[:, :672, (rotated_mask.size(2) - 672) // 2:-(rotated_mask.size(2) - 672) // 2]
+
+ bev_msk = cropped_mask.squeeze(0)
+ seg_masks = bev_cat[bev_msk]
+
+ seg_masks_onehot = seg_masks.clone()
+ seg_masks_onehot[seg_masks_onehot == 255] = 0
+ seg_masks_onehot = torch.nn.functional.one_hot(
+ seg_masks_onehot.to(torch.int64),
+ num_classes=self.cfg.num_classes
+ )
+ seg_masks_onehot[seg_masks == 255] = 0
+
+ seg_masks_onehot = seg_masks_onehot.permute(2, 0, 1)
+
+ seg_masks_down = tvf.resize(seg_masks_onehot, (100, 100))
+
+ seg_masks_down = seg_masks_down.permute(1, 2, 0)
+
+ if self.cfg.class_mapping is not None:
+ seg_masks_down = seg_masks_down[:, :, self.cfg.class_mapping]
+
+ img = self.augmentations(img)
+ flood_masks = torch.all(seg_masks_down == 0, dim=2).float()
+
+
+ ret = {
+ "image": img,
+ "valid": valid,
+ "camera": cam,
+ "seg_masks": (seg_masks_down).float().contiguous(),
+ "flood_masks": flood_masks,
+ "roll_pitch_yaw": torch.tensor((roll, pitch, yaw)).float(),
+ "confidence_map": flood_masks,
+ }
+
+ for key, value in ret.items():
+ if isinstance(value, np.ndarray):
+ ret[key] = torch.from_numpy(value)
+
+ return ret
diff --git a/mapper/data/mapillary/data_module.py b/mapper/data/mapillary/data_module.py
new file mode 100644
index 0000000000000000000000000000000000000000..7197e60f9f446a80e6b31a8dbc61adfcd92fdcd8
--- /dev/null
+++ b/mapper/data/mapillary/data_module.py
@@ -0,0 +1,317 @@
+import json
+from collections import defaultdict
+import os
+import shutil
+import tarfile
+from pathlib import Path
+from typing import Optional
+
+import numpy as np
+import pytorch_lightning as pl
+import torch
+import torch.utils.data as torchdata
+from omegaconf import DictConfig
+
+from ... import logger
+from .dataset import MapLocDataset
+from ..sequential import chunk_sequence
+from ..torch import collate, worker_init_fn
+from ..schema import MIADataConfiguration
+
+def pack_dump_dict(dump):
+ for per_seq in dump.values():
+ if "points" in per_seq:
+ for chunk in list(per_seq["points"]):
+ points = per_seq["points"].pop(chunk)
+ if points is not None:
+ per_seq["points"][chunk] = np.array(
+ per_seq["points"][chunk], np.float64
+ )
+ for view in per_seq["views"].values():
+ for k in ["R_c2w", "roll_pitch_yaw"]:
+ view[k] = np.array(view[k], np.float32)
+ for k in ["chunk_id"]:
+ if k in view:
+ view.pop(k)
+ if "observations" in view:
+ view["observations"] = np.array(view["observations"])
+ for camera in per_seq["cameras"].values():
+ for k in ["params"]:
+ camera[k] = np.array(camera[k], np.float32)
+ return dump
+
+
+class MapillaryDataModule(pl.LightningDataModule):
+ dump_filename = "dump.json"
+ images_archive = "images.tar.gz"
+ images_dirname = "images/"
+ semantic_masks_dirname = "semantic_masks/"
+ flood_dirname = "flood_fill/"
+
+ def __init__(self, cfg: MIADataConfiguration):
+ super().__init__()
+ self.cfg = cfg
+ self.root = self.cfg.data_dir
+ self.local_dir = None
+
+ def prepare_data(self):
+ for scene in self.cfg.scenes:
+ dump_dir = self.root / scene
+ assert (dump_dir / self.dump_filename).exists(), dump_dir
+ # assert (dump_dir / self.cfg.tiles_filename).exists(), dump_dir
+ if self.local_dir is None:
+ assert (dump_dir / self.images_dirname).exists(), dump_dir
+ continue
+ assert (dump_dir / self.semantic_masks_dirname).exists(), dump_dir
+ assert (dump_dir / self.flood_dirname).exists(), dump_dir
+ # Cache the folder of images locally to speed up reading
+ local_dir = self.local_dir / scene
+ if local_dir.exists():
+ shutil.rmtree(local_dir)
+ local_dir.mkdir(exist_ok=True, parents=True)
+ images_archive = dump_dir / self.images_archive
+ logger.info("Extracting the image archive %s.", images_archive)
+ with tarfile.open(images_archive) as fp:
+ fp.extractall(local_dir)
+
+ def setup(self, stage: Optional[str] = None):
+ self.dumps = {}
+ # self.tile_managers = {}
+ self.image_dirs = {}
+ self.seg_masks_dir = {}
+ self.flood_masks_dir = {}
+ names = []
+
+ for scene in self.cfg.scenes:
+ logger.info("Loading scene %s.", scene)
+ dump_dir = self.root / scene
+
+ logger.info("Loading dump json file %s.", self.dump_filename)
+ with (dump_dir / self.dump_filename).open("r") as fp:
+ self.dumps[scene] = pack_dump_dict(json.load(fp))
+ for seq, per_seq in self.dumps[scene].items():
+ for cam_id, cam_dict in per_seq["cameras"].items():
+ if cam_dict["model"] != "PINHOLE":
+ raise ValueError(
+ f"Unsupported camera model: {cam_dict['model']} for {scene},{seq},{cam_id}"
+ )
+
+ self.image_dirs[scene] = (
+ (self.local_dir or self.root) / scene / self.images_dirname
+ )
+ assert self.image_dirs[scene].exists(), self.image_dirs[scene]
+
+ self.seg_masks_dir[scene] = (
+ (self.local_dir or self.root) / scene / self.semantic_masks_dirname
+ )
+ assert self.seg_masks_dir[scene].exists(), self.seg_masks_dir[scene]
+
+ self.flood_masks_dir[scene] = (
+ (self.local_dir or self.root) / scene / self.flood_dirname
+ )
+ assert self.flood_masks_dir[scene].exists(), self.flood_masks_dir[scene]
+
+ images = set(x.split('.')[0] for x in os.listdir(self.image_dirs[scene]))
+ flood_masks = set(x.split('.')[0] for x in os.listdir(self.flood_masks_dir[scene]))
+ semantic_masks = set(x.split('.')[0] for x in os.listdir(self.seg_masks_dir[scene]))
+
+ for seq, data in self.dumps[scene].items():
+ for name in data["views"]:
+ if name in images and name.split("_")[0] in flood_masks and name.split("_")[0] in semantic_masks:
+ names.append((scene, seq, name))
+
+ self.parse_splits(self.cfg.split, names)
+ if self.cfg.filter_for is not None:
+ self.filter_elements()
+ self.pack_data()
+
+ def pack_data(self):
+ # We pack the data into compact tensors that can be shared across processes without copy
+ exclude = {
+ "compass_angle",
+ "compass_accuracy",
+ "gps_accuracy",
+ "chunk_key",
+ "panorama_offset",
+ }
+ cameras = {
+ scene: {seq: per_seq["cameras"] for seq, per_seq in per_scene.items()}
+ for scene, per_scene in self.dumps.items()
+ }
+ points = {
+ scene: {
+ seq: {
+ i: torch.from_numpy(p) for i, p in per_seq.get("points", {}).items()
+ }
+ for seq, per_seq in per_scene.items()
+ }
+ for scene, per_scene in self.dumps.items()
+ }
+ self.data = {}
+
+ # TODO: remove
+ if self.cfg.split == "splits_MGL_13loc.json":
+ # Use Last 20% as Val
+ num_samples_to_move = int(len(self.splits['train']) * 0.2)
+ samples_to_move = self.splits['train'][-num_samples_to_move:]
+ self.splits['val'].extend(samples_to_move)
+ self.splits['train'] = self.splits['train'][:-num_samples_to_move]
+ print(f"Dataset Len: {len(self.splits['train']), len(self.splits['val'])}\n\n\n\n")
+ elif self.cfg.split == "splits_MGL_soma_70k_mappred_random.json":
+ for stage, names in self.splits.items():
+ print("Length of splits {}: ".format(stage), len(self.splits[stage]))
+ for stage, names in self.splits.items():
+ view = self.dumps[names[0][0]][names[0][1]]["views"][names[0][2]]
+ data = {k: [] for k in view.keys() - exclude}
+ for scene, seq, name in names:
+ for k in data:
+ data[k].append(self.dumps[scene][seq]["views"][name].get(k, None))
+ for k in data:
+ v = np.array(data[k])
+ if np.issubdtype(v.dtype, np.integer) or np.issubdtype(
+ v.dtype, np.floating
+ ):
+ v = torch.from_numpy(v)
+ data[k] = v
+ data["cameras"] = cameras
+ data["points"] = points
+ self.data[stage] = data
+ self.splits[stage] = np.array(names)
+
+ def filter_elements(self):
+ for stage, names in self.splits.items():
+ names_select = []
+ for scene, seq, name in names:
+ view = self.dumps[scene][seq]["views"][name]
+ if self.cfg.filter_for == "ground_plane":
+ if not (1.0 <= view["height"] <= 3.0):
+ continue
+ planes = self.dumps[scene][seq].get("plane")
+ if planes is not None:
+ inliers = planes[str(view["chunk_id"])][-1]
+ if inliers < 10:
+ continue
+ if self.cfg.filter_by_ground_angle is not None:
+ plane = np.array(view["plane_params"])
+ normal = plane[:3] / np.linalg.norm(plane[:3])
+ angle = np.rad2deg(np.arccos(np.abs(normal[-1])))
+ if angle > self.cfg.filter_by_ground_angle:
+ continue
+ elif self.cfg.filter_for == "pointcloud":
+ if len(view["observations"]) < self.cfg.min_num_points:
+ continue
+ elif self.cfg.filter_for is not None:
+ raise ValueError(f"Unknown filtering: {self.cfg.filter_for}")
+ names_select.append((scene, seq, name))
+ logger.info(
+ "%s: Keep %d/%d images after filtering for %s.",
+ stage,
+ len(names_select),
+ len(names),
+ self.cfg.filter_for,
+ )
+ self.splits[stage] = names_select
+
+ def parse_splits(self, split_arg, names):
+ if split_arg is None:
+ self.splits = {
+ "train": names,
+ "val": names,
+ }
+ elif isinstance(split_arg, int):
+ names = np.random.RandomState(self.cfg.seed).permutation(names).tolist()
+ self.splits = {
+ "train": names[split_arg:],
+ "val": names[:split_arg],
+ }
+ elif isinstance(split_arg, float):
+ names = np.random.RandomState(self.cfg.seed).permutation(names).tolist()
+ self.splits = {
+ "train": names[int(split_arg * len(names)) :],
+ "val": names[: int(split_arg * len(names))],
+ }
+ elif isinstance(split_arg, DictConfig):
+ scenes_val = set(split_arg.val)
+ scenes_train = set(split_arg.train)
+ assert len(scenes_val - set(self.cfg.scenes)) == 0
+ assert len(scenes_train - set(self.cfg.scenes)) == 0
+ self.splits = {
+ "train": [n for n in names if n[0] in scenes_train],
+ "val": [n for n in names if n[0] in scenes_val],
+ }
+ elif isinstance(split_arg, str):
+
+ if "/" in split_arg:
+ split_path = self.root / split_arg
+ else:
+ split_path = Path(split_arg)
+
+ with split_path.open("r") as fp:
+ splits = json.load(fp)
+ splits = {
+ k: {loc: set(ids) for loc, ids in split.items()}
+ for k, split in splits.items()
+ }
+ self.splits = {}
+
+ for k, split in splits.items():
+ self.splits[k] = [
+ n
+ for n in names
+ if n[0] in split and int(n[-1].rsplit("_", 1)[0]) in split[n[0]]
+ ]
+ else:
+ raise ValueError(split_arg)
+
+ def dataset(self, stage: str):
+ return MapLocDataset(
+ stage,
+ self.cfg,
+ self.splits[stage],
+ self.data[stage],
+ self.image_dirs,
+ self.seg_masks_dir,
+ self.flood_masks_dir,
+
+ image_ext=".jpg",
+ )
+
+ def sequence_dataset(self, stage: str, **kwargs):
+ keys = self.splits[stage]
+ seq2indices = defaultdict(list)
+ for index, (_, seq, _) in enumerate(keys):
+ seq2indices[seq].append(index)
+ # chunk the sequences to the required length
+ chunk2indices = {}
+ for seq, indices in seq2indices.items():
+ chunks = chunk_sequence(self.data[stage], indices, **kwargs)
+ for i, sub_indices in enumerate(chunks):
+ chunk2indices[seq, i] = sub_indices
+ # store the index of each chunk in its sequence
+ chunk_indices = torch.full((len(keys),), -1)
+ for (_, chunk_index), idx in chunk2indices.items():
+ chunk_indices[idx] = chunk_index
+ self.data[stage]["chunk_index"] = chunk_indices
+ dataset = self.dataset(stage)
+ return dataset, chunk2indices
+
+ def sequence_dataloader(self, stage: str, shuffle: bool = False, **kwargs):
+ dataset, chunk2idx = self.sequence_dataset(stage, **kwargs)
+ chunk_keys = sorted(chunk2idx)
+ if shuffle:
+ perm = torch.randperm(len(chunk_keys))
+ chunk_keys = [chunk_keys[i] for i in perm]
+ key_indices = [i for key in chunk_keys for i in chunk2idx[key]]
+ num_workers = self.cfg.loading[stage]["num_workers"]
+ loader = torchdata.DataLoader(
+ dataset,
+ batch_size=None,
+ sampler=key_indices,
+ num_workers=num_workers,
+ shuffle=False,
+ pin_memory=True,
+ persistent_workers=num_workers > 0,
+ worker_init_fn=worker_init_fn,
+ collate_fn=collate,
+ )
+ return loader, chunk_keys, chunk2idx
diff --git a/mapper/data/mapillary/dataset.py b/mapper/data/mapillary/dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..ad89fe89229223cbad592d0e1071a3912d360c19
--- /dev/null
+++ b/mapper/data/mapillary/dataset.py
@@ -0,0 +1,255 @@
+from copy import deepcopy
+from pathlib import Path
+from typing import Any, Dict, List
+
+import numpy as np
+import torch
+import torch.utils.data as torchdata
+import torchvision.transforms as tvf
+from PIL import Image
+from pathlib import Path
+
+from ...models.utils import deg2rad, rotmat2d
+from ...utils.io import read_image
+from ...utils.wrappers import Camera
+from ..image import pad_image, rectify_image, resize_image
+from ..utils import decompose_rotmat
+from ..schema import MIADataConfiguration
+
+
+class MapLocDataset(torchdata.Dataset):
+ def __init__(
+ self,
+ stage: str,
+ cfg: MIADataConfiguration,
+ names: List[str],
+ data: Dict[str, Any],
+ image_dirs: Dict[str, Path],
+ seg_mask_dirs: Dict[str, Path],
+ flood_masks_dirs: Dict[str, Path],
+ image_ext: str = "",
+ ):
+ self.stage = stage
+ self.cfg = deepcopy(cfg)
+ self.data = data
+ self.image_dirs = image_dirs
+ self.seg_mask_dirs = seg_mask_dirs
+ self.flood_masks_dirs = flood_masks_dirs
+ self.names = names
+ self.image_ext = image_ext
+
+ tfs = []
+ self.tfs = tvf.Compose(tfs)
+ self.augmentations = self.get_augmentations()
+
+ def __len__(self):
+ return len(self.names)
+
+ def __getitem__(self, idx):
+ if self.stage == "train" and self.cfg.random:
+ seed = None
+ else:
+ seed = [self.cfg.seed, idx]
+ (seed,) = np.random.SeedSequence(seed).generate_state(1)
+
+ scene, seq, name = self.names[idx]
+
+ view = self.get_view(
+ idx, scene, seq, name, seed
+ )
+
+ return view
+
+ def get_augmentations(self):
+ if self.stage != "train" or not self.cfg.augmentations.enabled:
+ print(f"No Augmentation!", "\n" * 10)
+ self.cfg.augmentations.random_flip = 0.0
+ return tvf.Compose([])
+
+ print(f"Augmentation!", "\n" * 10)
+ augmentations = [
+ tvf.ColorJitter(
+ brightness=self.cfg.augmentations.brightness,
+ contrast=self.cfg.augmentations.contrast,
+ saturation=self.cfg.augmentations.saturation,
+ hue=self.cfg.augmentations.hue,
+ )
+ ]
+
+ if self.cfg.augmentations.random_resized_crop:
+ augmentations.append(
+ tvf.RandomResizedCrop(scale=(0.8, 1.0))
+ ) # RandomResizedCrop
+
+ if self.cfg.augmentations.gaussian_noise.enabled:
+ augmentations.append(
+ tvf.GaussianNoise(
+ mean=self.cfg.augmentations.gaussian_noise.mean,
+ std=self.cfg.augmentations.gaussian_noise.std,
+ )
+ ) # Gaussian noise
+
+ if self.cfg.augmentations.brightness_contrast.enabled:
+ augmentations.append(
+ tvf.ColorJitter(
+ brightness=self.cfg.augmentations.brightness_contrast.brightness_factor,
+ contrast=self.cfg.augmentations.brightness_contrast.contrast_factor,
+ saturation=0, # Keep saturation at 0 for brightness and contrast adjustment
+ hue=0,
+ )
+ ) # Brightness and contrast adjustment
+
+ return tvf.Compose(augmentations)
+
+ def random_flip(self, image, cam, valid, seg_mask, flood_mask, conf_mask):
+ if torch.rand(1) < self.cfg.augmentations.random_flip:
+ image = torch.flip(image, [-1])
+ cam = cam.flip()
+ valid = torch.flip(valid, [-1])
+ seg_mask = torch.flip(seg_mask, [1])
+ flood_mask = torch.flip(flood_mask, [-1])
+ conf_mask = torch.flip(conf_mask, [-1])
+
+ return image, cam, valid, seg_mask, flood_mask, conf_mask
+
+ def get_view(self, idx, scene, seq, name, seed):
+ data = {
+ "index": idx,
+ "name": name,
+ "scene": scene,
+ "sequence": seq,
+ }
+ cam_dict = self.data["cameras"][scene][seq][self.data["camera_id"][idx]]
+ cam = Camera.from_dict(cam_dict).float()
+
+ if "roll_pitch_yaw" in self.data:
+ roll, pitch, yaw = self.data["roll_pitch_yaw"][idx].numpy()
+ else:
+ roll, pitch, yaw = decompose_rotmat(
+ self.data["R_c2w"][idx].numpy())
+
+ image = read_image(self.image_dirs[scene] / (name + self.image_ext))
+ image = Image.fromarray(image)
+ image = self.augmentations(image)
+ image = np.array(image)
+
+ if "plane_params" in self.data:
+ # transform the plane parameters from world to camera frames
+ plane_w = self.data["plane_params"][idx]
+ data["ground_plane"] = torch.cat(
+ [rotmat2d(deg2rad(torch.tensor(yaw)))
+ @ plane_w[:2], plane_w[2:]]
+ )
+
+ image, valid, cam, roll, pitch = self.process_image(
+ image, cam, roll, pitch, seed
+ )
+
+ if "chunk_index" in self.data: # TODO: (cherie) do we need this?
+ data["chunk_id"] = (scene, seq, self.data["chunk_index"][idx])
+
+ # Semantic map extraction
+ seg_mask_path = self.seg_mask_dirs[scene] / \
+ (name.split("_")[0] + ".npy")
+ seg_masks_ours = np.load(seg_mask_path)
+ mask_center = (
+ seg_masks_ours.shape[0] // 2, seg_masks_ours.shape[1] // 2)
+
+ seg_masks_ours = seg_masks_ours[mask_center[0] -
+ 100:mask_center[0], mask_center[1] - 50: mask_center[1] + 50]
+
+ if self.cfg.num_classes == 6:
+ seg_masks_ours = seg_masks_ours[..., [0, 1, 2, 4, 6, 7]]
+
+ flood_mask_path = self.flood_masks_dirs[scene] / \
+ (name.split("_")[0] + ".npy")
+ flood_mask = np.load(flood_mask_path)
+
+ flood_mask = flood_mask[mask_center[0]-100:mask_center[0],
+ mask_center[1] - 50: mask_center[1] + 50]
+
+ confidence_map = flood_mask.copy()
+ confidence_map = (confidence_map - confidence_map.min()) / \
+ (confidence_map.max() - confidence_map.min() + 1e-6)
+
+ seg_masks_ours = torch.from_numpy(seg_masks_ours).float()
+ flood_mask = torch.from_numpy(flood_mask).float()
+ confidence_map = torch.from_numpy(confidence_map).float()
+
+ # Map Augmentations
+ with torch.random.fork_rng(devices=[]):
+ torch.manual_seed(seed)
+ image, cam, valid, seg_masks_ours, flood_mask, confidence_map = self.random_flip(
+ image, cam, valid, seg_masks_ours, flood_mask, confidence_map)
+
+ return {
+ **data,
+ "image": image,
+ "valid": valid,
+ "camera": cam,
+ "seg_masks": seg_masks_ours,
+ "flood_masks": flood_mask,
+ "roll_pitch_yaw": torch.tensor((roll, pitch, yaw)).float(),
+ "confidence_map": confidence_map
+ # "pixels_per_meter": torch.tensor(canvas.ppm).float(),
+ }
+
+ def process_image(self, image, cam, roll, pitch, seed):
+ image = (
+ torch.from_numpy(np.ascontiguousarray(image))
+ .permute(2, 0, 1)
+ .float()
+ .div_(255)
+ )
+
+ if not self.cfg.gravity_align:
+ # Turn off gravity alignment
+ roll = 0.0
+ pitch = 0.0
+ image, valid = rectify_image(image, cam, roll, pitch)
+ else:
+ image, valid = rectify_image(
+ image, cam, roll, pitch if self.cfg.rectify_pitch else None
+ )
+ roll = 0.0
+ if self.cfg.rectify_pitch:
+ pitch = 0.0
+
+ if self.cfg.target_focal_length is not None:
+ # Resize to a canonical focal length
+ factor = self.cfg.target_focal_length / cam.f.numpy()
+ size = (np.array(image.shape[-2:][::-1]) * factor).astype(int)
+ image, _, cam, valid = resize_image(
+ image, size, camera=cam, valid=valid)
+ size_out = self.cfg.resize_image
+ if size_out is None:
+ # Round the edges up such that they are multiple of a factor
+ stride = self.cfg.pad_to_multiple
+ size_out = (np.ceil((size / stride)) * stride).astype(int)
+ # Crop or pad such that both edges are of the given size
+ image, valid, cam = pad_image(
+ image, size_out, cam, valid, crop_and_center=True
+ )
+ elif self.cfg.resize_image is not None:
+ image, _, cam, valid = resize_image(
+ image, self.cfg.resize_image, fn=max, camera=cam, valid=valid
+ )
+ if self.cfg.pad_to_square:
+ # Pad such that both edges are of the given size
+ image, valid, cam = pad_image(
+ image, self.cfg.resize_image, cam, valid)
+
+ if self.cfg.reduce_fov is not None:
+ h, w = image.shape[-2:]
+ f = float(cam.f[0])
+ fov = np.arctan(w / f / 2)
+ w_new = round(2 * f * np.tan(self.cfg.reduce_fov * fov))
+ image, valid, cam = pad_image(
+ image, (w_new, h), cam, valid, crop_and_center=True
+ )
+
+ with torch.random.fork_rng(devices=[]):
+ torch.manual_seed(seed)
+ image = self.tfs(image)
+
+ return image, valid, cam, roll, pitch
diff --git a/mapper/data/module.py b/mapper/data/module.py
new file mode 100644
index 0000000000000000000000000000000000000000..ed964cb47b2292cf4d8830eb48b938aef70bac7a
--- /dev/null
+++ b/mapper/data/module.py
@@ -0,0 +1,64 @@
+from typing import Optional
+from omegaconf import DictConfig
+import pytorch_lightning as L
+import torch.utils.data as torchdata
+from .torch import collate, worker_init_fn
+
+
+def get_dataset(name):
+ if name == "mapillary":
+ from .mapillary.data_module import MapillaryDataModule
+ return MapillaryDataModule
+ elif name == "nuscenes":
+ from .nuscenes.data_module import NuScenesData
+ return NuScenesData
+ elif name == "kitti":
+ from .kitti.data_module import BEVKitti360Data
+ return BEVKitti360Data
+ else:
+ raise NotImplementedError(f"Dataset {name} not implemented.")
+
+
+class GenericDataModule(L.LightningDataModule):
+ def __init__(self, cfg: DictConfig):
+ super().__init__()
+ self.cfg = cfg
+ self.data_module = get_dataset(cfg.name)(cfg)
+
+ def prepare_data(self) -> None:
+ self.data_module.prepare_data()
+
+ def setup(self, stage: Optional[str] = None):
+ self.data_module.setup(stage)
+
+ def dataloader(
+ self,
+ stage: str,
+ shuffle: bool = False,
+ num_workers: int = None,
+ sampler: Optional[torchdata.Sampler] = None,
+ ):
+ dataset = self.data_module.dataset(stage)
+ cfg = self.cfg["loading"][stage]
+ num_workers = cfg["num_workers"] if num_workers is None else num_workers
+ loader = torchdata.DataLoader(
+ dataset,
+ batch_size=cfg["batch_size"],
+ num_workers=num_workers,
+ shuffle=shuffle or (stage == "train"),
+ pin_memory=True,
+ persistent_workers=num_workers > 0,
+ worker_init_fn=worker_init_fn,
+ collate_fn=collate,
+ sampler=sampler,
+ )
+ return loader
+
+ def train_dataloader(self, **kwargs):
+ return self.dataloader("train", **kwargs)
+
+ def val_dataloader(self, **kwargs):
+ return self.dataloader("val", **kwargs)
+
+ def test_dataloader(self, **kwargs):
+ return self.dataloader("test", **kwargs)
\ No newline at end of file
diff --git a/mapper/data/nuscenes/data_module.py b/mapper/data/nuscenes/data_module.py
new file mode 100644
index 0000000000000000000000000000000000000000..957a482609ef09f217cb2cc8526e193cae0c192a
--- /dev/null
+++ b/mapper/data/nuscenes/data_module.py
@@ -0,0 +1,33 @@
+from ..base import DataBase
+from .dataset import NuScenesDataset
+from ..schema import NuScenesDataConfiguration
+
+class NuScenesData(DataBase):
+ def __init__(self, cfg: NuScenesDataConfiguration):
+ self.cfg = cfg
+ self._dataset = {}
+
+ def prepare_data(self):
+ pass
+
+ def setup(self, stage):
+ if stage is None:
+ stage = 'fit'
+
+ split = {
+ 'fit': 'train',
+ 'val': 'val',
+ 'validate': 'val',
+ 'test': 'test'
+ }[stage]
+
+ self._dataset[split] = NuScenesDataset(
+ split=split,
+ cfg=self.cfg
+ )
+
+ def dataset(self, stage):
+ if self._dataset.get(stage) is None:
+ self.setup(stage)
+
+ return self._dataset[stage]
\ No newline at end of file
diff --git a/mapper/data/nuscenes/dataset.py b/mapper/data/nuscenes/dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..d97e0da5bf05e3507726db4d483fbaadf1a6a73a
--- /dev/null
+++ b/mapper/data/nuscenes/dataset.py
@@ -0,0 +1,207 @@
+import os
+import torch
+import numpy as np
+from pyquaternion import Quaternion
+from nuscenes.nuscenes import NuScenes
+from itertools import chain
+from PIL import Image
+from torchvision import transforms as T
+import torchvision.transforms as tvf
+from torchvision.transforms.functional import to_tensor
+
+from .splits_roddick import create_splits_scenes_roddick
+from ..image import pad_image, rectify_image, resize_image
+from .utils import decode_binary_labels
+from ..utils import decompose_rotmat
+from ...utils.io import read_image
+from ...utils.wrappers import Camera
+from ..schema import NuScenesDataConfiguration
+
+
+class NuScenesDataset(torch.utils.data.Dataset):
+ def __init__(self, cfg: NuScenesDataConfiguration, split="train"):
+
+ self.cfg = cfg
+ self.nusc = NuScenes(version=cfg.version, dataroot=str(cfg.data_dir))
+ self.map_data_root = cfg.map_dir
+ self.split = split
+
+ self.scenes = create_splits_scenes_roddick() # custom based on Roddick et al.
+
+ scene_split = {
+ 'v1.0-trainval': {'train': 'train', 'val': 'val', 'test': 'val'},
+ 'v1.0-mini': {'train': 'mini_train', 'val': 'mini_val'},
+ }[cfg.version][split]
+ self.scenes = self.scenes[scene_split]
+ self.sample = list(filter(lambda sample: self.nusc.get(
+ 'scene', sample['scene_token'])['name'] in self.scenes, self.nusc.sample))
+
+ self.tfs = self.get_augmentations() if split == "train" else T.Compose([])
+
+ data_tokens = []
+ for sample in self.sample:
+ data_token = sample['data']
+ data_token = [v for k,v in data_token.items() if k == "CAM_FRONT"]
+
+ data_tokens.append(data_token)
+
+ data_tokens = list(chain.from_iterable(data_tokens))
+ data = [self.nusc.get('sample_data', token) for token in data_tokens]
+
+ self.data = []
+ for d in data:
+ sample = self.nusc.get('sample', d['sample_token'])
+ scene = self.nusc.get('scene', sample['scene_token'])
+ location = self.nusc.get('log', scene['log_token'])['location']
+
+ file_name = d['filename']
+ ego_pose = self.nusc.get('ego_pose', d['ego_pose_token'])
+ calibrated_sensor = self.nusc.get(
+ "calibrated_sensor", d['calibrated_sensor_token'])
+
+ ego2global = np.eye(4).astype(np.float32)
+ ego2global[:3, :3] = Quaternion(ego_pose['rotation']).rotation_matrix
+ ego2global[:3, 3] = ego_pose['translation']
+
+ sensor2ego = np.eye(4).astype(np.float32)
+ sensor2ego[:3, :3] = Quaternion(
+ calibrated_sensor['rotation']).rotation_matrix
+ sensor2ego[:3, 3] = calibrated_sensor['translation']
+
+ sensor2global = ego2global @ sensor2ego
+
+ rotation = sensor2global[:3, :3]
+ roll, pitch, yaw = decompose_rotmat(rotation)
+
+ fx = calibrated_sensor['camera_intrinsic'][0][0]
+ fy = calibrated_sensor['camera_intrinsic'][1][1]
+ cx = calibrated_sensor['camera_intrinsic'][0][2]
+ cy = calibrated_sensor['camera_intrinsic'][1][2]
+ width = d['width']
+ height = d['height']
+
+ cam = Camera(torch.tensor(
+ [width, height, fx, fy, cx - 0.5, cy - 0.5])).float()
+ self.data.append({
+ 'filename': file_name,
+ 'yaw': yaw,
+ 'pitch': pitch,
+ 'roll': roll,
+ 'cam': cam,
+ 'sensor2global': sensor2global,
+ 'token': d['token'],
+ 'sample_token': d['sample_token'],
+ 'location': location
+ })
+
+ if self.cfg.percentage < 1.0 and split == "train":
+ self.data = self.data[:int(len(self.data) * self.cfg.percentage)]
+
+ def get_augmentations(self):
+
+ print(f"Augmentation!", "\n" * 10)
+ augmentations = [
+ tvf.ColorJitter(
+ brightness=self.cfg.augmentations.brightness,
+ contrast=self.cfg.augmentations.contrast,
+ saturation=self.cfg.augmentations.saturation,
+ hue=self.cfg.augmentations.hue,
+ )
+ ]
+
+ if self.cfg.augmentations.random_resized_crop:
+ augmentations.append(
+ tvf.RandomResizedCrop(scale=(0.8, 1.0))
+ ) # RandomResizedCrop
+
+ if self.cfg.augmentations.gaussian_noise.enabled:
+ augmentations.append(
+ tvf.GaussianNoise(
+ mean=self.cfg.augmentations.gaussian_noise.mean,
+ std=self.cfg.augmentations.gaussian_noise.std,
+ )
+ ) # Gaussian noise
+
+ if self.cfg.augmentations.brightness_contrast.enabled:
+ augmentations.append(
+ tvf.ColorJitter(
+ brightness=self.cfg.augmentations.brightness_contrast.brightness_factor,
+ contrast=self.cfg.augmentations.brightness_contrast.contrast_factor,
+ saturation=0, # Keep saturation at 0 for brightness and contrast adjustment
+ hue=0,
+ )
+ ) # Brightness and contrast adjustment
+
+ return tvf.Compose(augmentations)
+
+ def __len__(self):
+ return len(self.data)
+
+ def __getitem__(self, idx):
+ d = self.data[idx]
+
+ image = read_image(os.path.join(self.nusc.dataroot, d['filename']))
+ image = np.array(image)
+ cam = d['cam']
+ roll = d['roll']
+ pitch = d['pitch']
+ yaw = d['yaw']
+
+ with Image.open(self.map_data_root / f"{d['token']}.png") as semantic_image:
+ semantic_mask = to_tensor(semantic_image)
+
+ semantic_mask = decode_binary_labels(semantic_mask, self.cfg.num_classes + 1)
+ semantic_mask = torch.nn.functional.max_pool2d(semantic_mask.float(), (2, 2), stride=2) # 2 times downsample
+ semantic_mask = semantic_mask.permute(1, 2, 0)
+ semantic_mask = torch.flip(semantic_mask, [0])
+
+ visibility_mask = semantic_mask[..., -1]
+ semantic_mask = semantic_mask[..., :-1]
+
+ if self.cfg.class_mapping is not None:
+ semantic_mask = semantic_mask[..., self.cfg.class_mapping]
+
+ image = (
+ torch.from_numpy(np.ascontiguousarray(image))
+ .permute(2, 0, 1)
+ .float()
+ .div_(255)
+ )
+
+ if not self.cfg.gravity_align:
+ # Turn off gravity alignment
+ roll = 0.0
+ pitch = 0.0
+ image, valid = rectify_image(image, cam, roll, pitch)
+
+ else:
+ image, valid = rectify_image(
+ image, cam, roll, pitch if self.cfg.rectify_pitch else None
+ )
+ roll = 0.0
+ if self.cfg.rectify_pitch:
+ pitch = 0.0
+ if self.cfg.resize_image is not None:
+ image, _, cam, valid = resize_image(
+ image, self.cfg.resize_image, fn=max, camera=cam, valid=valid
+ )
+ if self.cfg.pad_to_square:
+ image, valid, cam = pad_image(image, self.cfg.resize_image, cam, valid)
+ image = self.tfs(image)
+
+ confidence_map = visibility_mask.clone().float()
+ confidence_map = (confidence_map - confidence_map.min()) / (confidence_map.max() - confidence_map.min())
+
+ return {
+ "image": image,
+ "roll_pitch_yaw": torch.tensor([roll, pitch, yaw]).float(),
+ "camera": cam,
+ "valid": valid,
+ "seg_masks": semantic_mask.float(),
+ "token": d['token'],
+ "sample_token": d['sample_token'],
+ 'location': d['location'],
+ 'flood_masks': visibility_mask.float(),
+ "confidence_map": confidence_map,
+ 'name': d['sample_token']
+ }
diff --git a/mapper/data/nuscenes/splits_roddick.py b/mapper/data/nuscenes/splits_roddick.py
new file mode 100644
index 0000000000000000000000000000000000000000..81af31ac2129ff65732981c586f766c906032156
--- /dev/null
+++ b/mapper/data/nuscenes/splits_roddick.py
@@ -0,0 +1,197 @@
+def create_splits_scenes_roddick():
+ train_roddick_scenes = [
+ "scene-0002", "scene-0003", "scene-0004", "scene-0005", "scene-0006",
+ "scene-0007", "scene-0008", "scene-0009", "scene-0012", "scene-0013",
+ "scene-0014", "scene-0015", "scene-0016", "scene-0017", "scene-0018",
+ "scene-0019", "scene-0021", "scene-0022", "scene-0023", "scene-0024",
+ "scene-0025", "scene-0026", "scene-0027", "scene-0028", "scene-0029",
+ "scene-0030", "scene-0031", "scene-0032", "scene-0033", "scene-0034",
+ "scene-0035", "scene-0036", "scene-0039", "scene-0042", "scene-0043",
+ "scene-0044", "scene-0045", "scene-0046", "scene-0047", "scene-0048",
+ "scene-0049", "scene-0050", "scene-0051", "scene-0052", "scene-0055",
+ "scene-0056", "scene-0057", "scene-0058", "scene-0059", "scene-0060",
+ "scene-0061", "scene-0062", "scene-0063", "scene-0064", "scene-0065",
+ "scene-0066", "scene-0067", "scene-0068", "scene-0069", "scene-0070",
+ "scene-0071", "scene-0072", "scene-0073", "scene-0074", "scene-0075",
+ "scene-0076", "scene-0092", "scene-0093", "scene-0094", "scene-0095",
+ "scene-0096", "scene-0097", "scene-0098", "scene-0099", "scene-0100",
+ "scene-0101", "scene-0102", "scene-0103", "scene-0104", "scene-0105",
+ "scene-0106", "scene-0107", "scene-0108", "scene-0109", "scene-0110",
+ "scene-0120", "scene-0123", "scene-0124", "scene-0125", "scene-0126",
+ "scene-0127", "scene-0128", "scene-0129", "scene-0130", "scene-0131",
+ "scene-0132", "scene-0133", "scene-0134", "scene-0135", "scene-0138",
+ "scene-0149", "scene-0150", "scene-0151", "scene-0154", "scene-0155",
+ "scene-0157", "scene-0158", "scene-0159", "scene-0161", "scene-0162",
+ "scene-0163", "scene-0164", "scene-0165", "scene-0166", "scene-0167",
+ "scene-0168", "scene-0170", "scene-0171", "scene-0172", "scene-0173",
+ "scene-0174", "scene-0175", "scene-0176", "scene-0177", "scene-0178",
+ "scene-0179", "scene-0180", "scene-0181", "scene-0182", "scene-0183",
+ "scene-0185", "scene-0187", "scene-0188", "scene-0190", "scene-0191",
+ "scene-0192", "scene-0193", "scene-0194", "scene-0195", "scene-0196",
+ "scene-0199", "scene-0200", "scene-0202", "scene-0203", "scene-0204",
+ "scene-0206", "scene-0207", "scene-0208", "scene-0209", "scene-0210",
+ "scene-0211", "scene-0212", "scene-0213", "scene-0214", "scene-0218",
+ "scene-0219", "scene-0220", "scene-0221", "scene-0222", "scene-0224",
+ "scene-0225", "scene-0226", "scene-0227", "scene-0228", "scene-0229",
+ "scene-0230", "scene-0231", "scene-0232", "scene-0233", "scene-0234",
+ "scene-0235", "scene-0236", "scene-0237", "scene-0238", "scene-0239",
+ "scene-0240", "scene-0241", "scene-0242", "scene-0243", "scene-0244",
+ "scene-0245", "scene-0246", "scene-0247", "scene-0248", "scene-0249",
+ "scene-0250", "scene-0251", "scene-0252", "scene-0253", "scene-0254",
+ "scene-0255", "scene-0256", "scene-0257", "scene-0258", "scene-0259",
+ "scene-0260", "scene-0261", "scene-0262", "scene-0263", "scene-0264",
+ "scene-0268", "scene-0270", "scene-0271", "scene-0272", "scene-0273",
+ "scene-0274", "scene-0275", "scene-0276", "scene-0277", "scene-0278",
+ "scene-0283", "scene-0284", "scene-0285", "scene-0286", "scene-0287",
+ "scene-0288", "scene-0289", "scene-0290", "scene-0291", "scene-0292",
+ "scene-0293", "scene-0294", "scene-0295", "scene-0296", "scene-0297",
+ "scene-0298", "scene-0299", "scene-0300", "scene-0301", "scene-0302",
+ "scene-0303", "scene-0304", "scene-0305", "scene-0306", "scene-0315",
+ "scene-0316", "scene-0317", "scene-0318", "scene-0321", "scene-0323",
+ "scene-0324", "scene-0328", "scene-0329", "scene-0330", "scene-0331",
+ "scene-0332", "scene-0344", "scene-0345", "scene-0346", "scene-0349",
+ "scene-0350", "scene-0351", "scene-0352", "scene-0353", "scene-0354",
+ "scene-0355", "scene-0356", "scene-0357", "scene-0358", "scene-0359",
+ "scene-0360", "scene-0361", "scene-0362", "scene-0363", "scene-0364",
+ "scene-0365", "scene-0367", "scene-0370", "scene-0371", "scene-0372",
+ "scene-0373", "scene-0374", "scene-0375", "scene-0376", "scene-0377",
+ "scene-0379", "scene-0380", "scene-0381", "scene-0382", "scene-0383",
+ "scene-0384", "scene-0385", "scene-0386", "scene-0388", "scene-0399",
+ "scene-0400", "scene-0401", "scene-0402", "scene-0403", "scene-0405",
+ "scene-0406", "scene-0407", "scene-0408", "scene-0420", "scene-0421",
+ "scene-0422", "scene-0423", "scene-0424", "scene-0425", "scene-0426",
+ "scene-0427", "scene-0428", "scene-0429", "scene-0430", "scene-0431",
+ "scene-0432", "scene-0433", "scene-0434", "scene-0435", "scene-0436",
+ "scene-0437", "scene-0438", "scene-0439", "scene-0440", "scene-0441",
+ "scene-0442", "scene-0443", "scene-0444", "scene-0445", "scene-0446",
+ "scene-0447", "scene-0448", "scene-0449", "scene-0450", "scene-0451",
+ "scene-0452", "scene-0453", "scene-0454", "scene-0455", "scene-0456",
+ "scene-0457", "scene-0458", "scene-0459", "scene-0461", "scene-0462",
+ "scene-0463", "scene-0464", "scene-0465", "scene-0467", "scene-0468",
+ "scene-0469", "scene-0471", "scene-0472", "scene-0474", "scene-0475",
+ "scene-0476", "scene-0477", "scene-0478", "scene-0479", "scene-0480",
+ "scene-0499", "scene-0500", "scene-0501", "scene-0502", "scene-0504",
+ "scene-0505", "scene-0506", "scene-0507", "scene-0508", "scene-0509",
+ "scene-0510", "scene-0511", "scene-0512", "scene-0513", "scene-0514",
+ "scene-0515", "scene-0517", "scene-0518", "scene-0519", "scene-0520",
+ "scene-0521", "scene-0522", "scene-0523", "scene-0524", "scene-0552",
+ "scene-0553", "scene-0554", "scene-0555", "scene-0559", "scene-0560",
+ "scene-0561", "scene-0562", "scene-0563", "scene-0564", "scene-0565",
+ "scene-0584", "scene-0585", "scene-0586", "scene-0587", "scene-0588",
+ "scene-0589", "scene-0590", "scene-0591", "scene-0592", "scene-0593",
+ "scene-0594", "scene-0595", "scene-0596", "scene-0597", "scene-0598",
+ "scene-0599", "scene-0600", "scene-0625", "scene-0626", "scene-0627",
+ "scene-0629", "scene-0630", "scene-0632", "scene-0633", "scene-0634",
+ "scene-0635", "scene-0636", "scene-0637", "scene-0638", "scene-0639",
+ "scene-0640", "scene-0652", "scene-0653", "scene-0654", "scene-0655",
+ "scene-0656", "scene-0657", "scene-0658", "scene-0659", "scene-0660",
+ "scene-0661", "scene-0662", "scene-0663", "scene-0664", "scene-0665",
+ "scene-0666", "scene-0667", "scene-0668", "scene-0669", "scene-0670",
+ "scene-0671", "scene-0672", "scene-0673", "scene-0674", "scene-0675",
+ "scene-0676", "scene-0677", "scene-0678", "scene-0679", "scene-0681",
+ "scene-0683", "scene-0684", "scene-0685", "scene-0686", "scene-0687",
+ "scene-0688", "scene-0689", "scene-0695", "scene-0696", "scene-0697",
+ "scene-0698", "scene-0700", "scene-0701", "scene-0703", "scene-0704",
+ "scene-0705", "scene-0706", "scene-0707", "scene-0708", "scene-0709",
+ "scene-0710", "scene-0711", "scene-0712", "scene-0713", "scene-0714",
+ "scene-0715", "scene-0716", "scene-0717", "scene-0718", "scene-0719",
+ "scene-0726", "scene-0727", "scene-0728", "scene-0730", "scene-0731",
+ "scene-0733", "scene-0734", "scene-0735", "scene-0736", "scene-0737",
+ "scene-0738", "scene-0780", "scene-0781", "scene-0782", "scene-0783",
+ "scene-0784", "scene-0786", "scene-0787", "scene-0789", "scene-0790",
+ "scene-0791", "scene-0792", "scene-0802", "scene-0806", "scene-0808",
+ "scene-0809", "scene-0810", "scene-0811", "scene-0812", "scene-0813",
+ "scene-0815", "scene-0816", "scene-0817", "scene-0819", "scene-0820",
+ "scene-0821", "scene-0822", "scene-0847", "scene-0848", "scene-0849",
+ "scene-0850", "scene-0851", "scene-0852", "scene-0853", "scene-0854",
+ "scene-0855", "scene-0856", "scene-0858", "scene-0860", "scene-0861",
+ "scene-0862", "scene-0863", "scene-0864", "scene-0865", "scene-0866",
+ "scene-0868", "scene-0869", "scene-0870", "scene-0871", "scene-0872",
+ "scene-0873", "scene-0875", "scene-0876", "scene-0877", "scene-0878",
+ "scene-0880", "scene-0882", "scene-0883", "scene-0884", "scene-0885",
+ "scene-0886", "scene-0887", "scene-0888", "scene-0889", "scene-0890",
+ "scene-0891", "scene-0892", "scene-0893", "scene-0894", "scene-0895",
+ "scene-0896", "scene-0897", "scene-0898", "scene-0899", "scene-0900",
+ "scene-0901", "scene-0902", "scene-0903", "scene-0904", "scene-0905",
+ "scene-0906", "scene-0907", "scene-0908", "scene-0909", "scene-0916",
+ "scene-0917", "scene-0921", "scene-0922", "scene-0923", "scene-0925",
+ "scene-0926", "scene-0927", "scene-0928", "scene-0929", "scene-0930",
+ "scene-0931", "scene-0945", "scene-0947", "scene-0949", "scene-0952",
+ "scene-0953", "scene-0955", "scene-0956", "scene-0957", "scene-0958",
+ "scene-0959", "scene-0960", "scene-0961", "scene-0966", "scene-0967",
+ "scene-0968", "scene-0969", "scene-0971", "scene-0972", "scene-0975",
+ "scene-0976", "scene-0977", "scene-0978", "scene-0979", "scene-0980",
+ "scene-0981", "scene-0982", "scene-0983", "scene-0984", "scene-0988",
+ "scene-0989", "scene-0990", "scene-0991", "scene-0992", "scene-0994",
+ "scene-0995", "scene-0996", "scene-0997", "scene-0998", "scene-0999",
+ "scene-1000", "scene-1001", "scene-1004", "scene-1005", "scene-1006",
+ "scene-1007", "scene-1008", "scene-1009", "scene-1010", "scene-1011",
+ "scene-1012", "scene-1013", "scene-1014", "scene-1015", "scene-1019",
+ "scene-1020", "scene-1021", "scene-1022", "scene-1023", "scene-1024",
+ "scene-1025", "scene-1044", "scene-1045", "scene-1046", "scene-1047",
+ "scene-1048", "scene-1049", "scene-1050", "scene-1051", "scene-1052",
+ "scene-1053", "scene-1054", "scene-1064", "scene-1065", "scene-1066",
+ "scene-1067", "scene-1068", "scene-1069", "scene-1070", "scene-1071",
+ "scene-1072", "scene-1073", "scene-1074", "scene-1075", "scene-1076",
+ "scene-1077", "scene-1078", "scene-1079", "scene-1080", "scene-1081",
+ "scene-1082", "scene-1083", "scene-1084", "scene-1085", "scene-1086",
+ "scene-1087", "scene-1088", "scene-1089", "scene-1090", "scene-1091",
+ "scene-1092", "scene-1093", "scene-1094", "scene-1095", "scene-1096",
+ "scene-1097", "scene-1098", "scene-1099", "scene-1100", "scene-1101",
+ "scene-1102", "scene-1104", "scene-1105", "scene-1106", "scene-1107",
+ "scene-1108", "scene-1109", "scene-1110"]
+
+ val_roddick_scenes = [
+ "scene-0001", "scene-0010", "scene-0011", "scene-0020", "scene-0038",
+ "scene-0041", "scene-0053", "scene-0054", "scene-0121", "scene-0122",
+ "scene-0139", "scene-0152", "scene-0160", "scene-0184", "scene-0269",
+ "scene-0347", "scene-0348", "scene-0366", "scene-0368", "scene-0369",
+ "scene-0378", "scene-0389", "scene-0390", "scene-0391", "scene-0392",
+ "scene-0393", "scene-0394", "scene-0395", "scene-0396", "scene-0397",
+ "scene-0398", "scene-0411", "scene-0412", "scene-0413", "scene-0414",
+ "scene-0415", "scene-0416", "scene-0417", "scene-0418", "scene-0419",
+ "scene-0525", "scene-0526", "scene-0527", "scene-0528", "scene-0529",
+ "scene-0530", "scene-0531", "scene-0532", "scene-0533", "scene-0534",
+ "scene-0535", "scene-0536", "scene-0537", "scene-0538", "scene-0539",
+ "scene-0541", "scene-0542", "scene-0543", "scene-0544", "scene-0545",
+ "scene-0546", "scene-0556", "scene-0557", "scene-0558", "scene-0566",
+ "scene-0568", "scene-0570", "scene-0571", "scene-0572", "scene-0573",
+ "scene-0574", "scene-0575", "scene-0576", "scene-0577", "scene-0578",
+ "scene-0580", "scene-0582", "scene-0583", "scene-0642", "scene-0643",
+ "scene-0644", "scene-0645", "scene-0646", "scene-0647", "scene-0648",
+ "scene-0649", "scene-0650", "scene-0651", "scene-0739", "scene-0740",
+ "scene-0741", "scene-0744", "scene-0746", "scene-0747", "scene-0749",
+ "scene-0750", "scene-0751", "scene-0752", "scene-0757", "scene-0758",
+ "scene-0759", "scene-0760", "scene-0761", "scene-0762", "scene-0763",
+ "scene-0764", "scene-0765", "scene-0767", "scene-0768", "scene-0769",
+ "scene-0770", "scene-0771", "scene-0775", "scene-0777", "scene-0778",
+ "scene-0794", "scene-0795", "scene-0796", "scene-0797", "scene-0798",
+ "scene-0799", "scene-0800", "scene-0803", "scene-0804", "scene-0911",
+ "scene-0912", "scene-0913", "scene-0914", "scene-0915", "scene-0919",
+ "scene-0920", "scene-0924", "scene-0962", "scene-0963", "scene-1002",
+ "scene-1003", "scene-1016", "scene-1017", "scene-1018", "scene-1055",
+ "scene-1056", "scene-1057", "scene-1058", "scene-1059", "scene-1060",
+ "scene-1061", "scene-1062", "scene-1063"]
+
+
+ calibration_roddick_scenes = [
+ "scene-0852", "scene-0429", "scene-0956", "scene-0194", "scene-0811",
+ "scene-1110", "scene-1107", "scene-0294", "scene-0900", "scene-0596",
+ "scene-0296", "scene-0885", "scene-0866", "scene-0105", "scene-0782",
+ "scene-0191", "scene-0876", "scene-0133", "scene-0231", "scene-0847",
+ "scene-0363", "scene-0026", "scene-0791", "scene-0909", "scene-0002",
+ "scene-0283", "scene-0007", "scene-0251", "scene-1100", "scene-0668",
+ "scene-0584", "scene-0287", "scene-0260", "scene-0171", "scene-0789",
+ "scene-0108", "scene-0190", "scene-0206", "scene-0635", "scene-0815",
+ "scene-0058", "scene-0710", "scene-0302", "scene-0639", "scene-0166",
+ "scene-0094", "scene-0735", "scene-0321", "scene-1091", "scene-0344"
+ ]
+
+
+ scenes_dict = {
+ "train": train_roddick_scenes,
+ "val": val_roddick_scenes,
+ "calibration": calibration_roddick_scenes
+ }
+
+ return scenes_dict
\ No newline at end of file
diff --git a/mapper/data/nuscenes/utils.py b/mapper/data/nuscenes/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..865622b5a7faf94e717457eb0263b6f17d34ac5f
--- /dev/null
+++ b/mapper/data/nuscenes/utils.py
@@ -0,0 +1,214 @@
+import os
+import numpy as np
+from shapely import geometry, affinity
+from pyquaternion import Quaternion
+import cv2
+
+from nuscenes.eval.detection.utils import category_to_detection_name
+from nuscenes.eval.detection.constants import DETECTION_NAMES
+from nuscenes.utils.data_classes import LidarPointCloud
+
+from nuscenes.map_expansion.map_api import NuScenesMap
+from shapely.strtree import STRtree
+from collections import OrderedDict
+import torch
+
+def decode_binary_labels(labels, nclass):
+ bits = torch.pow(2, torch.arange(nclass))
+ return (labels & bits.view(-1, 1, 1)) > 0
+
+def transform_polygon(polygon, affine):
+ """
+ Transform a 2D polygon
+ """
+ a, b, tx, c, d, ty = affine.flatten()[:6]
+ return affinity.affine_transform(polygon, [a, b, c, d, tx, ty])
+
+
+def render_polygon(mask, polygon, extents, resolution, value=1):
+ if len(polygon) == 0:
+ return
+ polygon = (polygon - np.array(extents[:2])) / resolution
+ polygon = np.ascontiguousarray(polygon).round().astype(np.int32)
+ cv2.fillConvexPoly(mask, polygon, value)
+
+def transform(matrix, vectors):
+ vectors = np.dot(matrix[:-1, :-1], vectors.T)
+ vectors = vectors.T + matrix[:-1, -1]
+ return vectors
+
+CAMERA_NAMES = ['CAM_FRONT', 'CAM_FRONT_LEFT', 'CAM_FRONT_RIGHT',
+ 'CAM_BACK_LEFT', 'CAM_BACK_RIGHT', 'CAM_BACK']
+
+NUSCENES_CLASS_NAMES = [
+ 'drivable_area', 'ped_crossing', 'walkway', 'carpark', 'car', 'truck',
+ 'bus', 'trailer', 'construction_vehicle', 'pedestrian', 'motorcycle',
+ 'bicycle', 'traffic_cone', 'barrier'
+]
+
+STATIC_CLASSES = ['drivable_area', 'ped_crossing', 'walkway', 'carpark_area']
+
+LOCATIONS = ['boston-seaport', 'singapore-onenorth', 'singapore-queenstown',
+ 'singapore-hollandvillage']
+
+def load_map_data(dataroot, location):
+
+ # Load the NuScenes map object
+ nusc_map = NuScenesMap(dataroot, location)
+
+ map_data = OrderedDict()
+ for layer in STATIC_CLASSES:
+
+ # Retrieve all data associated with the current layer
+ records = getattr(nusc_map, layer)
+ polygons = list()
+
+ # Drivable area records can contain multiple polygons
+ if layer == 'drivable_area':
+ for record in records:
+
+ # Convert each entry in the record into a shapely object
+ for token in record['polygon_tokens']:
+ poly = nusc_map.extract_polygon(token)
+ if poly.is_valid:
+ polygons.append(poly)
+ else:
+ for record in records:
+
+ # Convert each entry in the record into a shapely object
+ poly = nusc_map.extract_polygon(record['polygon_token'])
+ if poly.is_valid:
+ polygons.append(poly)
+
+
+ # Store as an R-Tree for fast intersection queries
+ map_data[layer] = STRtree(polygons)
+
+ return map_data
+
+def iterate_samples(nuscenes, start_token):
+ sample_token = start_token
+ while sample_token != '':
+ sample = nuscenes.get('sample', sample_token)
+ yield sample
+ sample_token = sample['next']
+
+
+def get_map_masks(nuscenes, map_data, sample_data, extents, resolution):
+
+ # Render each layer sequentially
+ layers = [get_layer_mask(nuscenes, polys, sample_data, extents,
+ resolution) for layer, polys in map_data.items()]
+
+ return np.stack(layers, axis=0)
+
+
+def get_layer_mask(nuscenes, polygons, sample_data, extents, resolution):
+
+ # Get the 2D affine transform from bev coords to map coords
+ tfm = get_sensor_transform(nuscenes, sample_data)[[0, 1, 3]][:, [0, 2, 3]]
+ inv_tfm = np.linalg.inv(tfm)
+
+ # Create a patch representing the birds-eye-view region in map coordinates
+ map_patch = geometry.box(*extents)
+ map_patch = transform_polygon(map_patch, tfm)
+
+ # Initialise the map mask
+ x1, z1, x2, z2 = extents
+ mask = np.zeros((int((z2 - z1) / resolution), int((x2 - x1) / resolution)),
+ dtype=np.uint8)
+
+ # Find all polygons which intersect with the area of interest
+ for polygon in polygons.query(map_patch):
+
+ polygon = polygon.intersection(map_patch)
+
+ # Transform into map coordinates
+ polygon = transform_polygon(polygon, inv_tfm)
+
+ # Render the polygon to the mask
+ render_shapely_polygon(mask, polygon, extents, resolution)
+
+ return mask
+
+
+
+
+def get_object_masks(nuscenes, sample_data, extents, resolution):
+
+ # Initialize object masks
+ nclass = len(DETECTION_NAMES) + 1
+ grid_width = int((extents[2] - extents[0]) / resolution)
+ grid_height = int((extents[3] - extents[1]) / resolution)
+ masks = np.zeros((nclass, grid_height, grid_width), dtype=np.uint8)
+
+ # Get the 2D affine transform from bev coords to map coords
+ tfm = get_sensor_transform(nuscenes, sample_data)[[0, 1, 3]][:, [0, 2, 3]]
+ inv_tfm = np.linalg.inv(tfm)
+
+ for box in nuscenes.get_boxes(sample_data['token']):
+
+ # Get the index of the class
+ det_name = category_to_detection_name(box.name)
+ if det_name not in DETECTION_NAMES:
+ class_id = -1
+ else:
+ class_id = DETECTION_NAMES.index(det_name)
+
+ # Get bounding box coordinates in the grid coordinate frame
+ bbox = box.bottom_corners()[:2]
+ local_bbox = np.dot(inv_tfm[:2, :2], bbox).T + inv_tfm[:2, 2]
+
+ # Render the rotated bounding box to the mask
+ render_polygon(masks[class_id], local_bbox, extents, resolution)
+
+ return masks.astype(np.bool)
+
+
+def get_sensor_transform(nuscenes, sample_data):
+
+ # Load sensor transform data
+ sensor = nuscenes.get(
+ 'calibrated_sensor', sample_data['calibrated_sensor_token'])
+ sensor_tfm = make_transform_matrix(sensor)
+
+ # Load ego pose data
+ pose = nuscenes.get('ego_pose', sample_data['ego_pose_token'])
+ pose_tfm = make_transform_matrix(pose)
+
+ return np.dot(pose_tfm, sensor_tfm)
+
+
+def load_point_cloud(nuscenes, sample_data):
+
+ # Load point cloud
+ lidar_path = os.path.join(nuscenes.dataroot, sample_data['filename'])
+ pcl = LidarPointCloud.from_file(lidar_path)
+ return pcl.points[:3, :].T
+
+
+def make_transform_matrix(record):
+ """
+ Create a 4x4 transform matrix from a calibrated_sensor or ego_pose record
+ """
+ transform = np.eye(4)
+ transform[:3, :3] = Quaternion(record['rotation']).rotation_matrix
+ transform[:3, 3] = np.array(record['translation'])
+ return transform
+
+
+def render_shapely_polygon(mask, polygon, extents, resolution):
+
+ if polygon.geom_type == 'Polygon':
+
+ # Render exteriors
+ render_polygon(mask, polygon.exterior.coords, extents, resolution, 1)
+
+ # Render interiors
+ for hole in polygon.interiors:
+ render_polygon(mask, hole.coords, extents, resolution, 0)
+
+ # Handle the case of compound shapes
+ else:
+ for poly in polygon:
+ render_shapely_polygon(mask, poly, extents, resolution)
\ No newline at end of file
diff --git a/mapper/data/schema.py b/mapper/data/schema.py
new file mode 100644
index 0000000000000000000000000000000000000000..cbca169a9a97fbee4e80eeed210879d739daa6f6
--- /dev/null
+++ b/mapper/data/schema.py
@@ -0,0 +1,75 @@
+from dataclasses import dataclass
+from typing import Optional, Any, Dict
+from pathlib import Path
+
+@dataclass
+class AugmentationConfiguration:
+ gaussian_noise: dict
+ brightness_contrast: dict
+
+ enabled: bool = False
+ brightness: float = 0.5
+ contrast: float = 0.5
+ saturation: float = 0.5
+ hue: float = 0.5
+ random_resized_crop: Any = False
+ random_flip: float = 0.5
+
+
+@dataclass(kw_only=True)
+class DataConfiguration:
+ augmentations: AugmentationConfiguration
+
+ loading: Dict[str, Dict[str, Any]]
+
+ target_focal_length: Optional[int] = None
+ reduce_fov: Optional[bool] = None
+ resize_image: Optional[Any] = None
+ pad_to_square: Optional[bool] = None
+ pad_to_multiple: Optional[int] = None
+ gravity_align: Optional[bool] = None
+ rectify_pitch: Optional[bool] = True
+ num_classes: int
+
+ name: str
+ seed: Optional[int] = 0
+ random: Optional[bool] = True
+ num_threads: Optional[int] = None
+
+@dataclass(kw_only=True)
+class MIADataConfiguration(DataConfiguration):
+
+ scenes: list[str]
+ split: Any
+ data_dir: Path
+ pixel_per_meter: int
+ crop_size_meters: int
+
+ name: str = "mapillary"
+ filter_for: Optional[str] = None
+ filter_by_ground_angle: Optional[float] = None
+ min_num_points: int = 0
+
+@dataclass(kw_only=True)
+class KITTIDataConfiguration(DataConfiguration):
+ seam_root_dir: Path
+ dataset_root_dir: Path
+ bev_percentage: float
+
+ pixel_per_meter: int
+ crop_size_meters: int
+
+ class_mapping: Optional[Any] = None
+ percentage: float = 1.0
+
+@dataclass(kw_only=True)
+class NuScenesDataConfiguration(DataConfiguration):
+ data_dir: Path
+ map_dir: Path
+ pixel_per_meter: int
+ crop_size_meters: int
+
+ percentage: float = 1.0
+ class_mapping: Optional[Any] = None
+ version: str = "v1.0-trainval"
+
\ No newline at end of file
diff --git a/mapper/data/sequential.py b/mapper/data/sequential.py
new file mode 100644
index 0000000000000000000000000000000000000000..15eb39fb4c1051413c872b8a8be4c9f24109d19f
--- /dev/null
+++ b/mapper/data/sequential.py
@@ -0,0 +1,45 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+
+import numpy as np
+import torch
+
+
+def chunk_sequence(
+ data,
+ indices,
+ *,
+ names=None,
+ max_length=100,
+ min_length=1,
+ max_delay_s=None,
+ max_inter_dist=None,
+ max_total_dist=None,
+):
+ sort_array = data.get("capture_time", data.get("index"))
+ if sort_array is None:
+ sort_array = indices if names is None else names
+ indices = sorted(indices, key=lambda i: sort_array[i].tolist())
+ centers = torch.stack([data["t_c2w"][i][:2] for i in indices]).numpy()
+ dists = np.linalg.norm(np.diff(centers, axis=0), axis=-1)
+ if "capture_time" in data:
+ times = torch.stack([data["capture_time"][i] for i in indices])
+ times = times.double() / 1e3 # ms to s
+ delays = np.diff(times, axis=0)
+ else:
+ delays = np.zeros_like(dists)
+ chunks = [[indices[0]]]
+ dist_total = 0
+ for dist, delay, idx in zip(dists, delays, indices[1:]):
+ dist_total += dist
+ if (
+ (max_inter_dist is not None and dist > max_inter_dist)
+ or (max_total_dist is not None and dist_total > max_total_dist)
+ or (max_delay_s is not None and delay > max_delay_s)
+ or len(chunks[-1]) >= max_length
+ ):
+ chunks.append([])
+ dist_total = 0
+ chunks[-1].append(idx)
+ chunks = list(filter(lambda c: len(c) >= min_length, chunks))
+ chunks = sorted(chunks, key=len, reverse=True)
+ return chunks
\ No newline at end of file
diff --git a/mapper/data/torch.py b/mapper/data/torch.py
new file mode 100644
index 0000000000000000000000000000000000000000..ed9493b5795fa0322b38d1232efa1b8050a9c5ee
--- /dev/null
+++ b/mapper/data/torch.py
@@ -0,0 +1,102 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+
+import collections
+import os
+
+import torch
+from torch.utils.data import get_worker_info
+from torch.utils.data._utils.collate import (
+ default_collate_err_msg_format,
+ np_str_obj_array_pattern,
+)
+from lightning_fabric.utilities.seed import pl_worker_init_function
+
+def collate(batch):
+ """Difference with PyTorch default_collate: it can stack other tensor-like objects.
+ Adapted from PixLoc, Paul-Edouard Sarlin, ETH Zurich
+ https://github.com/cvg/pixloc
+ Released under the Apache License 2.0
+ """
+ if not isinstance(batch, list): # no batching
+ return batch
+
+ # Filter None Elements
+ batch = [elem for elem in batch if elem is not None]
+ elem = batch[0]
+ elem_type = type(elem)
+ if isinstance(elem, torch.Tensor):
+ out = None
+ if torch.utils.data.get_worker_info() is not None:
+ # If we're in a background process, concatenate directly into a
+ # shared memory tensor to avoid an extra copy
+ numel = sum(x.numel() for x in batch)
+ storage = elem.storage()._new_shared(numel, device=elem.device)
+ out = elem.new(storage).resize_(len(batch), *list(elem.size()))
+ return torch.stack(batch, 0, out=out)
+ elif (
+ elem_type.__module__ == "numpy"
+ and elem_type.__name__ != "str_"
+ and elem_type.__name__ != "string_"
+ ):
+ if elem_type.__name__ == "ndarray" or elem_type.__name__ == "memmap":
+ # array of string classes and object
+ if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
+ raise TypeError(default_collate_err_msg_format.format(elem.dtype))
+
+ return collate([torch.as_tensor(b) for b in batch])
+ elif elem.shape == (): # scalars
+ return torch.as_tensor(batch)
+ elif isinstance(elem, float):
+ return torch.tensor(batch, dtype=torch.float64)
+ elif isinstance(elem, int):
+ return torch.tensor(batch)
+ elif isinstance(elem, (str, bytes)):
+ return batch
+ elif isinstance(elem, collections.abc.Mapping):
+ return {key: collate([d[key] for d in batch]) for key in elem}
+ elif isinstance(elem, tuple) and hasattr(elem, "_fields"): # namedtuple
+ return elem_type(*(collate(samples) for samples in zip(*batch)))
+ elif isinstance(elem, collections.abc.Sequence):
+ # check to make sure that the elements in batch have consistent size
+ it = iter(batch)
+ elem_size = len(next(it))
+ if not all(len(elem) == elem_size for elem in it):
+ raise RuntimeError("each element in list of batch should be of equal size")
+ transposed = zip(*batch)
+ return [collate(samples) for samples in transposed]
+ else:
+ # try to stack anyway in case the object implements stacking.
+ try:
+ return torch.stack(batch, 0)
+ except TypeError as e:
+ if "expected Tensor as element" in str(e):
+ return batch
+ else:
+ raise e
+
+
+def set_num_threads(nt):
+ """Force numpy and other libraries to use a limited number of threads."""
+ try:
+ import mkl
+ except ImportError:
+ pass
+ else:
+ mkl.set_num_threads(nt)
+ torch.set_num_threads(1)
+ os.environ["IPC_ENABLE"] = "1"
+ for o in [
+ "OPENBLAS_NUM_THREADS",
+ "NUMEXPR_NUM_THREADS",
+ "OMP_NUM_THREADS",
+ "MKL_NUM_THREADS",
+ ]:
+ os.environ[o] = str(nt)
+
+
+def worker_init_fn(i):
+ info = get_worker_info()
+ pl_worker_init_function(info.id)
+ num_threads = info.dataset.cfg.get("num_threads")
+ if num_threads is not None:
+ set_num_threads(num_threads)
\ No newline at end of file
diff --git a/mapper/data/utils.py b/mapper/data/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..92610cd667b99ed7db0dae804a17c728baca93e5
--- /dev/null
+++ b/mapper/data/utils.py
@@ -0,0 +1,21 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+
+import numpy as np
+from scipy.spatial.transform import Rotation
+
+
+def crop_map(raster, xy, size, seed=None):
+ h, w = raster.shape[-2:]
+ state = np.random.RandomState(seed)
+ top = state.randint(0, h - size + 1)
+ left = state.randint(0, w - size + 1)
+ raster = raster[..., top : top + size, left : left + size]
+ xy -= np.array([left, top])
+ return raster, xy
+
+
+def decompose_rotmat(R_c2w):
+ R_cv2xyz = Rotation.from_euler("X", -90, degrees=True)
+ rot_w2c = R_cv2xyz * Rotation.from_matrix(R_c2w).inv()
+ roll, pitch, yaw = rot_w2c.as_euler("YXZ", degrees=True)
+ return roll, pitch, yaw
\ No newline at end of file
diff --git a/mapper/mapper.py b/mapper/mapper.py
new file mode 100644
index 0000000000000000000000000000000000000000..43e464fa5c32ffd7bad67c45183f089e20c9d8db
--- /dev/null
+++ b/mapper/mapper.py
@@ -0,0 +1,112 @@
+import time
+import torch
+import hydra
+import pytorch_lightning as pl
+from typing import Any
+
+from hydra.core.config_store import ConfigStore
+from omegaconf import OmegaConf
+from pytorch_lightning.loggers import WandbLogger
+from pytorch_lightning.callbacks import ModelCheckpoint
+
+from pathlib import Path
+from dataclasses import dataclass
+
+from .module import GenericModule
+from .data.module import GenericDataModule
+from .callbacks import EvalSaveCallback, ImageLoggerCallback
+from .models.schema import ModelConfiguration, DINOConfiguration, ResNetConfiguration
+from .data.schema import MIADataConfiguration, KITTIDataConfiguration, NuScenesDataConfiguration
+
+
+@dataclass
+class ExperimentConfiguration:
+ name: str
+
+@dataclass
+class Configuration:
+ model: ModelConfiguration
+ experiment: ExperimentConfiguration
+ data: Any
+ training: Any
+
+
+cs = ConfigStore.instance()
+
+# Store root configuration schema
+cs.store(name="pretrain", node=Configuration)
+cs.store(name="mapper_nuscenes", node=Configuration)
+cs.store(name="mapper_kitti", node=Configuration)
+
+# Store data configuration schema
+cs.store(group="schema/data", name="mia",
+ node=MIADataConfiguration, package="data")
+cs.store(group="schema/data", name="kitti", node=KITTIDataConfiguration, package="data")
+cs.store(group="schema/data", name="nuscenes", node=NuScenesDataConfiguration, package="data")
+
+cs.store(group="model/schema/backbone", name="dino", node=DINOConfiguration, package="model.image_encoder.backbone")
+cs.store(group="model/schema/backbone", name="resnet", node=ResNetConfiguration, package="model.image_encoder.backbone")
+
+
+@hydra.main(version_base=None, config_path="conf", config_name="pretrain")
+def train(cfg: Configuration):
+ OmegaConf.resolve(cfg)
+
+ dm = GenericDataModule(cfg.data)
+
+ model = GenericModule(cfg)
+
+ exp_name_with_time = cfg.experiment.name + \
+ "_" + time.strftime("%Y-%m-%d_%H-%M-%S")
+
+ callbacks: list[pl.Callback]
+
+ if cfg.training.eval:
+ save_dir = Path(cfg.training.save_dir)
+ save_dir.mkdir(parents=True, exist_ok=True)
+
+ callbacks = [
+ EvalSaveCallback(save_dir=save_dir)
+ ]
+
+ logger = None
+ else:
+ callbacks = [
+ ImageLoggerCallback(num_classes=cfg.training.num_classes),
+ ModelCheckpoint(
+ monitor=cfg.training.checkpointing.monitor,
+ save_last=cfg.training.checkpointing.save_last,
+ save_top_k=cfg.training.checkpointing.save_top_k,
+ )
+ ]
+
+ logger = WandbLogger(
+ name=exp_name_with_time,
+ id=exp_name_with_time,
+ entity="mappred-large",
+ project="map-pred-full-v3",
+ )
+
+ logger.watch(model, log="all", log_freq=500)
+
+ if cfg.training.checkpoint is not None:
+ state_dict = torch.load(cfg.training.checkpoint)['state_dict']
+ model.load_state_dict(state_dict, strict=False)
+
+ trainer_args = OmegaConf.to_container(cfg.training.trainer)
+ trainer_args['callbacks'] = callbacks
+ trainer_args['logger'] = logger
+
+ trainer = pl.Trainer(**trainer_args)
+
+ if cfg.training.eval:
+ trainer.test(model, datamodule=dm)
+ else:
+ trainer.fit(model, datamodule=dm)
+
+
+if __name__ == "__main__":
+ pl.seed_everything(42)
+ torch.set_float32_matmul_precision("high")
+
+ train()
diff --git a/mapper/models/__init__.py b/mapper/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d47bc93f87050756a36fde7814d45cfb7ae4b644
--- /dev/null
+++ b/mapper/models/__init__.py
@@ -0,0 +1,28 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+
+# Adapted from PixLoc, Paul-Edouard Sarlin, ETH Zurich
+# https://github.com/cvg/pixloc
+# Released under the Apache License 2.0
+
+import inspect
+
+from .base import BaseModel
+
+
+def get_class(mod_name, base_path, BaseClass):
+ """Get the class object which inherits from BaseClass and is defined in
+ the module named mod_name, child of base_path.
+ """
+ mod_path = "{}.{}".format(base_path, mod_name)
+ mod = __import__(mod_path, fromlist=[""])
+ classes = inspect.getmembers(mod, inspect.isclass)
+ # Filter classes defined in the module
+ classes = [c for c in classes if c[1].__module__ == mod_path]
+ # Filter classes inherited from BaseModel
+ classes = [c for c in classes if issubclass(c[1], BaseClass)]
+ assert len(classes) == 1, classes
+ return classes[0][1]
+
+
+def get_model(name):
+ return get_class(name, __name__, BaseModel)
diff --git a/mapper/models/base.py b/mapper/models/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..93d6458c551d3b1ff97ad07fbeaaee84f79802ba
--- /dev/null
+++ b/mapper/models/base.py
@@ -0,0 +1,59 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+
+# Adapted from PixLoc, Paul-Edouard Sarlin, ETH Zurich
+# https://github.com/cvg/pixloc
+# Released under the Apache License 2.0
+
+"""
+Base class for trainable models.
+"""
+
+from abc import ABCMeta, abstractmethod
+from copy import copy
+
+from omegaconf import OmegaConf
+from torch import nn
+
+
+class BaseModel(nn.Module, metaclass=ABCMeta):
+
+ required_data_keys = []
+ strict_conf = True
+
+ def __init__(self, conf):
+ """Perform some logic and call the _init method of the child model."""
+ super().__init__()
+ self.conf = conf
+ OmegaConf.set_readonly(conf, True)
+ OmegaConf.set_struct(conf, True)
+ self.required_data_keys = copy(self.required_data_keys)
+ self._init(conf)
+
+ def forward(self, data):
+ """Check the data and call the _forward method of the child model."""
+
+ def recursive_key_check(expected, given):
+ for key in expected:
+ assert key in given, f"Missing key {key} in data"
+ if isinstance(expected, dict):
+ recursive_key_check(expected[key], given[key])
+
+ recursive_key_check(self.required_data_keys, data)
+ return self._forward(data)
+
+ @abstractmethod
+ def _init(self, conf):
+ """To be implemented by the child class."""
+ raise NotImplementedError
+
+ @abstractmethod
+ def _forward(self, data):
+ """To be implemented by the child class."""
+ raise NotImplementedError
+
+ def loss(self, pred, data):
+ """To be implemented by the child class."""
+ raise NotImplementedError
+
+ def metrics(self):
+ return {} # no metrics
diff --git a/mapper/models/bev_projection.py b/mapper/models/bev_projection.py
new file mode 100644
index 0000000000000000000000000000000000000000..7fc3d941cc2b919bc5d5b7f15b54a204bfdf3120
--- /dev/null
+++ b/mapper/models/bev_projection.py
@@ -0,0 +1,95 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+
+import torch
+from torch.nn.functional import grid_sample
+
+from ..utils.geometry import from_homogeneous
+from .utils import make_grid
+
+
+class PolarProjectionDepth(torch.nn.Module):
+ def __init__(self, z_max, ppm, scale_range, z_min=None):
+ super().__init__()
+ self.z_max = z_max
+ self.Δ = Δ = 1 / ppm
+ self.z_min = z_min = Δ if z_min is None else z_min
+ self.scale_range = scale_range
+ z_steps = torch.arange(z_min, z_max + Δ, Δ)
+ self.register_buffer("depth_steps", z_steps, persistent=False)
+
+ def sample_depth_scores(self, pixel_scales, camera):
+ scale_steps = camera.f[..., None, 1] / self.depth_steps.flip(-1)
+ log_scale_steps = torch.log2(scale_steps)
+ scale_min, scale_max = self.scale_range
+ log_scale_norm = (log_scale_steps - scale_min) / \
+ (scale_max - scale_min)
+ log_scale_norm = log_scale_norm * 2 - 1 # in [-1, 1]
+
+ values = pixel_scales.flatten(1, 2).unsqueeze(-1)
+ indices = log_scale_norm.unsqueeze(-1)
+ indices = torch.stack([torch.zeros_like(indices), indices], -1)
+ depth_scores = grid_sample(values, indices, align_corners=True)
+ depth_scores = depth_scores.reshape(
+ pixel_scales.shape[:-1] + (len(self.depth_steps),)
+ )
+ return depth_scores
+
+ def forward(
+ self,
+ image,
+ pixel_scales,
+ camera,
+ return_total_score=False,
+ ):
+ depth_scores = self.sample_depth_scores(pixel_scales, camera)
+ depth_prob = torch.softmax(depth_scores, dim=1)
+ image_polar = torch.einsum("...dhw,...hwz->...dzw", image, depth_prob)
+ if return_total_score:
+ cell_score = torch.logsumexp(depth_scores, dim=1, keepdim=True)
+ return image_polar, cell_score.squeeze(1)
+ return image_polar
+
+
+class CartesianProjection(torch.nn.Module):
+ def __init__(self, z_max, x_max, ppm, z_min=None):
+ super().__init__()
+ self.z_max = z_max
+ self.x_max = x_max
+ self.Δ = Δ = 1 / ppm
+ self.z_min = z_min = Δ if z_min is None else z_min
+
+ grid_xz = make_grid(
+ x_max * 2 + Δ, z_max, step_y=Δ, step_x=Δ, orig_y=Δ, orig_x=-x_max, y_up=True
+ )
+ self.register_buffer("grid_xz", grid_xz, persistent=False)
+
+ def grid_to_polar(self, cam):
+ f, c = cam.f[..., 0][..., None, None], cam.c[..., 0][..., None, None]
+ u = from_homogeneous(self.grid_xz).squeeze(-1) * f + c
+ z_idx = (self.grid_xz[..., 1] - self.z_min) / \
+ self.Δ # convert z value to index
+ z_idx = z_idx[None].expand_as(u)
+ grid_polar = torch.stack([u, z_idx], -1)
+ return grid_polar
+
+ def sample_from_polar(self, image_polar, valid_polar, grid_uz):
+ size = grid_uz.new_tensor(image_polar.shape[-2:][::-1])
+ grid_uz_norm = (grid_uz + 0.5) / size * 2 - 1
+ grid_uz_norm = grid_uz_norm * \
+ grid_uz.new_tensor([1, -1]) # y axis is up
+ image_bev = grid_sample(image_polar, grid_uz_norm, align_corners=False)
+
+ if valid_polar is None:
+ valid = torch.ones_like(image_polar[..., :1, :, :])
+ else:
+ valid = valid_polar.to(image_polar)[:, None]
+ valid = grid_sample(valid, grid_uz_norm, align_corners=False)
+ valid = valid.squeeze(1) > (1 - 1e-4)
+
+ return image_bev, valid
+
+ def forward(self, image_polar, valid_polar, cam):
+ grid_uz = self.grid_to_polar(cam)
+ image, valid = self.sample_from_polar(
+ image_polar, valid_polar, grid_uz)
+ return image, valid, grid_uz
diff --git a/mapper/models/dinov2/__init__.py b/mapper/models/dinov2/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ae847e46898077fe3d8701b8a181d7b4e3d41cd9
--- /dev/null
+++ b/mapper/models/dinov2/__init__.py
@@ -0,0 +1,6 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+__version__ = "0.0.1"
diff --git a/mapper/models/dinov2/configs/__init__.py b/mapper/models/dinov2/configs/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..68e0830c62ea19649b6cd2361995f6df309d7640
--- /dev/null
+++ b/mapper/models/dinov2/configs/__init__.py
@@ -0,0 +1,22 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+import pathlib
+
+from omegaconf import OmegaConf
+
+
+def load_config(config_name: str):
+ config_filename = config_name + ".yaml"
+ return OmegaConf.load(pathlib.Path(__file__).parent.resolve() / config_filename)
+
+
+dinov2_default_config = load_config("ssl_default_config")
+
+
+def load_and_merge_config(config_name: str):
+ default_config = OmegaConf.create(dinov2_default_config)
+ loaded_config = load_config(config_name)
+ return OmegaConf.merge(default_config, loaded_config)
diff --git a/mapper/models/dinov2/configs/eval/vitb14_pretrain.yaml b/mapper/models/dinov2/configs/eval/vitb14_pretrain.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..117d0f027ca26cd8ce6c010bb78d5a8fac42c70e
--- /dev/null
+++ b/mapper/models/dinov2/configs/eval/vitb14_pretrain.yaml
@@ -0,0 +1,6 @@
+student:
+ arch: vit_base
+ patch_size: 14
+crops:
+ global_crops_size: 518 # this is to set up the position embeddings properly
+ local_crops_size: 98
\ No newline at end of file
diff --git a/mapper/models/dinov2/configs/eval/vitg14_pretrain.yaml b/mapper/models/dinov2/configs/eval/vitg14_pretrain.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..a96dd5b117b4d59ee210b65037821f1b3e3f16e3
--- /dev/null
+++ b/mapper/models/dinov2/configs/eval/vitg14_pretrain.yaml
@@ -0,0 +1,7 @@
+student:
+ arch: vit_giant2
+ patch_size: 14
+ ffn_layer: swiglufused
+crops:
+ global_crops_size: 518 # this is to set up the position embeddings properly
+ local_crops_size: 98
\ No newline at end of file
diff --git a/mapper/models/dinov2/configs/eval/vitl14_pretrain.yaml b/mapper/models/dinov2/configs/eval/vitl14_pretrain.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..7a984548bd034f762d455419d7193917fa462dd8
--- /dev/null
+++ b/mapper/models/dinov2/configs/eval/vitl14_pretrain.yaml
@@ -0,0 +1,6 @@
+student:
+ arch: vit_large
+ patch_size: 14
+crops:
+ global_crops_size: 518 # this is to set up the position embeddings properly
+ local_crops_size: 98
\ No newline at end of file
diff --git a/mapper/models/dinov2/configs/eval/vits14_pretrain.yaml b/mapper/models/dinov2/configs/eval/vits14_pretrain.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..afbdb4ba14f1c97130a25b579360f4d817cda495
--- /dev/null
+++ b/mapper/models/dinov2/configs/eval/vits14_pretrain.yaml
@@ -0,0 +1,6 @@
+student:
+ arch: vit_small
+ patch_size: 14
+crops:
+ global_crops_size: 518 # this is to set up the position embeddings properly
+ local_crops_size: 98
\ No newline at end of file
diff --git a/mapper/models/dinov2/configs/eval/vits14_reg4_pretrain.yaml b/mapper/models/dinov2/configs/eval/vits14_reg4_pretrain.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..d25fd638389bfba9220792302dc9dbf5d9a2406a
--- /dev/null
+++ b/mapper/models/dinov2/configs/eval/vits14_reg4_pretrain.yaml
@@ -0,0 +1,9 @@
+student:
+ arch: vit_small
+ patch_size: 14
+ num_register_tokens: 4
+ interpolate_antialias: true
+ interpolate_offset: 0.0
+crops:
+ global_crops_size: 518 # this is to set up the position embeddings properly
+ local_crops_size: 98
\ No newline at end of file
diff --git a/mapper/models/dinov2/configs/ssl_default_config.yaml b/mapper/models/dinov2/configs/ssl_default_config.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..64c9fdfce20a1542353f899417a48d75f7bdd569
--- /dev/null
+++ b/mapper/models/dinov2/configs/ssl_default_config.yaml
@@ -0,0 +1,118 @@
+MODEL:
+ WEIGHTS: ''
+compute_precision:
+ grad_scaler: true
+ teacher:
+ backbone:
+ sharding_strategy: SHARD_GRAD_OP
+ mixed_precision:
+ param_dtype: fp16
+ reduce_dtype: fp16
+ buffer_dtype: fp32
+ dino_head:
+ sharding_strategy: SHARD_GRAD_OP
+ mixed_precision:
+ param_dtype: fp16
+ reduce_dtype: fp16
+ buffer_dtype: fp32
+ ibot_head:
+ sharding_strategy: SHARD_GRAD_OP
+ mixed_precision:
+ param_dtype: fp16
+ reduce_dtype: fp16
+ buffer_dtype: fp32
+ student:
+ backbone:
+ sharding_strategy: SHARD_GRAD_OP
+ mixed_precision:
+ param_dtype: fp16
+ reduce_dtype: fp16
+ buffer_dtype: fp32
+ dino_head:
+ sharding_strategy: SHARD_GRAD_OP
+ mixed_precision:
+ param_dtype: fp16
+ reduce_dtype: fp32
+ buffer_dtype: fp32
+ ibot_head:
+ sharding_strategy: SHARD_GRAD_OP
+ mixed_precision:
+ param_dtype: fp16
+ reduce_dtype: fp32
+ buffer_dtype: fp32
+dino:
+ loss_weight: 1.0
+ head_n_prototypes: 65536
+ head_bottleneck_dim: 256
+ head_nlayers: 3
+ head_hidden_dim: 2048
+ koleo_loss_weight: 0.1
+ibot:
+ loss_weight: 1.0
+ mask_sample_probability: 0.5
+ mask_ratio_min_max:
+ - 0.1
+ - 0.5
+ separate_head: false
+ head_n_prototypes: 65536
+ head_bottleneck_dim: 256
+ head_nlayers: 3
+ head_hidden_dim: 2048
+train:
+ batch_size_per_gpu: 64
+ dataset_path: ImageNet:split=TRAIN
+ output_dir: .
+ saveckp_freq: 20
+ seed: 0
+ num_workers: 10
+ OFFICIAL_EPOCH_LENGTH: 1250
+ cache_dataset: true
+ centering: "centering" # or "sinkhorn_knopp"
+student:
+ arch: vit_large
+ patch_size: 16
+ drop_path_rate: 0.3
+ layerscale: 1.0e-05
+ drop_path_uniform: true
+ pretrained_weights: ''
+ ffn_layer: "mlp"
+ block_chunks: 0
+ qkv_bias: true
+ proj_bias: true
+ ffn_bias: true
+ num_register_tokens: 0
+ interpolate_antialias: false
+ interpolate_offset: 0.1
+teacher:
+ momentum_teacher: 0.992
+ final_momentum_teacher: 1
+ warmup_teacher_temp: 0.04
+ teacher_temp: 0.07
+ warmup_teacher_temp_epochs: 30
+optim:
+ epochs: 100
+ weight_decay: 0.04
+ weight_decay_end: 0.4
+ base_lr: 0.004 # learning rate for a batch size of 1024
+ lr: 0. # will be set after applying scaling rule
+ warmup_epochs: 10
+ min_lr: 1.0e-06
+ clip_grad: 3.0
+ freeze_last_layer_epochs: 1
+ scaling_rule: sqrt_wrt_1024
+ patch_embed_lr_mult: 0.2
+ layerwise_decay: 0.9
+ adamw_beta1: 0.9
+ adamw_beta2: 0.999
+crops:
+ global_crops_scale:
+ - 0.32
+ - 1.0
+ local_crops_number: 8
+ local_crops_scale:
+ - 0.05
+ - 0.32
+ global_crops_size: 224
+ local_crops_size: 96
+evaluation:
+ eval_period_iterations: 12500
\ No newline at end of file
diff --git a/mapper/models/dinov2/configs/train/vitg14.yaml b/mapper/models/dinov2/configs/train/vitg14.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..d05cf0d59e07ac6e4a2b0f9bdcb6131d7c508962
--- /dev/null
+++ b/mapper/models/dinov2/configs/train/vitg14.yaml
@@ -0,0 +1,26 @@
+dino:
+ head_n_prototypes: 131072
+ head_bottleneck_dim: 384
+ibot:
+ separate_head: true
+ head_n_prototypes: 131072
+train:
+ batch_size_per_gpu: 12
+ dataset_path: ImageNet22k
+ centering: sinkhorn_knopp
+student:
+ arch: vit_giant2
+ patch_size: 14
+ drop_path_rate: 0.4
+ ffn_layer: swiglufused
+ block_chunks: 4
+teacher:
+ momentum_teacher: 0.994
+optim:
+ epochs: 500
+ weight_decay_end: 0.2
+ base_lr: 2.0e-04 # learning rate for a batch size of 1024
+ warmup_epochs: 80
+ layerwise_decay: 1.0
+crops:
+ local_crops_size: 98
\ No newline at end of file
diff --git a/mapper/models/dinov2/configs/train/vitl14.yaml b/mapper/models/dinov2/configs/train/vitl14.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..d9b491dcc6a522c71328fc2933dd0501123c8f6b
--- /dev/null
+++ b/mapper/models/dinov2/configs/train/vitl14.yaml
@@ -0,0 +1,26 @@
+dino:
+ head_n_prototypes: 131072
+ head_bottleneck_dim: 384
+ibot:
+ separate_head: true
+ head_n_prototypes: 131072
+train:
+ batch_size_per_gpu: 32
+ dataset_path: ImageNet22k
+ centering: sinkhorn_knopp
+student:
+ arch: vit_large
+ patch_size: 14
+ drop_path_rate: 0.4
+ ffn_layer: swiglufused
+ block_chunks: 4
+teacher:
+ momentum_teacher: 0.994
+optim:
+ epochs: 500
+ weight_decay_end: 0.2
+ base_lr: 2.0e-04 # learning rate for a batch size of 1024
+ warmup_epochs: 80
+ layerwise_decay: 1.0
+crops:
+ local_crops_size: 98
\ No newline at end of file
diff --git a/mapper/models/dinov2/configs/train/vitl16_short.yaml b/mapper/models/dinov2/configs/train/vitl16_short.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..3e7e72864c92175a1354142ac1d64da8070d1e5e
--- /dev/null
+++ b/mapper/models/dinov2/configs/train/vitl16_short.yaml
@@ -0,0 +1,6 @@
+# this corresponds to the default config
+train:
+ dataset_path: ImageNet:split=TRAIN
+ batch_size_per_gpu: 64
+student:
+ block_chunks: 4
diff --git a/mapper/models/dinov2/data/__init__.py b/mapper/models/dinov2/data/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2ded47ea63a7b184ff74a040e2c2c514cda273ef
--- /dev/null
+++ b/mapper/models/dinov2/data/__init__.py
@@ -0,0 +1,10 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+from .adapters import DatasetWithEnumeratedTargets
+from .loaders import make_data_loader, make_dataset, SamplerType
+from .collate import collate_data_and_cast
+from .masking import MaskingGenerator
+from .augmentations import DataAugmentationDINO
diff --git a/mapper/models/dinov2/data/adapters.py b/mapper/models/dinov2/data/adapters.py
new file mode 100644
index 0000000000000000000000000000000000000000..2097bad046fb1052267d5f2bb99c798045f00c92
--- /dev/null
+++ b/mapper/models/dinov2/data/adapters.py
@@ -0,0 +1,28 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+from typing import Any, Tuple
+
+from torch.utils.data import Dataset
+
+
+class DatasetWithEnumeratedTargets(Dataset):
+ def __init__(self, dataset):
+ self._dataset = dataset
+
+ def get_image_data(self, index: int) -> bytes:
+ return self._dataset.get_image_data(index)
+
+ def get_target(self, index: int) -> Tuple[Any, int]:
+ target = self._dataset.get_target(index)
+ return (index, target)
+
+ def __getitem__(self, index: int) -> Tuple[Any, Tuple[Any, int]]:
+ image, target = self._dataset[index]
+ target = index if target is None else target
+ return image, (index, target)
+
+ def __len__(self) -> int:
+ return len(self._dataset)
diff --git a/mapper/models/dinov2/data/augmentations.py b/mapper/models/dinov2/data/augmentations.py
new file mode 100644
index 0000000000000000000000000000000000000000..05b1eaa942c14f75b88d9e14732e141e8909b0a1
--- /dev/null
+++ b/mapper/models/dinov2/data/augmentations.py
@@ -0,0 +1,118 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+import logging
+
+from torchvision import transforms
+
+from .transforms import (
+ GaussianBlur,
+ make_normalize_transform,
+)
+
+
+logger = logging.getLogger("dinov2")
+
+
+class DataAugmentationDINO(object):
+ def __init__(
+ self,
+ global_crops_scale,
+ local_crops_scale,
+ local_crops_number,
+ global_crops_size=224,
+ local_crops_size=96,
+ ):
+ self.global_crops_scale = global_crops_scale
+ self.local_crops_scale = local_crops_scale
+ self.local_crops_number = local_crops_number
+ self.global_crops_size = global_crops_size
+ self.local_crops_size = local_crops_size
+
+ logger.info("###################################")
+ logger.info("Using data augmentation parameters:")
+ logger.info(f"global_crops_scale: {global_crops_scale}")
+ logger.info(f"local_crops_scale: {local_crops_scale}")
+ logger.info(f"local_crops_number: {local_crops_number}")
+ logger.info(f"global_crops_size: {global_crops_size}")
+ logger.info(f"local_crops_size: {local_crops_size}")
+ logger.info("###################################")
+
+ # random resized crop and flip
+ self.geometric_augmentation_global = transforms.Compose(
+ [
+ transforms.RandomResizedCrop(
+ global_crops_size, scale=global_crops_scale, interpolation=transforms.InterpolationMode.BICUBIC
+ ),
+ transforms.RandomHorizontalFlip(p=0.5),
+ ]
+ )
+
+ self.geometric_augmentation_local = transforms.Compose(
+ [
+ transforms.RandomResizedCrop(
+ local_crops_size, scale=local_crops_scale, interpolation=transforms.InterpolationMode.BICUBIC
+ ),
+ transforms.RandomHorizontalFlip(p=0.5),
+ ]
+ )
+
+ # color distorsions / blurring
+ color_jittering = transforms.Compose(
+ [
+ transforms.RandomApply(
+ [transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1)],
+ p=0.8,
+ ),
+ transforms.RandomGrayscale(p=0.2),
+ ]
+ )
+
+ global_transfo1_extra = GaussianBlur(p=1.0)
+
+ global_transfo2_extra = transforms.Compose(
+ [
+ GaussianBlur(p=0.1),
+ transforms.RandomSolarize(threshold=128, p=0.2),
+ ]
+ )
+
+ local_transfo_extra = GaussianBlur(p=0.5)
+
+ # normalization
+ self.normalize = transforms.Compose(
+ [
+ transforms.ToTensor(),
+ make_normalize_transform(),
+ ]
+ )
+
+ self.global_transfo1 = transforms.Compose([color_jittering, global_transfo1_extra, self.normalize])
+ self.global_transfo2 = transforms.Compose([color_jittering, global_transfo2_extra, self.normalize])
+ self.local_transfo = transforms.Compose([color_jittering, local_transfo_extra, self.normalize])
+
+ def __call__(self, image):
+ output = {}
+
+ # global crops:
+ im1_base = self.geometric_augmentation_global(image)
+ global_crop_1 = self.global_transfo1(im1_base)
+
+ im2_base = self.geometric_augmentation_global(image)
+ global_crop_2 = self.global_transfo2(im2_base)
+
+ output["global_crops"] = [global_crop_1, global_crop_2]
+
+ # global crops for teacher:
+ output["global_crops_teacher"] = [global_crop_1, global_crop_2]
+
+ # local crops:
+ local_crops = [
+ self.local_transfo(self.geometric_augmentation_local(image)) for _ in range(self.local_crops_number)
+ ]
+ output["local_crops"] = local_crops
+ output["offsets"] = ()
+
+ return output
diff --git a/mapper/models/dinov2/data/collate.py b/mapper/models/dinov2/data/collate.py
new file mode 100644
index 0000000000000000000000000000000000000000..b3e32f357a76e6f32162cee14cb6ae1665a4827a
--- /dev/null
+++ b/mapper/models/dinov2/data/collate.py
@@ -0,0 +1,49 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+import torch
+import random
+
+
+def collate_data_and_cast(samples_list, mask_ratio_tuple, mask_probability, dtype, n_tokens=None, mask_generator=None):
+ # dtype = torch.half # TODO: Remove
+
+ n_global_crops = len(samples_list[0][0]["global_crops"])
+ n_local_crops = len(samples_list[0][0]["local_crops"])
+
+ collated_global_crops = torch.stack([s[0]["global_crops"][i] for i in range(n_global_crops) for s in samples_list])
+
+ collated_local_crops = torch.stack([s[0]["local_crops"][i] for i in range(n_local_crops) for s in samples_list])
+
+ B = len(collated_global_crops)
+ N = n_tokens
+ n_samples_masked = int(B * mask_probability)
+ probs = torch.linspace(*mask_ratio_tuple, n_samples_masked + 1)
+ upperbound = 0
+ masks_list = []
+ for i in range(0, n_samples_masked):
+ prob_min = probs[i]
+ prob_max = probs[i + 1]
+ masks_list.append(torch.BoolTensor(mask_generator(int(N * random.uniform(prob_min, prob_max)))))
+ upperbound += int(N * prob_max)
+ for i in range(n_samples_masked, B):
+ masks_list.append(torch.BoolTensor(mask_generator(0)))
+
+ random.shuffle(masks_list)
+
+ collated_masks = torch.stack(masks_list).flatten(1)
+ mask_indices_list = collated_masks.flatten().nonzero().flatten()
+
+ masks_weight = (1 / collated_masks.sum(-1).clamp(min=1.0)).unsqueeze(-1).expand_as(collated_masks)[collated_masks]
+
+ return {
+ "collated_global_crops": collated_global_crops.to(dtype),
+ "collated_local_crops": collated_local_crops.to(dtype),
+ "collated_masks": collated_masks,
+ "mask_indices_list": mask_indices_list,
+ "masks_weight": masks_weight,
+ "upperbound": upperbound,
+ "n_masked_patches": torch.full((1,), fill_value=mask_indices_list.shape[0], dtype=torch.long),
+ }
diff --git a/mapper/models/dinov2/data/loaders.py b/mapper/models/dinov2/data/loaders.py
new file mode 100644
index 0000000000000000000000000000000000000000..d6a2f0210efa0fa96be764665b5d6792191b1e72
--- /dev/null
+++ b/mapper/models/dinov2/data/loaders.py
@@ -0,0 +1,222 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+import logging
+from enum import Enum
+from typing import Any, Callable, List, Optional, TypeVar
+
+import torch
+from torch.utils.data import Sampler
+
+from .datasets import ImageNet, ImageNet22k
+from .samplers import EpochSampler, InfiniteSampler, ShardedInfiniteSampler
+
+
+logger = logging.getLogger("dinov2")
+
+
+class SamplerType(Enum):
+ DISTRIBUTED = 0
+ EPOCH = 1
+ INFINITE = 2
+ SHARDED_INFINITE = 3
+ SHARDED_INFINITE_NEW = 4
+
+
+def _make_bool_str(b: bool) -> str:
+ return "yes" if b else "no"
+
+
+def _make_sample_transform(image_transform: Optional[Callable] = None, target_transform: Optional[Callable] = None):
+ def transform(sample):
+ image, target = sample
+ if image_transform is not None:
+ image = image_transform(image)
+ if target_transform is not None:
+ target = target_transform(target)
+ return image, target
+
+ return transform
+
+
+def _parse_dataset_str(dataset_str: str):
+ tokens = dataset_str.split(":")
+
+ name = tokens[0]
+ kwargs = {}
+
+ for token in tokens[1:]:
+ key, value = token.split("=")
+ assert key in ("root", "extra", "split")
+ kwargs[key] = value
+
+ if name == "ImageNet":
+ class_ = ImageNet
+ if "split" in kwargs:
+ kwargs["split"] = ImageNet.Split[kwargs["split"]]
+ elif name == "ImageNet22k":
+ class_ = ImageNet22k
+ else:
+ raise ValueError(f'Unsupported dataset "{name}"')
+
+ return class_, kwargs
+
+
+def make_dataset(
+ *,
+ dataset_str: str,
+ transform: Optional[Callable] = None,
+ target_transform: Optional[Callable] = None,
+):
+ """
+ Creates a dataset with the specified parameters.
+
+ Args:
+ dataset_str: A dataset string description (e.g. ImageNet:split=TRAIN).
+ transform: A transform to apply to images.
+ target_transform: A transform to apply to targets.
+
+ Returns:
+ The created dataset.
+ """
+ logger.info(f'using dataset: "{dataset_str}"')
+
+ class_, kwargs = _parse_dataset_str(dataset_str)
+ dataset = class_(transform=transform, target_transform=target_transform, **kwargs)
+
+ logger.info(f"# of dataset samples: {len(dataset):,d}")
+
+ # Aggregated datasets do not expose (yet) these attributes, so add them.
+ if not hasattr(dataset, "transform"):
+ setattr(dataset, "transform", transform)
+ if not hasattr(dataset, "target_transform"):
+ setattr(dataset, "target_transform", target_transform)
+
+ return dataset
+
+
+def _make_sampler(
+ *,
+ dataset,
+ type: Optional[SamplerType] = None,
+ shuffle: bool = False,
+ seed: int = 0,
+ size: int = -1,
+ advance: int = 0,
+) -> Optional[Sampler]:
+ sample_count = len(dataset)
+
+ if type == SamplerType.INFINITE:
+ logger.info("sampler: infinite")
+ if size > 0:
+ raise ValueError("sampler size > 0 is invalid")
+ return InfiniteSampler(
+ sample_count=sample_count,
+ shuffle=shuffle,
+ seed=seed,
+ advance=advance,
+ )
+ elif type in (SamplerType.SHARDED_INFINITE, SamplerType.SHARDED_INFINITE_NEW):
+ logger.info("sampler: sharded infinite")
+ if size > 0:
+ raise ValueError("sampler size > 0 is invalid")
+ # TODO: Remove support for old shuffling
+ use_new_shuffle_tensor_slice = type == SamplerType.SHARDED_INFINITE_NEW
+ return ShardedInfiniteSampler(
+ sample_count=sample_count,
+ shuffle=shuffle,
+ seed=seed,
+ advance=advance,
+ use_new_shuffle_tensor_slice=use_new_shuffle_tensor_slice,
+ )
+ elif type == SamplerType.EPOCH:
+ logger.info("sampler: epoch")
+ if advance > 0:
+ raise NotImplementedError("sampler advance > 0 is not supported")
+ size = size if size > 0 else sample_count
+ logger.info(f"# of samples / epoch: {size:,d}")
+ return EpochSampler(
+ size=size,
+ sample_count=sample_count,
+ shuffle=shuffle,
+ seed=seed,
+ )
+ elif type == SamplerType.DISTRIBUTED:
+ logger.info("sampler: distributed")
+ if size > 0:
+ raise ValueError("sampler size > 0 is invalid")
+ if advance > 0:
+ raise ValueError("sampler advance > 0 is invalid")
+ return torch.utils.data.DistributedSampler(
+ dataset=dataset,
+ shuffle=shuffle,
+ seed=seed,
+ drop_last=False,
+ )
+
+ logger.info("sampler: none")
+ return None
+
+
+T = TypeVar("T")
+
+
+def make_data_loader(
+ *,
+ dataset,
+ batch_size: int,
+ num_workers: int,
+ shuffle: bool = True,
+ seed: int = 0,
+ sampler_type: Optional[SamplerType] = SamplerType.INFINITE,
+ sampler_size: int = -1,
+ sampler_advance: int = 0,
+ drop_last: bool = True,
+ persistent_workers: bool = False,
+ collate_fn: Optional[Callable[[List[T]], Any]] = None,
+):
+ """
+ Creates a data loader with the specified parameters.
+
+ Args:
+ dataset: A dataset (third party, LaViDa or WebDataset).
+ batch_size: The size of batches to generate.
+ num_workers: The number of workers to use.
+ shuffle: Whether to shuffle samples.
+ seed: The random seed to use.
+ sampler_type: Which sampler to use: EPOCH, INFINITE, SHARDED_INFINITE, SHARDED_INFINITE_NEW, DISTRIBUTED or None.
+ sampler_size: The number of images per epoch (when applicable) or -1 for the entire dataset.
+ sampler_advance: How many samples to skip (when applicable).
+ drop_last: Whether the last non-full batch of data should be dropped.
+ persistent_workers: maintain the workers Dataset instances alive after a dataset has been consumed once.
+ collate_fn: Function that performs batch collation
+ """
+
+ sampler = _make_sampler(
+ dataset=dataset,
+ type=sampler_type,
+ shuffle=shuffle,
+ seed=seed,
+ size=sampler_size,
+ advance=sampler_advance,
+ )
+
+ logger.info("using PyTorch data loader")
+ data_loader = torch.utils.data.DataLoader(
+ dataset,
+ sampler=sampler,
+ batch_size=batch_size,
+ num_workers=num_workers,
+ pin_memory=True,
+ drop_last=drop_last,
+ persistent_workers=persistent_workers,
+ collate_fn=collate_fn,
+ )
+
+ try:
+ logger.info(f"# of batches: {len(data_loader):,d}")
+ except TypeError: # data loader has no length
+ logger.info("infinite data loader")
+ return data_loader
diff --git a/mapper/models/dinov2/data/masking.py b/mapper/models/dinov2/data/masking.py
new file mode 100644
index 0000000000000000000000000000000000000000..ab12aa7bf138b916b16a9a2ed1a628a2759dbec6
--- /dev/null
+++ b/mapper/models/dinov2/data/masking.py
@@ -0,0 +1,86 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+import random
+import math
+import numpy as np
+
+
+class MaskingGenerator:
+ def __init__(
+ self,
+ input_size,
+ num_masking_patches=None,
+ min_num_patches=4,
+ max_num_patches=None,
+ min_aspect=0.3,
+ max_aspect=None,
+ ):
+ if not isinstance(input_size, tuple):
+ input_size = (input_size,) * 2
+ self.height, self.width = input_size
+
+ self.num_patches = self.height * self.width
+ self.num_masking_patches = num_masking_patches
+
+ self.min_num_patches = min_num_patches
+ self.max_num_patches = num_masking_patches if max_num_patches is None else max_num_patches
+
+ max_aspect = max_aspect or 1 / min_aspect
+ self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect))
+
+ def __repr__(self):
+ repr_str = "Generator(%d, %d -> [%d ~ %d], max = %d, %.3f ~ %.3f)" % (
+ self.height,
+ self.width,
+ self.min_num_patches,
+ self.max_num_patches,
+ self.num_masking_patches,
+ self.log_aspect_ratio[0],
+ self.log_aspect_ratio[1],
+ )
+ return repr_str
+
+ def get_shape(self):
+ return self.height, self.width
+
+ def _mask(self, mask, max_mask_patches):
+ delta = 0
+ for _ in range(10):
+ target_area = random.uniform(self.min_num_patches, max_mask_patches)
+ aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
+ h = int(round(math.sqrt(target_area * aspect_ratio)))
+ w = int(round(math.sqrt(target_area / aspect_ratio)))
+ if w < self.width and h < self.height:
+ top = random.randint(0, self.height - h)
+ left = random.randint(0, self.width - w)
+
+ num_masked = mask[top : top + h, left : left + w].sum()
+ # Overlap
+ if 0 < h * w - num_masked <= max_mask_patches:
+ for i in range(top, top + h):
+ for j in range(left, left + w):
+ if mask[i, j] == 0:
+ mask[i, j] = 1
+ delta += 1
+
+ if delta > 0:
+ break
+ return delta
+
+ def __call__(self, num_masking_patches=0):
+ mask = np.zeros(shape=self.get_shape(), dtype=bool)
+ mask_count = 0
+ while mask_count < num_masking_patches:
+ max_mask_patches = num_masking_patches - mask_count
+ max_mask_patches = min(max_mask_patches, self.max_num_patches)
+
+ delta = self._mask(mask, max_mask_patches)
+ if delta == 0:
+ break
+ else:
+ mask_count += delta
+
+ return mask
diff --git a/mapper/models/dinov2/data/samplers.py b/mapper/models/dinov2/data/samplers.py
new file mode 100644
index 0000000000000000000000000000000000000000..6562197d94652bb9a75a5fc722fcb2c65ca161be
--- /dev/null
+++ b/mapper/models/dinov2/data/samplers.py
@@ -0,0 +1,229 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+import itertools
+from typing import Any, Optional
+import warnings
+
+import numpy as np
+import torch
+from torch.utils.data.sampler import Sampler
+
+import dinov2.distributed as distributed
+
+
+class EpochSampler(Sampler):
+ def __init__(
+ self,
+ *,
+ size: int,
+ sample_count: int,
+ shuffle: bool = False,
+ seed: int = 0,
+ start: Optional[int] = None,
+ step: Optional[int] = None,
+ ):
+ self._size = size
+ self._sample_count = sample_count
+ self._shuffle = shuffle
+ self._seed = seed
+ self._start = distributed.get_global_rank() if start is None else start
+ self._step = distributed.get_global_size() if step is None else step
+ self._epoch = 0
+
+ def __iter__(self):
+ count = (self._size + self._sample_count - 1) // self._sample_count
+ tiled_indices = np.tile(np.arange(self._sample_count), count)
+ if self._shuffle:
+ seed = self._seed * self._epoch if self._seed != 0 else self._epoch
+ rng = np.random.default_rng(seed)
+ iterable = rng.choice(tiled_indices, self._size, replace=False)
+ else:
+ iterable = tiled_indices[: self._size]
+
+ yield from itertools.islice(iterable, self._start, None, self._step)
+
+ def __len__(self):
+ return (self._size - self._start + self._step - 1) // self._step
+
+ def set_epoch(self, epoch):
+ self._epoch = epoch
+
+
+def _get_numpy_dtype(size: int) -> Any:
+ return np.int32 if size <= 2**31 else np.int64
+
+
+def _get_torch_dtype(size: int) -> Any:
+ return torch.int32 if size <= 2**31 else torch.int64
+
+
+def _generate_randperm_indices(*, size: int, generator: torch.Generator):
+ """Generate the indices of a random permutation."""
+ dtype = _get_torch_dtype(size)
+ # This is actually matching PyTorch's CPU implementation, see: https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/TensorFactories.cpp#L900-L921
+ perm = torch.arange(size, dtype=dtype)
+ for i in range(size):
+ j = torch.randint(i, size, size=(1,), generator=generator).item()
+
+ # Always swap even if no-op
+ value = perm[j].item()
+ perm[j] = perm[i].item()
+ perm[i] = value
+ yield value
+
+
+class InfiniteSampler(Sampler):
+ def __init__(
+ self,
+ *,
+ sample_count: int,
+ shuffle: bool = False,
+ seed: int = 0,
+ start: Optional[int] = None,
+ step: Optional[int] = None,
+ advance: int = 0,
+ ):
+ self._sample_count = sample_count
+ self._seed = seed
+ self._shuffle = shuffle
+ self._start = distributed.get_global_rank() if start is None else start
+ self._step = distributed.get_global_size() if step is None else step
+ self._advance = advance
+
+ def __iter__(self):
+ if self._shuffle:
+ iterator = self._shuffled_iterator()
+ else:
+ iterator = self._iterator()
+
+ yield from itertools.islice(iterator, self._advance, None)
+
+ def _iterator(self):
+ assert not self._shuffle
+
+ while True:
+ iterable = range(self._sample_count)
+ yield from itertools.islice(iterable, self._start, None, self._step)
+
+ def _shuffled_iterator(self):
+ assert self._shuffle
+
+ # Instantiate a generator here (rather than in the ctor) to keep the class
+ # picklable (requirement of mp.spawn)
+ generator = torch.Generator().manual_seed(self._seed)
+
+ while True:
+ iterable = _generate_randperm_indices(size=self._sample_count, generator=generator)
+ yield from itertools.islice(iterable, self._start, None, self._step)
+
+
+# The following function is somewhat equivalent to _new_shuffle_tensor_slice below,
+# but avoids a full in-place random permutation generation.
+def _shuffle_tensor_slice(
+ *, tensor: torch.Tensor, start: int = 0, step: int = 1, generator: torch.Generator
+) -> np.ndarray:
+ stop = len(tensor)
+ count = stop // step
+ drop_count = stop - step * count
+ if drop_count:
+ warnings.warn(f"# of dropped samples: {drop_count}")
+
+ dtype = _get_numpy_dtype(stop)
+ result = np.empty(count, dtype=dtype)
+
+ for i in range(count):
+ j = torch.randint(0, i + 1, size=(1,), generator=generator).item() if i > 0 else 0
+
+ result[i] = result[j]
+ result[j] = tensor[start + i * step].item()
+
+ return result
+
+
+def _new_shuffle_tensor_slice(
+ *, tensor: torch.Tensor, start: int = 0, step: int = 1, generator: torch.Generator
+) -> np.ndarray:
+ stop = len(tensor)
+ count = stop // step
+ dtype = torch.int64 # Needed for using randperm result as indices
+ count = stop // step
+ drop_count = stop - step * count
+ if drop_count:
+ warnings.warn(f"# of dropped samples: {drop_count}")
+ indices = torch.randperm(count, dtype=dtype, generator=generator)
+ return tensor[start::step][indices].numpy()
+
+
+def _make_seed(seed: int, start: int, iter_count: int) -> int:
+ # NOTE: Tried a few variants (including iter_count << 32), this one worked best.
+ return seed + start + (iter_count << 24)
+
+
+class ShardedInfiniteSampler(Sampler):
+ def __init__(
+ self,
+ *,
+ sample_count: int,
+ shuffle: bool = False,
+ seed: int = 0,
+ start: Optional[int] = None,
+ step: Optional[int] = None,
+ advance: int = 0,
+ use_new_shuffle_tensor_slice: bool = False,
+ ):
+ self._sample_count = sample_count
+ self._seed = seed
+ self._shuffle = shuffle
+ self._start = distributed.get_global_rank() if start is None else start
+ self._step = distributed.get_global_size() if step is None else step
+ self._advance = advance
+ self._iter_count = 0
+ self._shuffle_tensor_slice_fn = (
+ _new_shuffle_tensor_slice if use_new_shuffle_tensor_slice else _shuffle_tensor_slice
+ )
+
+ def __iter__(self):
+ iter_count = self._advance // self._sample_count
+ if iter_count > 0:
+ self._advance -= iter_count * self._sample_count
+ self._iter_count += iter_count
+
+ if self._shuffle:
+ iterator = self._shuffled_iterator()
+ else:
+ iterator = self._iterator()
+
+ yield from itertools.islice(iterator, self._advance, None)
+
+ def _iterator(self):
+ assert not self._shuffle
+
+ while True:
+ iterable = range(self._sample_count)
+ yield from itertools.islice(iterable, self._start, None, self._step)
+
+ def _shuffled_iterator(self):
+ assert self._shuffle
+
+ # Instantiate a generator here (rather than in the ctor) to be keep the class
+ # picklable (requirement of mp.spawn)
+ generator = torch.Generator()
+
+ # Always shuffle everything first
+ generator.manual_seed(self._seed)
+ dtype = _get_torch_dtype(self._sample_count)
+ perm = torch.randperm(self._sample_count, dtype=dtype, generator=generator)
+
+ while True:
+ # Re-seed on each iteration to allow skipping whole permutations
+ seed = _make_seed(self._seed, self._start, self._iter_count)
+ generator.manual_seed(seed)
+
+ iterable = self._shuffle_tensor_slice_fn(
+ tensor=perm, start=self._start, step=self._step, generator=generator
+ )
+ yield from iterable
+ self._iter_count += 1
diff --git a/mapper/models/dinov2/data/transforms.py b/mapper/models/dinov2/data/transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..eb5f252b50c54d58f160528c9f2b00fad47103c7
--- /dev/null
+++ b/mapper/models/dinov2/data/transforms.py
@@ -0,0 +1,91 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+from typing import Sequence
+
+import torch
+from torchvision import transforms
+
+
+class GaussianBlur(transforms.RandomApply):
+ """
+ Apply Gaussian Blur to the PIL image.
+ """
+
+ def __init__(self, *, p: float = 0.5, radius_min: float = 0.1, radius_max: float = 2.0):
+ # NOTE: torchvision is applying 1 - probability to return the original image
+ keep_p = 1 - p
+ transform = transforms.GaussianBlur(kernel_size=9, sigma=(radius_min, radius_max))
+ super().__init__(transforms=[transform], p=keep_p)
+
+
+class MaybeToTensor(transforms.ToTensor):
+ """
+ Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor, or keep as is if already a tensor.
+ """
+
+ def __call__(self, pic):
+ """
+ Args:
+ pic (PIL Image, numpy.ndarray or torch.tensor): Image to be converted to tensor.
+ Returns:
+ Tensor: Converted image.
+ """
+ if isinstance(pic, torch.Tensor):
+ return pic
+ return super().__call__(pic)
+
+
+# Use timm's names
+IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
+IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
+
+
+def make_normalize_transform(
+ mean: Sequence[float] = IMAGENET_DEFAULT_MEAN,
+ std: Sequence[float] = IMAGENET_DEFAULT_STD,
+) -> transforms.Normalize:
+ return transforms.Normalize(mean=mean, std=std)
+
+
+# This roughly matches torchvision's preset for classification training:
+# https://github.com/pytorch/vision/blob/main/references/classification/presets.py#L6-L44
+def make_classification_train_transform(
+ *,
+ crop_size: int = 224,
+ interpolation=transforms.InterpolationMode.BICUBIC,
+ hflip_prob: float = 0.5,
+ mean: Sequence[float] = IMAGENET_DEFAULT_MEAN,
+ std: Sequence[float] = IMAGENET_DEFAULT_STD,
+):
+ transforms_list = [transforms.RandomResizedCrop(crop_size, interpolation=interpolation)]
+ if hflip_prob > 0.0:
+ transforms_list.append(transforms.RandomHorizontalFlip(hflip_prob))
+ transforms_list.extend(
+ [
+ MaybeToTensor(),
+ make_normalize_transform(mean=mean, std=std),
+ ]
+ )
+ return transforms.Compose(transforms_list)
+
+
+# This matches (roughly) torchvision's preset for classification evaluation:
+# https://github.com/pytorch/vision/blob/main/references/classification/presets.py#L47-L69
+def make_classification_eval_transform(
+ *,
+ resize_size: int = 256,
+ interpolation=transforms.InterpolationMode.BICUBIC,
+ crop_size: int = 224,
+ mean: Sequence[float] = IMAGENET_DEFAULT_MEAN,
+ std: Sequence[float] = IMAGENET_DEFAULT_STD,
+) -> transforms.Compose:
+ transforms_list = [
+ transforms.Resize(resize_size, interpolation=interpolation),
+ transforms.CenterCrop(crop_size),
+ MaybeToTensor(),
+ make_normalize_transform(mean=mean, std=std),
+ ]
+ return transforms.Compose(transforms_list)
diff --git a/mapper/models/dinov2/distributed/__init__.py b/mapper/models/dinov2/distributed/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..23226f4536bf5acf4ffac242e9903d92863b246d
--- /dev/null
+++ b/mapper/models/dinov2/distributed/__init__.py
@@ -0,0 +1,270 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+import os
+import random
+import re
+import socket
+from typing import Dict, List
+
+import torch
+import torch.distributed as dist
+
+_LOCAL_RANK = -1
+_LOCAL_WORLD_SIZE = -1
+
+
+def is_enabled() -> bool:
+ """
+ Returns:
+ True if distributed training is enabled
+ """
+ return dist.is_available() and dist.is_initialized()
+
+
+def get_global_size() -> int:
+ """
+ Returns:
+ The number of processes in the process group
+ """
+ return dist.get_world_size() if is_enabled() else 1
+
+
+def get_global_rank() -> int:
+ """
+ Returns:
+ The rank of the current process within the global process group.
+ """
+ return dist.get_rank() if is_enabled() else 0
+
+
+def get_local_rank() -> int:
+ """
+ Returns:
+ The rank of the current process within the local (per-machine) process group.
+ """
+ if not is_enabled():
+ return 0
+ assert 0 <= _LOCAL_RANK < _LOCAL_WORLD_SIZE
+ return _LOCAL_RANK
+
+
+def get_local_size() -> int:
+ """
+ Returns:
+ The size of the per-machine process group,
+ i.e. the number of processes per machine.
+ """
+ if not is_enabled():
+ return 1
+ assert 0 <= _LOCAL_RANK < _LOCAL_WORLD_SIZE
+ return _LOCAL_WORLD_SIZE
+
+
+def is_main_process() -> bool:
+ """
+ Returns:
+ True if the current process is the main one.
+ """
+ return get_global_rank() == 0
+
+
+def _restrict_print_to_main_process() -> None:
+ """
+ This function disables printing when not in the main process
+ """
+ import builtins as __builtin__
+
+ builtin_print = __builtin__.print
+
+ def print(*args, **kwargs):
+ force = kwargs.pop("force", False)
+ if is_main_process() or force:
+ builtin_print(*args, **kwargs)
+
+ __builtin__.print = print
+
+
+def _get_master_port(seed: int = 0) -> int:
+ MIN_MASTER_PORT, MAX_MASTER_PORT = (20_000, 60_000)
+
+ master_port_str = os.environ.get("MASTER_PORT")
+ if master_port_str is None:
+ rng = random.Random(seed)
+ return rng.randint(MIN_MASTER_PORT, MAX_MASTER_PORT)
+
+ return int(master_port_str)
+
+
+def _get_available_port() -> int:
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
+ # A "" host address means INADDR_ANY i.e. binding to all interfaces.
+ # Note this is not compatible with IPv6.
+ s.bind(("", 0))
+ port = s.getsockname()[1]
+ return port
+
+
+_TORCH_DISTRIBUTED_ENV_VARS = (
+ "MASTER_ADDR",
+ "MASTER_PORT",
+ "RANK",
+ "WORLD_SIZE",
+ "LOCAL_RANK",
+ "LOCAL_WORLD_SIZE",
+)
+
+
+def _collect_env_vars() -> Dict[str, str]:
+ return {env_var: os.environ[env_var] for env_var in _TORCH_DISTRIBUTED_ENV_VARS if env_var in os.environ}
+
+
+def _is_slurm_job_process() -> bool:
+ return "SLURM_JOB_ID" in os.environ
+
+
+def _parse_slurm_node_list(s: str) -> List[str]:
+ nodes = []
+ # Extract "hostname", "hostname[1-2,3,4-5]," substrings
+ p = re.compile(r"(([^\[]+)(?:\[([^\]]+)\])?),?")
+ for m in p.finditer(s):
+ prefix, suffixes = s[m.start(2) : m.end(2)], s[m.start(3) : m.end(3)]
+ for suffix in suffixes.split(","):
+ span = suffix.split("-")
+ if len(span) == 1:
+ nodes.append(prefix + suffix)
+ else:
+ width = len(span[0])
+ start, end = int(span[0]), int(span[1]) + 1
+ nodes.extend([prefix + f"{i:0{width}}" for i in range(start, end)])
+ return nodes
+
+
+def _check_env_variable(key: str, new_value: str):
+ # Only check for difference with preset environment variables
+ if key in os.environ and os.environ[key] != new_value:
+ raise RuntimeError(f"Cannot export environment variables as {key} is already set")
+
+
+class _TorchDistributedEnvironment:
+ def __init__(self):
+ self.master_addr = "127.0.0.1"
+ self.master_port = 0
+ self.rank = -1
+ self.world_size = -1
+ self.local_rank = -1
+ self.local_world_size = -1
+
+ if _is_slurm_job_process():
+ return self._set_from_slurm_env()
+
+ env_vars = _collect_env_vars()
+ if not env_vars:
+ # Environment is not set
+ pass
+ elif len(env_vars) == len(_TORCH_DISTRIBUTED_ENV_VARS):
+ # Environment is fully set
+ return self._set_from_preset_env()
+ else:
+ # Environment is partially set
+ collected_env_vars = ", ".join(env_vars.keys())
+ raise RuntimeError(f"Partially set environment: {collected_env_vars}")
+
+ if torch.cuda.device_count() > 0:
+ return self._set_from_local()
+
+ raise RuntimeError("Can't initialize PyTorch distributed environment")
+
+ # Slurm job created with sbatch, submitit, etc...
+ def _set_from_slurm_env(self):
+ # logger.info("Initialization from Slurm environment")
+ job_id = int(os.environ["SLURM_JOB_ID"])
+ node_count = int(os.environ["SLURM_JOB_NUM_NODES"])
+ nodes = _parse_slurm_node_list(os.environ["SLURM_JOB_NODELIST"])
+ assert len(nodes) == node_count
+
+ self.master_addr = nodes[0]
+ self.master_port = _get_master_port(seed=job_id)
+ self.rank = int(os.environ["SLURM_PROCID"])
+ self.world_size = int(os.environ["SLURM_NTASKS"])
+ assert self.rank < self.world_size
+ self.local_rank = int(os.environ["SLURM_LOCALID"])
+ self.local_world_size = self.world_size // node_count
+ assert self.local_rank < self.local_world_size
+
+ # Single node job with preset environment (i.e. torchrun)
+ def _set_from_preset_env(self):
+ # logger.info("Initialization from preset environment")
+ self.master_addr = os.environ["MASTER_ADDR"]
+ self.master_port = os.environ["MASTER_PORT"]
+ self.rank = int(os.environ["RANK"])
+ self.world_size = int(os.environ["WORLD_SIZE"])
+ assert self.rank < self.world_size
+ self.local_rank = int(os.environ["LOCAL_RANK"])
+ self.local_world_size = int(os.environ["LOCAL_WORLD_SIZE"])
+ assert self.local_rank < self.local_world_size
+
+ # Single node and GPU job (i.e. local script run)
+ def _set_from_local(self):
+ # logger.info("Initialization from local")
+ self.master_addr = "127.0.0.1"
+ self.master_port = _get_available_port()
+ self.rank = 0
+ self.world_size = 1
+ self.local_rank = 0
+ self.local_world_size = 1
+
+ def export(self, *, overwrite: bool) -> "_TorchDistributedEnvironment":
+ # See the "Environment variable initialization" section from
+ # https://pytorch.org/docs/stable/distributed.html for the complete list of
+ # environment variables required for the env:// initialization method.
+ env_vars = {
+ "MASTER_ADDR": self.master_addr,
+ "MASTER_PORT": str(self.master_port),
+ "RANK": str(self.rank),
+ "WORLD_SIZE": str(self.world_size),
+ "LOCAL_RANK": str(self.local_rank),
+ "LOCAL_WORLD_SIZE": str(self.local_world_size),
+ }
+ if not overwrite:
+ for k, v in env_vars.items():
+ _check_env_variable(k, v)
+
+ os.environ.update(env_vars)
+ return self
+
+
+def enable(*, set_cuda_current_device: bool = True, overwrite: bool = False, allow_nccl_timeout: bool = False):
+ """Enable distributed mode
+
+ Args:
+ set_cuda_current_device: If True, call torch.cuda.set_device() to set the
+ current PyTorch CUDA device to the one matching the local rank.
+ overwrite: If True, overwrites already set variables. Else fails.
+ """
+
+ global _LOCAL_RANK, _LOCAL_WORLD_SIZE
+ if _LOCAL_RANK >= 0 or _LOCAL_WORLD_SIZE >= 0:
+ raise RuntimeError("Distributed mode has already been enabled")
+ torch_env = _TorchDistributedEnvironment()
+ torch_env.export(overwrite=overwrite)
+
+ if set_cuda_current_device:
+ torch.cuda.set_device(torch_env.local_rank)
+
+ if allow_nccl_timeout:
+ # This allows to use torch distributed timeout in a NCCL backend
+ key, value = "NCCL_ASYNC_ERROR_HANDLING", "1"
+ if not overwrite:
+ _check_env_variable(key, value)
+ os.environ[key] = value
+
+ dist.init_process_group(backend="nccl")
+ dist.barrier()
+
+ # Finalize setup
+ _LOCAL_RANK = torch_env.local_rank
+ _LOCAL_WORLD_SIZE = torch_env.local_world_size
+ _restrict_print_to_main_process()
diff --git a/mapper/models/dinov2/eval/__init__.py b/mapper/models/dinov2/eval/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b88da6bf80be92af00b72dfdb0a806fa64a7a2d9
--- /dev/null
+++ b/mapper/models/dinov2/eval/__init__.py
@@ -0,0 +1,4 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
diff --git a/mapper/models/dinov2/eval/depth/__init__.py b/mapper/models/dinov2/eval/depth/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b88da6bf80be92af00b72dfdb0a806fa64a7a2d9
--- /dev/null
+++ b/mapper/models/dinov2/eval/depth/__init__.py
@@ -0,0 +1,4 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
diff --git a/mapper/models/dinov2/eval/depth/models/__init__.py b/mapper/models/dinov2/eval/depth/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9a5825181dc2189424b5c58d245b36919cbc5b2e
--- /dev/null
+++ b/mapper/models/dinov2/eval/depth/models/__init__.py
@@ -0,0 +1,10 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+from .backbones import * # noqa: F403
+from .builder import BACKBONES, DEPTHER, HEADS, LOSSES, build_backbone, build_depther, build_head, build_loss
+from .decode_heads import * # noqa: F403
+from .depther import * # noqa: F403
+from .losses import * # noqa: F403
diff --git a/mapper/models/dinov2/eval/depth/models/backbones/__init__.py b/mapper/models/dinov2/eval/depth/models/backbones/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..520d75bc6e064b9d64487293604ac1bda6e2b6f7
--- /dev/null
+++ b/mapper/models/dinov2/eval/depth/models/backbones/__init__.py
@@ -0,0 +1,6 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+from .vision_transformer import DinoVisionTransformer
diff --git a/mapper/models/dinov2/eval/depth/models/backbones/vision_transformer.py b/mapper/models/dinov2/eval/depth/models/backbones/vision_transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..69bda46fd69eb7dabb8f5b60e6fa459fdc21aeab
--- /dev/null
+++ b/mapper/models/dinov2/eval/depth/models/backbones/vision_transformer.py
@@ -0,0 +1,16 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+from mmcv.runner import BaseModule
+
+from ..builder import BACKBONES
+
+
+@BACKBONES.register_module()
+class DinoVisionTransformer(BaseModule):
+ """Vision Transformer."""
+
+ def __init__(self, *args, **kwargs):
+ super().__init__()
diff --git a/mapper/models/dinov2/eval/depth/models/builder.py b/mapper/models/dinov2/eval/depth/models/builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..c152643435308afcff60b07cd68ea979fe1d90cb
--- /dev/null
+++ b/mapper/models/dinov2/eval/depth/models/builder.py
@@ -0,0 +1,49 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+import warnings
+
+from mmcv.cnn import MODELS as MMCV_MODELS
+from mmcv.cnn.bricks.registry import ATTENTION as MMCV_ATTENTION
+from mmcv.utils import Registry
+
+MODELS = Registry("models", parent=MMCV_MODELS)
+ATTENTION = Registry("attention", parent=MMCV_ATTENTION)
+
+
+BACKBONES = MODELS
+NECKS = MODELS
+HEADS = MODELS
+LOSSES = MODELS
+DEPTHER = MODELS
+
+
+def build_backbone(cfg):
+ """Build backbone."""
+ return BACKBONES.build(cfg)
+
+
+def build_neck(cfg):
+ """Build neck."""
+ return NECKS.build(cfg)
+
+
+def build_head(cfg):
+ """Build head."""
+ return HEADS.build(cfg)
+
+
+def build_loss(cfg):
+ """Build loss."""
+ return LOSSES.build(cfg)
+
+
+def build_depther(cfg, train_cfg=None, test_cfg=None):
+ """Build depther."""
+ if train_cfg is not None or test_cfg is not None:
+ warnings.warn("train_cfg and test_cfg is deprecated, " "please specify them in model", UserWarning)
+ assert cfg.get("train_cfg") is None or train_cfg is None, "train_cfg specified in both outer field and model field "
+ assert cfg.get("test_cfg") is None or test_cfg is None, "test_cfg specified in both outer field and model field "
+ return DEPTHER.build(cfg, default_args=dict(train_cfg=train_cfg, test_cfg=test_cfg))
diff --git a/mapper/models/dinov2/eval/depth/models/decode_heads/__init__.py b/mapper/models/dinov2/eval/depth/models/decode_heads/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..bd0f0754a5b01d7622c1f26bf3f60daea19da4e8
--- /dev/null
+++ b/mapper/models/dinov2/eval/depth/models/decode_heads/__init__.py
@@ -0,0 +1,7 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+from .dpt_head import DPTHead
+from .linear_head import BNHead
diff --git a/mapper/models/dinov2/eval/depth/models/decode_heads/decode_head.py b/mapper/models/dinov2/eval/depth/models/decode_heads/decode_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..8d1237c6c6ea6463141084cc6f3d2f2c156c364f
--- /dev/null
+++ b/mapper/models/dinov2/eval/depth/models/decode_heads/decode_head.py
@@ -0,0 +1,230 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+import copy
+from abc import ABCMeta, abstractmethod
+
+import mmcv
+import numpy as np
+import torch
+import torch.nn as nn
+from mmcv.runner import BaseModule, auto_fp16, force_fp32
+
+from ...ops import resize
+from ..builder import build_loss
+
+
+class DepthBaseDecodeHead(BaseModule, metaclass=ABCMeta):
+ """Base class for BaseDecodeHead.
+
+ Args:
+ in_channels (List): Input channels.
+ channels (int): Channels after modules, before conv_depth.
+ conv_cfg (dict|None): Config of conv layers. Default: None.
+ act_cfg (dict): Config of activation layers.
+ Default: dict(type='ReLU')
+ loss_decode (dict): Config of decode loss.
+ Default: dict(type='SigLoss').
+ sampler (dict|None): The config of depth map sampler.
+ Default: None.
+ align_corners (bool): align_corners argument of F.interpolate.
+ Default: False.
+ min_depth (int): Min depth in dataset setting.
+ Default: 1e-3.
+ max_depth (int): Max depth in dataset setting.
+ Default: None.
+ norm_cfg (dict|None): Config of norm layers.
+ Default: None.
+ classify (bool): Whether predict depth in a cls.-reg. manner.
+ Default: False.
+ n_bins (int): The number of bins used in cls. step.
+ Default: 256.
+ bins_strategy (str): The discrete strategy used in cls. step.
+ Default: 'UD'.
+ norm_strategy (str): The norm strategy on cls. probability
+ distribution. Default: 'linear'
+ scale_up (str): Whether predict depth in a scale-up manner.
+ Default: False.
+ """
+
+ def __init__(
+ self,
+ in_channels,
+ channels=96,
+ conv_cfg=None,
+ act_cfg=dict(type="ReLU"),
+ loss_decode=dict(type="SigLoss", valid_mask=True, loss_weight=10),
+ sampler=None,
+ align_corners=False,
+ min_depth=1e-3,
+ max_depth=None,
+ norm_cfg=None,
+ classify=False,
+ n_bins=256,
+ bins_strategy="UD",
+ norm_strategy="linear",
+ scale_up=False,
+ ):
+ super(DepthBaseDecodeHead, self).__init__()
+
+ self.in_channels = in_channels
+ self.channels = channels
+ self.conv_cfg = conv_cfg
+ self.act_cfg = act_cfg
+ if isinstance(loss_decode, dict):
+ self.loss_decode = build_loss(loss_decode)
+ elif isinstance(loss_decode, (list, tuple)):
+ self.loss_decode = nn.ModuleList()
+ for loss in loss_decode:
+ self.loss_decode.append(build_loss(loss))
+ self.align_corners = align_corners
+ self.min_depth = min_depth
+ self.max_depth = max_depth
+ self.norm_cfg = norm_cfg
+ # self.classify = classify
+ self.classify = True
+ self.n_bins = n_bins = 256
+ self.scale_up = scale_up
+
+ if self.classify:
+ assert bins_strategy in ["UD", "SID"], "Support bins_strategy: UD, SID"
+ assert norm_strategy in ["linear", "softmax", "sigmoid"], "Support norm_strategy: linear, softmax, sigmoid"
+
+ self.bins_strategy = bins_strategy
+ self.norm_strategy = norm_strategy
+ self.softmax = nn.Softmax(dim=1)
+ self.conv_depth = nn.Conv2d(channels, n_bins, kernel_size=3, padding=1, stride=1)
+ else:
+ self.conv_depth = nn.Conv2d(channels, 1, kernel_size=3, padding=1, stride=1)
+
+ self.fp16_enabled = False
+ self.relu = nn.ReLU()
+ self.sigmoid = nn.Sigmoid()
+
+ def extra_repr(self):
+ """Extra repr."""
+ s = f"align_corners={self.align_corners}"
+ return s
+
+ @auto_fp16()
+ @abstractmethod
+ def forward(self, inputs, img_metas):
+ """Placeholder of forward function."""
+ pass
+
+ def forward_train(self, img, inputs, img_metas, depth_gt, train_cfg):
+ """Forward function for training.
+ Args:
+ inputs (list[Tensor]): List of multi-level img features.
+ img_metas (list[dict]): List of image info dict where each dict
+ has: 'img_shape', 'scale_factor', 'flip', and may also contain
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
+ For details on the values of these keys see
+ `depth/datasets/pipelines/formatting.py:Collect`.
+ depth_gt (Tensor): GT depth
+ train_cfg (dict): The training config.
+
+ Returns:
+ dict[str, Tensor]: a dictionary of loss components
+ """
+ depth_pred = self.forward(inputs, img_metas)
+ losses = self.losses(depth_pred, depth_gt)
+
+ log_imgs = self.log_images(img[0], depth_pred[0], depth_gt[0], img_metas[0])
+ losses.update(**log_imgs)
+
+ return losses
+
+ def forward_test(self, inputs, img_metas, test_cfg):
+ """Forward function for testing.
+ Args:
+ inputs (list[Tensor]): List of multi-level img features.
+ img_metas (list[dict]): List of image info dict where each dict
+ has: 'img_shape', 'scale_factor', 'flip', and may also contain
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
+ For details on the values of these keys see
+ `depth/datasets/pipelines/formatting.py:Collect`.
+ test_cfg (dict): The testing config.
+
+ Returns:
+ Tensor: Output depth map.
+ """
+ return self.forward(inputs, img_metas)
+
+
+ def depth_pred(self, feat):
+ """Prediction each pixel."""
+ if self.classify:
+ # print("Here1\n\n\n\n")
+ logit = self.conv_depth(feat)
+ # if self.bins_strategy == "UD":
+ # bins = torch.linspace(self.min_depth, self.max_depth, self.n_bins, device=feat.device)
+ # elif self.bins_strategy == "SID":
+ # bins = torch.logspace(self.min_depth, self.max_depth, self.n_bins, device=feat.device)
+
+ # # following Adabins, default linear
+ # if self.norm_strategy == "linear":
+ # logit = torch.relu(logit)
+ # eps = 0.1
+ # logit = logit + eps
+ # logit = logit / logit.sum(dim=1, keepdim=True)
+ # elif self.norm_strategy == "softmax":
+ # logit = torch.softmax(logit, dim=1)
+ # elif self.norm_strategy == "sigmoid":
+ # logit = torch.sigmoid(logit)
+ # logit = logit / logit.sum(dim=1, keepdim=True)
+
+ # output = torch.einsum("ikmn,k->imn", [logit, bins]).unsqueeze(dim=1)
+ # output = torch.einsum("ikmn,k->ikmn", [logit, bins])
+ output = logit
+
+ else:
+ if self.scale_up:
+ output = self.sigmoid(self.conv_depth(feat)) * self.max_depth
+ else:
+ output = self.relu(self.conv_depth(feat)) + self.min_depth
+
+ return output
+
+ @force_fp32(apply_to=("depth_pred",))
+ def losses(self, depth_pred, depth_gt):
+ """Compute depth loss."""
+ loss = dict()
+ depth_pred = resize(
+ input=depth_pred, size=depth_gt.shape[2:], mode="bilinear", align_corners=self.align_corners, warning=False
+ )
+ if not isinstance(self.loss_decode, nn.ModuleList):
+ losses_decode = [self.loss_decode]
+ else:
+ losses_decode = self.loss_decode
+ for loss_decode in losses_decode:
+ if loss_decode.loss_name not in loss:
+ loss[loss_decode.loss_name] = loss_decode(depth_pred, depth_gt)
+ else:
+ loss[loss_decode.loss_name] += loss_decode(depth_pred, depth_gt)
+ return loss
+
+ def log_images(self, img_path, depth_pred, depth_gt, img_meta):
+ show_img = copy.deepcopy(img_path.detach().cpu().permute(1, 2, 0))
+ show_img = show_img.numpy().astype(np.float32)
+ show_img = mmcv.imdenormalize(
+ show_img,
+ img_meta["img_norm_cfg"]["mean"],
+ img_meta["img_norm_cfg"]["std"],
+ img_meta["img_norm_cfg"]["to_rgb"],
+ )
+ show_img = np.clip(show_img, 0, 255)
+ show_img = show_img.astype(np.uint8)
+ show_img = show_img[:, :, ::-1]
+ show_img = show_img.transpose(0, 2, 1)
+ show_img = show_img.transpose(1, 0, 2)
+
+ depth_pred = depth_pred / torch.max(depth_pred)
+ depth_gt = depth_gt / torch.max(depth_gt)
+
+ depth_pred_color = copy.deepcopy(depth_pred.detach().cpu())
+ depth_gt_color = copy.deepcopy(depth_gt.detach().cpu())
+
+ return {"img_rgb": show_img, "img_depth_pred": depth_pred_color, "img_depth_gt": depth_gt_color}
diff --git a/mapper/models/dinov2/eval/depth/models/decode_heads/dpt_head.py b/mapper/models/dinov2/eval/depth/models/decode_heads/dpt_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..846682d8cf29543c1103b19380e5f6b11305e1f9
--- /dev/null
+++ b/mapper/models/dinov2/eval/depth/models/decode_heads/dpt_head.py
@@ -0,0 +1,270 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+import math
+
+import torch
+import torch.nn as nn
+from mmcv.cnn import ConvModule, Linear, build_activation_layer
+from mmcv.runner import BaseModule
+
+from ...ops import resize
+from ..builder import HEADS
+from .decode_head import DepthBaseDecodeHead
+
+
+class Interpolate(nn.Module):
+ def __init__(self, scale_factor, mode, align_corners=False):
+ super(Interpolate, self).__init__()
+ self.interp = nn.functional.interpolate
+ self.scale_factor = scale_factor
+ self.mode = mode
+ self.align_corners = align_corners
+
+ def forward(self, x):
+ x = self.interp(x.contiguous(), scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners)
+ return x
+
+
+class HeadDepth(nn.Module):
+ def __init__(self, features):
+ super(HeadDepth, self).__init__()
+ self.head = nn.Sequential(
+ nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
+ Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
+ # nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
+ # nn.ReLU(),
+ # nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
+ )
+
+ def forward(self, x):
+ x = self.head(x)
+ return x
+
+
+class ReassembleBlocks(BaseModule):
+ """ViTPostProcessBlock, process cls_token in ViT backbone output and
+ rearrange the feature vector to feature map.
+ Args:
+ in_channels (int): ViT feature channels. Default: 768.
+ out_channels (List): output channels of each stage.
+ Default: [96, 192, 384, 768].
+ readout_type (str): Type of readout operation. Default: 'ignore'.
+ patch_size (int): The patch size. Default: 16.
+ init_cfg (dict, optional): Initialization config dict. Default: None.
+ """
+
+ def __init__(
+ self, in_channels=768, out_channels=[96, 192, 384, 768], readout_type="ignore", patch_size=16, init_cfg=None
+ ):
+ super(ReassembleBlocks, self).__init__(init_cfg)
+
+ assert readout_type in ["ignore", "add", "project"]
+ self.readout_type = readout_type
+ self.patch_size = patch_size
+
+ self.projects = nn.ModuleList(
+ [
+ ConvModule(
+ in_channels=in_channels,
+ out_channels=out_channel,
+ kernel_size=1,
+ act_cfg=None,
+ )
+ for out_channel in out_channels
+ ]
+ )
+
+ self.resize_layers = nn.ModuleList(
+ [
+ nn.ConvTranspose2d(
+ in_channels=out_channels[0], out_channels=out_channels[0], kernel_size=4, stride=4, padding=0
+ ),
+ nn.ConvTranspose2d(
+ in_channels=out_channels[1], out_channels=out_channels[1], kernel_size=2, stride=2, padding=0
+ ),
+ nn.Identity(),
+ nn.Conv2d(
+ in_channels=out_channels[3], out_channels=out_channels[3], kernel_size=3, stride=2, padding=1
+ ),
+ ]
+ )
+ if self.readout_type == "project":
+ self.readout_projects = nn.ModuleList()
+ for _ in range(len(self.projects)):
+ self.readout_projects.append(
+ nn.Sequential(Linear(2 * in_channels, in_channels), build_activation_layer(dict(type="GELU")))
+ )
+
+ def forward(self, inputs):
+ assert isinstance(inputs, list)
+ out = []
+ for i, x in enumerate(inputs):
+ assert len(x) == 2
+ x, cls_token = x[0], x[1]
+ feature_shape = x.shape
+ if self.readout_type == "project":
+ x = x.flatten(2).permute((0, 2, 1))
+ readout = cls_token.unsqueeze(1).expand_as(x)
+ x = self.readout_projects[i](torch.cat((x, readout), -1))
+ x = x.permute(0, 2, 1).reshape(feature_shape)
+ elif self.readout_type == "add":
+ x = x.flatten(2) + cls_token.unsqueeze(-1)
+ x = x.reshape(feature_shape)
+ else:
+ pass
+ x = self.projects[i](x)
+ x = self.resize_layers[i](x)
+ out.append(x)
+ return out
+
+
+class PreActResidualConvUnit(BaseModule):
+ """ResidualConvUnit, pre-activate residual unit.
+ Args:
+ in_channels (int): number of channels in the input feature map.
+ act_cfg (dict): dictionary to construct and config activation layer.
+ norm_cfg (dict): dictionary to construct and config norm layer.
+ stride (int): stride of the first block. Default: 1
+ dilation (int): dilation rate for convs layers. Default: 1.
+ init_cfg (dict, optional): Initialization config dict. Default: None.
+ """
+
+ def __init__(self, in_channels, act_cfg, norm_cfg, stride=1, dilation=1, init_cfg=None):
+ super(PreActResidualConvUnit, self).__init__(init_cfg)
+
+ self.conv1 = ConvModule(
+ in_channels,
+ in_channels,
+ 3,
+ stride=stride,
+ padding=dilation,
+ dilation=dilation,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg,
+ bias=False,
+ order=("act", "conv", "norm"),
+ )
+
+ self.conv2 = ConvModule(
+ in_channels,
+ in_channels,
+ 3,
+ padding=1,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg,
+ bias=False,
+ order=("act", "conv", "norm"),
+ )
+
+ def forward(self, inputs):
+ inputs_ = inputs.clone()
+ x = self.conv1(inputs)
+ x = self.conv2(x)
+ return x + inputs_
+
+
+class FeatureFusionBlock(BaseModule):
+ """FeatureFusionBlock, merge feature map from different stages.
+ Args:
+ in_channels (int): Input channels.
+ act_cfg (dict): The activation config for ResidualConvUnit.
+ norm_cfg (dict): Config dict for normalization layer.
+ expand (bool): Whether expand the channels in post process block.
+ Default: False.
+ align_corners (bool): align_corner setting for bilinear upsample.
+ Default: True.
+ init_cfg (dict, optional): Initialization config dict. Default: None.
+ """
+
+ def __init__(self, in_channels, act_cfg, norm_cfg, expand=False, align_corners=True, init_cfg=None):
+ super(FeatureFusionBlock, self).__init__(init_cfg)
+
+ self.in_channels = in_channels
+ self.expand = expand
+ self.align_corners = align_corners
+
+ self.out_channels = in_channels
+ if self.expand:
+ self.out_channels = in_channels // 2
+
+ self.project = ConvModule(self.in_channels, self.out_channels, kernel_size=1, act_cfg=None, bias=True)
+
+ self.res_conv_unit1 = PreActResidualConvUnit(in_channels=self.in_channels, act_cfg=act_cfg, norm_cfg=norm_cfg)
+ self.res_conv_unit2 = PreActResidualConvUnit(in_channels=self.in_channels, act_cfg=act_cfg, norm_cfg=norm_cfg)
+
+ def forward(self, *inputs):
+ x = inputs[0]
+ if len(inputs) == 2:
+ if x.shape != inputs[1].shape:
+ res = resize(inputs[1], size=(x.shape[2], x.shape[3]), mode="bilinear", align_corners=False)
+ else:
+ res = inputs[1]
+ x = x + self.res_conv_unit1(res)
+ x = self.res_conv_unit2(x)
+ x = resize(x, scale_factor=2, mode="bilinear", align_corners=self.align_corners)
+ x = self.project(x)
+ return x
+
+
+@HEADS.register_module()
+class DPTHead(DepthBaseDecodeHead):
+ """Vision Transformers for Dense Prediction.
+ This head is implemented of `DPT `_.
+ Args:
+ embed_dims (int): The embed dimension of the ViT backbone.
+ Default: 768.
+ post_process_channels (List): Out channels of post process conv
+ layers. Default: [96, 192, 384, 768].
+ readout_type (str): Type of readout operation. Default: 'ignore'.
+ patch_size (int): The patch size. Default: 16.
+ expand_channels (bool): Whether expand the channels in post process
+ block. Default: False.
+ """
+
+ def __init__(
+ self,
+ embed_dims=768,
+ post_process_channels=[96, 192, 384, 768],
+ readout_type="ignore",
+ patch_size=16,
+ expand_channels=False,
+ **kwargs
+ ):
+ super(DPTHead, self).__init__(**kwargs)
+
+ self.in_channels = self.in_channels
+ self.expand_channels = expand_channels
+ self.reassemble_blocks = ReassembleBlocks(embed_dims, post_process_channels, readout_type, patch_size)
+
+ self.post_process_channels = [
+ channel * math.pow(2, i) if expand_channels else channel for i, channel in enumerate(post_process_channels)
+ ]
+ self.convs = nn.ModuleList()
+ for channel in self.post_process_channels:
+ self.convs.append(ConvModule(channel, self.channels, kernel_size=3, padding=1, act_cfg=None, bias=False))
+ self.fusion_blocks = nn.ModuleList()
+ for _ in range(len(self.convs)):
+ self.fusion_blocks.append(FeatureFusionBlock(self.channels, self.act_cfg, self.norm_cfg))
+ self.fusion_blocks[0].res_conv_unit1 = None
+ self.project = ConvModule(self.channels, self.channels, kernel_size=3, padding=1, norm_cfg=self.norm_cfg)
+ self.num_fusion_blocks = len(self.fusion_blocks)
+ self.num_reassemble_blocks = len(self.reassemble_blocks.resize_layers)
+ self.num_post_process_channels = len(self.post_process_channels)
+ assert self.num_fusion_blocks == self.num_reassemble_blocks
+ assert self.num_reassemble_blocks == self.num_post_process_channels
+ self.conv_depth = HeadDepth(self.channels)
+
+ def forward(self, inputs, img_metas):
+ assert len(inputs) == self.num_reassemble_blocks
+ x = [inp for inp in inputs]
+ x = self.reassemble_blocks(x)
+ x = [self.convs[i](feature) for i, feature in enumerate(x)]
+ out = self.fusion_blocks[0](x[-1])
+ for i in range(1, len(self.fusion_blocks)):
+ out = self.fusion_blocks[i](out, x[-(i + 1)])
+ projection = self.project(out)
+ out = self.depth_pred(projection)
+ return projection, out
diff --git a/mapper/models/dinov2/eval/depth/models/decode_heads/linear_head.py b/mapper/models/dinov2/eval/depth/models/decode_heads/linear_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..ef1853516c138307c1e220e5f9eb7b9f4e89ccd4
--- /dev/null
+++ b/mapper/models/dinov2/eval/depth/models/decode_heads/linear_head.py
@@ -0,0 +1,89 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn as nn
+
+from ...ops import resize
+from ..builder import HEADS
+from .decode_head import DepthBaseDecodeHead
+
+
+@HEADS.register_module()
+class BNHead(DepthBaseDecodeHead):
+ """Just a batchnorm."""
+
+ def __init__(self, input_transform="resize_concat", in_index=(0, 1, 2, 3), upsample=1, **kwargs):
+ super().__init__(**kwargs)
+ self.input_transform = input_transform
+ self.in_index = in_index
+ self.upsample = upsample
+ # self.bn = nn.SyncBatchNorm(self.in_channels)
+ if self.classify:
+ self.conv_depth = nn.Conv2d(self.channels, self.n_bins, kernel_size=1, padding=0, stride=1)
+ else:
+ self.conv_depth = nn.Conv2d(self.channels, 1, kernel_size=1, padding=0, stride=1)
+
+ def _transform_inputs(self, inputs):
+ """Transform inputs for decoder.
+ Args:
+ inputs (list[Tensor]): List of multi-level img features.
+ Returns:
+ Tensor: The transformed inputs
+ """
+
+ if "concat" in self.input_transform:
+ inputs = [inputs[i] for i in self.in_index]
+ if "resize" in self.input_transform:
+ inputs = [
+ resize(
+ input=x,
+ size=[s * self.upsample for s in inputs[0].shape[2:]],
+ mode="bilinear",
+ align_corners=self.align_corners,
+ )
+ for x in inputs
+ ]
+ inputs = torch.cat(inputs, dim=1)
+ elif self.input_transform == "multiple_select":
+ inputs = [inputs[i] for i in self.in_index]
+ else:
+ inputs = inputs[self.in_index]
+
+ return inputs
+
+ def _forward_feature(self, inputs, img_metas=None, **kwargs):
+ """Forward function for feature maps before classifying each pixel with
+ ``self.cls_seg`` fc.
+ Args:
+ inputs (list[Tensor]): List of multi-level img features.
+ Returns:
+ feats (Tensor): A tensor of shape (batch_size, self.channels,
+ H, W) which is feature map for last layer of decoder head.
+ """
+ # accept lists (for cls token)
+ inputs = list(inputs)
+ for i, x in enumerate(inputs):
+ if len(x) == 2:
+ x, cls_token = x[0], x[1]
+ if len(x.shape) == 2:
+ x = x[:, :, None, None]
+ cls_token = cls_token[:, :, None, None].expand_as(x)
+ inputs[i] = torch.cat((x, cls_token), 1)
+ else:
+ x = x[0]
+ if len(x.shape) == 2:
+ x = x[:, :, None, None]
+ inputs[i] = x
+ x = self._transform_inputs(inputs)
+ # feats = self.bn(x)
+ return x
+
+ def forward(self, inputs, img_metas=None, **kwargs):
+ """Forward function."""
+ feature = self._forward_feature(inputs, img_metas=img_metas, **kwargs)
+ output = self.depth_pred(feature)
+
+ return feature, output
diff --git a/mapper/models/dinov2/eval/depth/models/depther/__init__.py b/mapper/models/dinov2/eval/depth/models/depther/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..be99743bf6c773d05f2b74524116e368c0cfcba0
--- /dev/null
+++ b/mapper/models/dinov2/eval/depth/models/depther/__init__.py
@@ -0,0 +1,7 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+from .base import BaseDepther
+from .encoder_decoder import DepthEncoderDecoder
diff --git a/mapper/models/dinov2/eval/depth/models/depther/base.py b/mapper/models/dinov2/eval/depth/models/depther/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..e133a825a888167f90d95d67803609d6cac7ff55
--- /dev/null
+++ b/mapper/models/dinov2/eval/depth/models/depther/base.py
@@ -0,0 +1,194 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+from abc import ABCMeta, abstractmethod
+from collections import OrderedDict
+
+import torch
+import torch.distributed as dist
+from mmcv.runner import BaseModule, auto_fp16
+
+
+class BaseDepther(BaseModule, metaclass=ABCMeta):
+ """Base class for depther."""
+
+ def __init__(self, init_cfg=None):
+ super(BaseDepther, self).__init__(init_cfg)
+ self.fp16_enabled = False
+
+ @property
+ def with_neck(self):
+ """bool: whether the depther has neck"""
+ return hasattr(self, "neck") and self.neck is not None
+
+ @property
+ def with_auxiliary_head(self):
+ """bool: whether the depther has auxiliary head"""
+ return hasattr(self, "auxiliary_head") and self.auxiliary_head is not None
+
+ @property
+ def with_decode_head(self):
+ """bool: whether the depther has decode head"""
+ return hasattr(self, "decode_head") and self.decode_head is not None
+
+ @abstractmethod
+ def extract_feat(self, imgs):
+ """Placeholder for extract features from images."""
+ pass
+
+ @abstractmethod
+ def encode_decode(self, img, img_metas):
+ """Placeholder for encode images with backbone and decode into a
+ semantic depth map of the same size as input."""
+ pass
+
+ @abstractmethod
+ def forward_train(self, imgs, img_metas, **kwargs):
+ """Placeholder for Forward function for training."""
+ pass
+
+ @abstractmethod
+ def simple_test(self, img, img_meta, **kwargs):
+ """Placeholder for single image test."""
+ pass
+
+ @abstractmethod
+ def aug_test(self, imgs, img_metas, **kwargs):
+ """Placeholder for augmentation test."""
+ pass
+
+ def forward_test(self, imgs, img_metas, **kwargs):
+ """
+ Args:
+ imgs (List[Tensor]): the outer list indicates test-time
+ augmentations and inner Tensor should have a shape NxCxHxW,
+ which contains all images in the batch.
+ img_metas (List[List[dict]]): the outer list indicates test-time
+ augs (multiscale, flip, etc.) and the inner list indicates
+ images in a batch.
+ """
+ for var, name in [(imgs, "imgs"), (img_metas, "img_metas")]:
+ if not isinstance(var, list):
+ raise TypeError(f"{name} must be a list, but got " f"{type(var)}")
+ num_augs = len(imgs)
+ if num_augs != len(img_metas):
+ raise ValueError(f"num of augmentations ({len(imgs)}) != " f"num of image meta ({len(img_metas)})")
+ # all images in the same aug batch all of the same ori_shape and pad
+ # shape
+ for img_meta in img_metas:
+ ori_shapes = [_["ori_shape"] for _ in img_meta]
+ assert all(shape == ori_shapes[0] for shape in ori_shapes)
+ img_shapes = [_["img_shape"] for _ in img_meta]
+ assert all(shape == img_shapes[0] for shape in img_shapes)
+ pad_shapes = [_["pad_shape"] for _ in img_meta]
+ assert all(shape == pad_shapes[0] for shape in pad_shapes)
+
+ if num_augs == 1:
+ return self.simple_test(imgs[0], img_metas[0], **kwargs)
+ else:
+ return self.aug_test(imgs, img_metas, **kwargs)
+
+ @auto_fp16(apply_to=("img",))
+ def forward(self, img, img_metas, return_loss=True, **kwargs):
+ """Calls either :func:`forward_train` or :func:`forward_test` depending
+ on whether ``return_loss`` is ``True``.
+
+ Note this setting will change the expected inputs. When
+ ``return_loss=True``, img and img_meta are single-nested (i.e. Tensor
+ and List[dict]), and when ``resturn_loss=False``, img and img_meta
+ should be double nested (i.e. List[Tensor], List[List[dict]]), with
+ the outer list indicating test time augmentations.
+ """
+ if return_loss:
+ return self.forward_train(img, img_metas, **kwargs)
+ else:
+ return self.forward_test(img, img_metas, **kwargs)
+
+ def train_step(self, data_batch, optimizer, **kwargs):
+ """The iteration step during training.
+
+ This method defines an iteration step during training, except for the
+ back propagation and optimizer updating, which are done in an optimizer
+ hook. Note that in some complicated cases or models, the whole process
+ including back propagation and optimizer updating is also defined in
+ this method, such as GAN.
+
+ Args:
+ data (dict): The output of dataloader.
+ optimizer (:obj:`torch.optim.Optimizer` | dict): The optimizer of
+ runner is passed to ``train_step()``. This argument is unused
+ and reserved.
+
+ Returns:
+ dict: It should contain at least 3 keys: ``loss``, ``log_vars``,
+ ``num_samples``.
+ ``loss`` is a tensor for back propagation, which can be a
+ weighted sum of multiple losses.
+ ``log_vars`` contains all the variables to be sent to the
+ logger.
+ ``num_samples`` indicates the batch size (when the model is
+ DDP, it means the batch size on each GPU), which is used for
+ averaging the logs.
+ """
+ losses = self(**data_batch)
+
+ # split losses and images
+ real_losses = {}
+ log_imgs = {}
+ for k, v in losses.items():
+ if "img" in k:
+ log_imgs[k] = v
+ else:
+ real_losses[k] = v
+
+ loss, log_vars = self._parse_losses(real_losses)
+
+ outputs = dict(loss=loss, log_vars=log_vars, num_samples=len(data_batch["img_metas"]), log_imgs=log_imgs)
+
+ return outputs
+
+ def val_step(self, data_batch, **kwargs):
+ """The iteration step during validation.
+
+ This method shares the same signature as :func:`train_step`, but used
+ during val epochs. Note that the evaluation after training epochs is
+ not implemented with this method, but an evaluation hook.
+ """
+ output = self(**data_batch, **kwargs)
+ return output
+
+ @staticmethod
+ def _parse_losses(losses):
+ """Parse the raw outputs (losses) of the network.
+
+ Args:
+ losses (dict): Raw output of the network, which usually contain
+ losses and other necessary information.
+
+ Returns:
+ tuple[Tensor, dict]: (loss, log_vars), loss is the loss tensor
+ which may be a weighted sum of all losses, log_vars contains
+ all the variables to be sent to the logger.
+ """
+ log_vars = OrderedDict()
+ for loss_name, loss_value in losses.items():
+ if isinstance(loss_value, torch.Tensor):
+ log_vars[loss_name] = loss_value.mean()
+ elif isinstance(loss_value, list):
+ log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value)
+ else:
+ raise TypeError(f"{loss_name} is not a tensor or list of tensors")
+
+ loss = sum(_value for _key, _value in log_vars.items() if "loss" in _key)
+
+ log_vars["loss"] = loss
+ for loss_name, loss_value in log_vars.items():
+ # reduce loss when distributed training
+ if dist.is_available() and dist.is_initialized():
+ loss_value = loss_value.data.clone()
+ dist.all_reduce(loss_value.div_(dist.get_world_size()))
+ log_vars[loss_name] = loss_value.item()
+
+ return loss, log_vars
diff --git a/mapper/models/dinov2/eval/depth/models/depther/encoder_decoder.py b/mapper/models/dinov2/eval/depth/models/depther/encoder_decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..ca0f394cfa48fa43602816278288440fdb8c8cae
--- /dev/null
+++ b/mapper/models/dinov2/eval/depth/models/depther/encoder_decoder.py
@@ -0,0 +1,242 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn.functional as F
+
+from ...models import builder
+from ...models.builder import DEPTHER
+from ...ops import resize
+from .base import BaseDepther
+
+
+def add_prefix(inputs, prefix):
+ """Add prefix for dict.
+
+ Args:
+ inputs (dict): The input dict with str keys.
+ prefix (str): The prefix to add.
+
+ Returns:
+
+ dict: The dict with keys updated with ``prefix``.
+ """
+
+ outputs = dict()
+ for name, value in inputs.items():
+ outputs[f"{prefix}.{name}"] = value
+
+ return outputs
+
+
+@DEPTHER.register_module()
+class DepthEncoderDecoder(BaseDepther):
+ """Encoder Decoder depther.
+
+ EncoderDecoder typically consists of backbone, (neck) and decode_head.
+ """
+
+ def __init__(self, backbone, decode_head, neck=None, train_cfg=None, test_cfg=None, pretrained=None, init_cfg=None):
+ super(DepthEncoderDecoder, self).__init__(init_cfg)
+ if pretrained is not None:
+ assert backbone.get("pretrained") is None, "both backbone and depther set pretrained weight"
+ backbone.pretrained = pretrained
+ self.backbone = builder.build_backbone(backbone)
+ self._init_decode_head(decode_head)
+
+ if neck is not None:
+ self.neck = builder.build_neck(neck)
+
+ self.train_cfg = train_cfg
+ self.test_cfg = test_cfg
+
+ assert self.with_decode_head
+
+ def _init_decode_head(self, decode_head):
+ """Initialize ``decode_head``"""
+ self.decode_head = builder.build_head(decode_head)
+ self.align_corners = self.decode_head.align_corners
+
+ def extract_feat(self, img):
+ """Extract features from images."""
+ x = self.backbone(img)
+ if self.with_neck:
+ x = self.neck(x)
+ return x
+
+ def encode_decode(self, img, img_metas, rescale=True, size=None, scale=2):
+ """Encode images with backbone and decode into a depth estimation
+ map of the same size as input."""
+ with torch.no_grad():
+ x = self.extract_feat(img)
+ projection, out = self._decode_head_forward_test(x, img_metas)
+ projection = resize(input=projection, size=(img.shape[2] // scale, img.shape[3] // scale),
+ mode="bilinear", align_corners=self.align_corners)
+
+ # crop the pred depth to the certain range.
+ out = torch.clamp(out, min=self.decode_head.min_depth, max=self.decode_head.max_depth)
+ if rescale:
+ if size is None:
+ if img_metas is not None:
+ size = img_metas[0]["ori_shape"][:2]
+ else:
+ size = (img.shape[2] // scale, img.shape[3] // scale)
+ out = resize(input=out, size=size, mode="bilinear", align_corners=self.align_corners)
+ # print(projection.shape, out.shape, "\n\n\n\n")
+ return projection, out
+
+ def _decode_head_forward_train(self, img, x, img_metas, depth_gt, **kwargs):
+ """Run forward function and calculate loss for decode head in
+ training."""
+ losses = dict()
+ loss_decode = self.decode_head.forward_train(img, x, img_metas, depth_gt, self.train_cfg, **kwargs)
+ losses.update(add_prefix(loss_decode, "decode"))
+ return losses
+
+ def _decode_head_forward_test(self, x, img_metas):
+ """Run forward function and calculate loss for decode head in
+ inference."""
+ depth_pred = self.decode_head.forward_test(x, img_metas, self.test_cfg)
+ return depth_pred
+
+ def forward_dummy(self, img):
+ """Dummy forward function."""
+ feat, depth = self.encode_decode(img, None)
+
+ return depth
+
+ def forward_train(self, img, img_metas, depth_gt, **kwargs):
+ """Forward function for training.
+
+ Args:
+ img (Tensor): Input images.
+ img_metas (list[dict]): List of image info dict where each dict
+ has: 'img_shape', 'scale_factor', 'flip', and may also contain
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
+ For details on the values of these keys see
+ `depth/datasets/pipelines/formatting.py:Collect`.
+ depth_gt (Tensor): Depth gt
+ used if the architecture supports depth estimation task.
+
+ Returns:
+ dict[str, Tensor]: a dictionary of loss components
+ """
+
+ x = self.extract_feat(img)
+
+ losses = dict()
+
+ # the last of x saves the info from neck
+ loss_decode = self._decode_head_forward_train(img, x, img_metas, depth_gt, **kwargs)
+
+ losses.update(loss_decode)
+
+ return losses
+
+ def whole_inference(self, img, img_meta, rescale, size=None):
+ """Inference with full image."""
+ features, depth_pred = self.encode_decode(img, img_meta, rescale, size=size)
+ # print(features.shape, depth_pred.shape, "\n\n\n\n\n\n")
+
+ return features, depth_pred
+
+ def slide_inference(self, img, img_meta, rescale):
+ """Inference by sliding-window with overlap.
+
+ If h_crop > h_img or w_crop > w_img, the small patch will be used to
+ decode without padding.
+ """
+
+ h_stride, w_stride = self.test_cfg.stride
+ h_crop, w_crop = self.test_cfg.crop_size
+ batch_size, _, h_img, w_img = img.size()
+ h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1
+ w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1
+ preds = img.new_zeros((batch_size, 1, h_img, w_img))
+ count_mat = img.new_zeros((batch_size, 1, h_img, w_img))
+ for h_idx in range(h_grids):
+ for w_idx in range(w_grids):
+ y1 = h_idx * h_stride
+ x1 = w_idx * w_stride
+ y2 = min(y1 + h_crop, h_img)
+ x2 = min(x1 + w_crop, w_img)
+ y1 = max(y2 - h_crop, 0)
+ x1 = max(x2 - w_crop, 0)
+ crop_img = img[:, :, y1:y2, x1:x2]
+ _, depth_pred = self.encode_decode(crop_img, img_meta, rescale)
+ preds += F.pad(depth_pred, (int(x1), int(preds.shape[3] - x2), int(y1), int(preds.shape[2] - y2)))
+
+ count_mat[:, :, y1:y2, x1:x2] += 1
+ assert (count_mat == 0).sum() == 0
+ if torch.onnx.is_in_onnx_export():
+ # cast count_mat to constant while exporting to ONNX
+ count_mat = torch.from_numpy(count_mat.cpu().detach().numpy()).to(device=img.device)
+ preds = preds / count_mat
+ return preds
+
+ def inference(self, img, img_meta, rescale, size=None):
+ """Inference with slide/whole style.
+
+ Args:
+ img (Tensor): The input image of shape (N, 3, H, W).
+ img_meta (dict): Image info dict where each dict has: 'img_shape',
+ 'scale_factor', 'flip', and may also contain
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
+ For details on the values of these keys see
+ `depth/datasets/pipelines/formatting.py:Collect`.
+ rescale (bool): Whether rescale back to original shape.
+
+ Returns:
+ Tensor: The output depth map.
+ """
+
+ assert self.test_cfg.mode in ["slide", "whole"]
+ ori_shape = img_meta[0]["ori_shape"]
+ assert all(_["ori_shape"] == ori_shape for _ in img_meta)
+ if self.test_cfg.mode == "slide":
+ depth_pred = self.slide_inference(img, img_meta, rescale)
+ else:
+ depth_pred = self.whole_inference(img, img_meta, rescale, size=size)
+ output = depth_pred
+ flip = img_meta[0]["flip"]
+ if flip:
+ flip_direction = img_meta[0]["flip_direction"]
+ assert flip_direction in ["horizontal", "vertical"]
+ if flip_direction == "horizontal":
+ output = output.flip(dims=(3,))
+ elif flip_direction == "vertical":
+ output = output.flip(dims=(2,))
+
+ return output
+
+ def simple_test(self, img, img_meta, rescale=True):
+ """Simple test with single image."""
+ depth_pred = self.inference(img, img_meta, rescale)
+ if torch.onnx.is_in_onnx_export():
+ # our inference backend only support 4D output
+ depth_pred = depth_pred.unsqueeze(0)
+ return depth_pred
+ depth_pred = depth_pred.cpu().numpy()
+ # unravel batch dim
+ depth_pred = list(depth_pred)
+ return depth_pred
+
+ def aug_test(self, imgs, img_metas, rescale=True):
+ """Test with augmentations.
+
+ Only rescale=True is supported.
+ """
+ # aug_test rescale all imgs back to ori_shape for now
+ assert rescale
+ # to save memory, we get augmented depth logit inplace
+ depth_pred = self.inference(imgs[0], img_metas[0], rescale)
+ for i in range(1, len(imgs)):
+ cur_depth_pred = self.inference(imgs[i], img_metas[i], rescale, size=depth_pred.shape[-2:])
+ depth_pred += cur_depth_pred
+ depth_pred /= len(imgs)
+ depth_pred = depth_pred.cpu().numpy()
+ # unravel batch dim
+ depth_pred = list(depth_pred)
+ return depth_pred
diff --git a/mapper/models/dinov2/eval/depth/models/losses/__init__.py b/mapper/models/dinov2/eval/depth/models/losses/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2f86242e342776da2e0acc61150d15a8d58ff1e0
--- /dev/null
+++ b/mapper/models/dinov2/eval/depth/models/losses/__init__.py
@@ -0,0 +1,7 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+from .gradientloss import GradientLoss
+from .sigloss import SigLoss
diff --git a/mapper/models/dinov2/eval/depth/models/losses/gradientloss.py b/mapper/models/dinov2/eval/depth/models/losses/gradientloss.py
new file mode 100644
index 0000000000000000000000000000000000000000..1599878a6b70cdff4f8467e1e875f0d13ea89eca
--- /dev/null
+++ b/mapper/models/dinov2/eval/depth/models/losses/gradientloss.py
@@ -0,0 +1,69 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn as nn
+
+from ...models.builder import LOSSES
+
+
+@LOSSES.register_module()
+class GradientLoss(nn.Module):
+ """GradientLoss.
+
+ Adapted from https://www.cs.cornell.edu/projects/megadepth/
+
+ Args:
+ valid_mask (bool): Whether filter invalid gt (gt > 0). Default: True.
+ loss_weight (float): Weight of the loss. Default: 1.0.
+ max_depth (int): When filtering invalid gt, set a max threshold. Default: None.
+ """
+
+ def __init__(self, valid_mask=True, loss_weight=1.0, max_depth=None, loss_name="loss_grad"):
+ super(GradientLoss, self).__init__()
+ self.valid_mask = valid_mask
+ self.loss_weight = loss_weight
+ self.max_depth = max_depth
+ self.loss_name = loss_name
+
+ self.eps = 0.001 # avoid grad explode
+
+ def gradientloss(self, input, target):
+ input_downscaled = [input] + [input[:: 2 * i, :: 2 * i] for i in range(1, 4)]
+ target_downscaled = [target] + [target[:: 2 * i, :: 2 * i] for i in range(1, 4)]
+
+ gradient_loss = 0
+ for input, target in zip(input_downscaled, target_downscaled):
+ if self.valid_mask:
+ mask = target > 0
+ if self.max_depth is not None:
+ mask = torch.logical_and(target > 0, target <= self.max_depth)
+ N = torch.sum(mask)
+ else:
+ mask = torch.ones_like(target)
+ N = input.numel()
+ input_log = torch.log(input + self.eps)
+ target_log = torch.log(target + self.eps)
+ log_d_diff = input_log - target_log
+
+ log_d_diff = torch.mul(log_d_diff, mask)
+
+ v_gradient = torch.abs(log_d_diff[0:-2, :] - log_d_diff[2:, :])
+ v_mask = torch.mul(mask[0:-2, :], mask[2:, :])
+ v_gradient = torch.mul(v_gradient, v_mask)
+
+ h_gradient = torch.abs(log_d_diff[:, 0:-2] - log_d_diff[:, 2:])
+ h_mask = torch.mul(mask[:, 0:-2], mask[:, 2:])
+ h_gradient = torch.mul(h_gradient, h_mask)
+
+ gradient_loss += (torch.sum(h_gradient) + torch.sum(v_gradient)) / N
+
+ return gradient_loss
+
+ def forward(self, depth_pred, depth_gt):
+ """Forward function."""
+
+ gradient_loss = self.loss_weight * self.gradientloss(depth_pred, depth_gt)
+ return gradient_loss
diff --git a/mapper/models/dinov2/eval/depth/models/losses/sigloss.py b/mapper/models/dinov2/eval/depth/models/losses/sigloss.py
new file mode 100644
index 0000000000000000000000000000000000000000..e12fad3e6151e4b975dd055193fdaec0206d4a14
--- /dev/null
+++ b/mapper/models/dinov2/eval/depth/models/losses/sigloss.py
@@ -0,0 +1,65 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn as nn
+
+from ...models.builder import LOSSES
+
+
+@LOSSES.register_module()
+class SigLoss(nn.Module):
+ """SigLoss.
+
+ This follows `AdaBins `_.
+
+ Args:
+ valid_mask (bool): Whether filter invalid gt (gt > 0). Default: True.
+ loss_weight (float): Weight of the loss. Default: 1.0.
+ max_depth (int): When filtering invalid gt, set a max threshold. Default: None.
+ warm_up (bool): A simple warm up stage to help convergence. Default: False.
+ warm_iter (int): The number of warm up stage. Default: 100.
+ """
+
+ def __init__(
+ self, valid_mask=True, loss_weight=1.0, max_depth=None, warm_up=False, warm_iter=100, loss_name="sigloss"
+ ):
+ super(SigLoss, self).__init__()
+ self.valid_mask = valid_mask
+ self.loss_weight = loss_weight
+ self.max_depth = max_depth
+ self.loss_name = loss_name
+
+ self.eps = 0.001 # avoid grad explode
+
+ # HACK: a hack implementation for warmup sigloss
+ self.warm_up = warm_up
+ self.warm_iter = warm_iter
+ self.warm_up_counter = 0
+
+ def sigloss(self, input, target):
+ if self.valid_mask:
+ valid_mask = target > 0
+ if self.max_depth is not None:
+ valid_mask = torch.logical_and(target > 0, target <= self.max_depth)
+ input = input[valid_mask]
+ target = target[valid_mask]
+
+ if self.warm_up:
+ if self.warm_up_counter < self.warm_iter:
+ g = torch.log(input + self.eps) - torch.log(target + self.eps)
+ g = 0.15 * torch.pow(torch.mean(g), 2)
+ self.warm_up_counter += 1
+ return torch.sqrt(g)
+
+ g = torch.log(input + self.eps) - torch.log(target + self.eps)
+ Dg = torch.var(g) + 0.15 * torch.pow(torch.mean(g), 2)
+ return torch.sqrt(Dg)
+
+ def forward(self, depth_pred, depth_gt):
+ """Forward function."""
+
+ loss_depth = self.loss_weight * self.sigloss(depth_pred, depth_gt)
+ return loss_depth
diff --git a/mapper/models/dinov2/eval/depth/ops/__init__.py b/mapper/models/dinov2/eval/depth/ops/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..78181c29581a281b5f42cf12078636aaeb43b5a5
--- /dev/null
+++ b/mapper/models/dinov2/eval/depth/ops/__init__.py
@@ -0,0 +1,6 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+from .wrappers import resize
diff --git a/mapper/models/dinov2/eval/depth/ops/wrappers.py b/mapper/models/dinov2/eval/depth/ops/wrappers.py
new file mode 100644
index 0000000000000000000000000000000000000000..15880ee0cb7652d4b41c489b927bf6a156b40e5e
--- /dev/null
+++ b/mapper/models/dinov2/eval/depth/ops/wrappers.py
@@ -0,0 +1,28 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+import warnings
+
+import torch.nn.functional as F
+
+
+def resize(input, size=None, scale_factor=None, mode="nearest", align_corners=None, warning=False):
+ if warning:
+ if size is not None and align_corners:
+ input_h, input_w = tuple(int(x) for x in input.shape[2:])
+ output_h, output_w = tuple(int(x) for x in size)
+ if output_h > input_h or output_w > output_h:
+ if (
+ (output_h > 1 and output_w > 1 and input_h > 1 and input_w > 1)
+ and (output_h - 1) % (input_h - 1)
+ and (output_w - 1) % (input_w - 1)
+ ):
+ warnings.warn(
+ f"When align_corners={align_corners}, "
+ "the output would more aligned if "
+ f"input size {(input_h, input_w)} is `x+1` and "
+ f"out size {(output_h, output_w)} is `nx+1`"
+ )
+ return F.interpolate(input, size, scale_factor, mode, align_corners)
diff --git a/mapper/models/dinov2/eval/knn.py b/mapper/models/dinov2/eval/knn.py
new file mode 100644
index 0000000000000000000000000000000000000000..f3a4845da1313a6db6b8345bb9a98230fcd24acf
--- /dev/null
+++ b/mapper/models/dinov2/eval/knn.py
@@ -0,0 +1,404 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+import argparse
+from functools import partial
+import json
+import logging
+import os
+import sys
+from typing import List, Optional
+
+import torch
+from torch.nn.functional import one_hot, softmax
+
+import dinov2.distributed as distributed
+from dinov2.data import SamplerType, make_data_loader, make_dataset
+from dinov2.data.transforms import make_classification_eval_transform
+from dinov2.eval.metrics import AccuracyAveraging, build_topk_accuracy_metric
+from dinov2.eval.setup import get_args_parser as get_setup_args_parser
+from dinov2.eval.setup import setup_and_build_model
+from dinov2.eval.utils import ModelWithNormalize, evaluate, extract_features
+
+
+logger = logging.getLogger("dinov2")
+
+
+def get_args_parser(
+ description: Optional[str] = None,
+ parents: Optional[List[argparse.ArgumentParser]] = None,
+ add_help: bool = True,
+):
+ parents = parents or []
+ setup_args_parser = get_setup_args_parser(parents=parents, add_help=False)
+ parents = [setup_args_parser]
+ parser = argparse.ArgumentParser(
+ description=description,
+ parents=parents,
+ add_help=add_help,
+ )
+ parser.add_argument(
+ "--train-dataset",
+ dest="train_dataset_str",
+ type=str,
+ help="Training dataset",
+ )
+ parser.add_argument(
+ "--val-dataset",
+ dest="val_dataset_str",
+ type=str,
+ help="Validation dataset",
+ )
+ parser.add_argument(
+ "--nb_knn",
+ nargs="+",
+ type=int,
+ help="Number of NN to use. 20 is usually working the best.",
+ )
+ parser.add_argument(
+ "--temperature",
+ type=float,
+ help="Temperature used in the voting coefficient",
+ )
+ parser.add_argument(
+ "--gather-on-cpu",
+ action="store_true",
+ help="Whether to gather the train features on cpu, slower"
+ "but useful to avoid OOM for large datasets (e.g. ImageNet22k).",
+ )
+ parser.add_argument(
+ "--batch-size",
+ type=int,
+ help="Batch size.",
+ )
+ parser.add_argument(
+ "--n-per-class-list",
+ nargs="+",
+ type=int,
+ help="Number to take per class",
+ )
+ parser.add_argument(
+ "--n-tries",
+ type=int,
+ help="Number of tries",
+ )
+ parser.set_defaults(
+ train_dataset_str="ImageNet:split=TRAIN",
+ val_dataset_str="ImageNet:split=VAL",
+ nb_knn=[10, 20, 100, 200],
+ temperature=0.07,
+ batch_size=256,
+ n_per_class_list=[-1],
+ n_tries=1,
+ )
+ return parser
+
+
+class KnnModule(torch.nn.Module):
+ """
+ Gets knn of test features from all processes on a chunk of the train features
+
+ Each rank gets a chunk of the train features as well as a chunk of the test features.
+ In `compute_neighbors`, for each rank one after the other, its chunk of test features
+ is sent to all devices, partial knns are computed with each chunk of train features
+ then collated back on the original device.
+ """
+
+ def __init__(self, train_features, train_labels, nb_knn, T, device, num_classes=1000):
+ super().__init__()
+
+ self.global_rank = distributed.get_global_rank()
+ self.global_size = distributed.get_global_size()
+
+ self.device = device
+ self.train_features_rank_T = train_features.chunk(self.global_size)[self.global_rank].T.to(self.device)
+ self.candidates = train_labels.chunk(self.global_size)[self.global_rank].view(1, -1).to(self.device)
+
+ self.nb_knn = nb_knn
+ self.max_k = max(self.nb_knn)
+ self.T = T
+ self.num_classes = num_classes
+
+ def _get_knn_sims_and_labels(self, similarity, train_labels):
+ topk_sims, indices = similarity.topk(self.max_k, largest=True, sorted=True)
+ neighbors_labels = torch.gather(train_labels, 1, indices)
+ return topk_sims, neighbors_labels
+
+ def _similarity_for_rank(self, features_rank, source_rank):
+ # Send the features from `source_rank` to all ranks
+ broadcast_shape = torch.tensor(features_rank.shape).to(self.device)
+ torch.distributed.broadcast(broadcast_shape, source_rank)
+
+ broadcasted = features_rank
+ if self.global_rank != source_rank:
+ broadcasted = torch.zeros(*broadcast_shape, dtype=features_rank.dtype, device=self.device)
+ torch.distributed.broadcast(broadcasted, source_rank)
+
+ # Compute the neighbors for `source_rank` among `train_features_rank_T`
+ similarity_rank = torch.mm(broadcasted, self.train_features_rank_T)
+ candidate_labels = self.candidates.expand(len(similarity_rank), -1)
+ return self._get_knn_sims_and_labels(similarity_rank, candidate_labels)
+
+ def _gather_all_knn_for_rank(self, topk_sims, neighbors_labels, target_rank):
+ # Gather all neighbors for `target_rank`
+ topk_sims_rank = retrieved_rank = None
+ if self.global_rank == target_rank:
+ topk_sims_rank = [torch.zeros_like(topk_sims) for _ in range(self.global_size)]
+ retrieved_rank = [torch.zeros_like(neighbors_labels) for _ in range(self.global_size)]
+
+ torch.distributed.gather(topk_sims, topk_sims_rank, dst=target_rank)
+ torch.distributed.gather(neighbors_labels, retrieved_rank, dst=target_rank)
+
+ if self.global_rank == target_rank:
+ # Perform a second top-k on the k * global_size retrieved neighbors
+ topk_sims_rank = torch.cat(topk_sims_rank, dim=1)
+ retrieved_rank = torch.cat(retrieved_rank, dim=1)
+ results = self._get_knn_sims_and_labels(topk_sims_rank, retrieved_rank)
+ return results
+ return None
+
+ def compute_neighbors(self, features_rank):
+ for rank in range(self.global_size):
+ topk_sims, neighbors_labels = self._similarity_for_rank(features_rank, rank)
+ results = self._gather_all_knn_for_rank(topk_sims, neighbors_labels, rank)
+ if results is not None:
+ topk_sims_rank, neighbors_labels_rank = results
+ return topk_sims_rank, neighbors_labels_rank
+
+ def forward(self, features_rank):
+ """
+ Compute the results on all values of `self.nb_knn` neighbors from the full `self.max_k`
+ """
+ assert all(k <= self.max_k for k in self.nb_knn)
+
+ topk_sims, neighbors_labels = self.compute_neighbors(features_rank)
+ batch_size = neighbors_labels.shape[0]
+ topk_sims_transform = softmax(topk_sims / self.T, 1)
+ matmul = torch.mul(
+ one_hot(neighbors_labels, num_classes=self.num_classes),
+ topk_sims_transform.view(batch_size, -1, 1),
+ )
+ probas_for_k = {k: torch.sum(matmul[:, :k, :], 1) for k in self.nb_knn}
+ return probas_for_k
+
+
+class DictKeysModule(torch.nn.Module):
+ def __init__(self, keys):
+ super().__init__()
+ self.keys = keys
+
+ def forward(self, features_dict, targets):
+ for k in self.keys:
+ features_dict = features_dict[k]
+ return {"preds": features_dict, "target": targets}
+
+
+def create_module_dict(*, module, n_per_class_list, n_tries, nb_knn, train_features, train_labels):
+ modules = {}
+ mapping = create_class_indices_mapping(train_labels)
+ for npc in n_per_class_list:
+ if npc < 0: # Only one try needed when using the full data
+ full_module = module(
+ train_features=train_features,
+ train_labels=train_labels,
+ nb_knn=nb_knn,
+ )
+ modules["full"] = ModuleDictWithForward({"1": full_module})
+ continue
+ all_tries = {}
+ for t in range(n_tries):
+ final_indices = filter_train(mapping, npc, seed=t)
+ k_list = list(set(nb_knn + [npc]))
+ k_list = sorted([el for el in k_list if el <= npc])
+ all_tries[str(t)] = module(
+ train_features=train_features[final_indices],
+ train_labels=train_labels[final_indices],
+ nb_knn=k_list,
+ )
+ modules[f"{npc} per class"] = ModuleDictWithForward(all_tries)
+
+ return ModuleDictWithForward(modules)
+
+
+def filter_train(mapping, n_per_class, seed):
+ torch.manual_seed(seed)
+ final_indices = []
+ for k in mapping.keys():
+ index = torch.randperm(len(mapping[k]))[:n_per_class]
+ final_indices.append(mapping[k][index])
+ return torch.cat(final_indices).squeeze()
+
+
+def create_class_indices_mapping(labels):
+ unique_labels, inverse = torch.unique(labels, return_inverse=True)
+ mapping = {unique_labels[i]: (inverse == i).nonzero() for i in range(len(unique_labels))}
+ return mapping
+
+
+class ModuleDictWithForward(torch.nn.ModuleDict):
+ def forward(self, *args, **kwargs):
+ return {k: module(*args, **kwargs) for k, module in self._modules.items()}
+
+
+def eval_knn(
+ model,
+ train_dataset,
+ val_dataset,
+ accuracy_averaging,
+ nb_knn,
+ temperature,
+ batch_size,
+ num_workers,
+ gather_on_cpu,
+ n_per_class_list=[-1],
+ n_tries=1,
+):
+ model = ModelWithNormalize(model)
+
+ logger.info("Extracting features for train set...")
+ train_features, train_labels = extract_features(
+ model, train_dataset, batch_size, num_workers, gather_on_cpu=gather_on_cpu
+ )
+ logger.info(f"Train features created, shape {train_features.shape}.")
+
+ val_dataloader = make_data_loader(
+ dataset=val_dataset,
+ batch_size=batch_size,
+ num_workers=num_workers,
+ sampler_type=SamplerType.DISTRIBUTED,
+ drop_last=False,
+ shuffle=False,
+ persistent_workers=True,
+ )
+ num_classes = train_labels.max() + 1
+ metric_collection = build_topk_accuracy_metric(accuracy_averaging, num_classes=num_classes)
+
+ device = torch.cuda.current_device()
+ partial_module = partial(KnnModule, T=temperature, device=device, num_classes=num_classes)
+ knn_module_dict = create_module_dict(
+ module=partial_module,
+ n_per_class_list=n_per_class_list,
+ n_tries=n_tries,
+ nb_knn=nb_knn,
+ train_features=train_features,
+ train_labels=train_labels,
+ )
+ postprocessors, metrics = {}, {}
+ for n_per_class, knn_module in knn_module_dict.items():
+ for t, knn_try in knn_module.items():
+ postprocessors = {
+ **postprocessors,
+ **{(n_per_class, t, k): DictKeysModule([n_per_class, t, k]) for k in knn_try.nb_knn},
+ }
+ metrics = {**metrics, **{(n_per_class, t, k): metric_collection.clone() for k in knn_try.nb_knn}}
+ model_with_knn = torch.nn.Sequential(model, knn_module_dict)
+
+ # ============ evaluation ... ============
+ logger.info("Start the k-NN classification.")
+ _, results_dict = evaluate(model_with_knn, val_dataloader, postprocessors, metrics, device)
+
+ # Averaging the results over the n tries for each value of n_per_class
+ for n_per_class, knn_module in knn_module_dict.items():
+ first_try = list(knn_module.keys())[0]
+ k_list = knn_module[first_try].nb_knn
+ for k in k_list:
+ keys = results_dict[(n_per_class, first_try, k)].keys() # keys are e.g. `top-1` and `top-5`
+ results_dict[(n_per_class, k)] = {
+ key: torch.mean(torch.stack([results_dict[(n_per_class, t, k)][key] for t in knn_module.keys()]))
+ for key in keys
+ }
+ for t in knn_module.keys():
+ del results_dict[(n_per_class, t, k)]
+
+ return results_dict
+
+
+def eval_knn_with_model(
+ model,
+ output_dir,
+ train_dataset_str="ImageNet:split=TRAIN",
+ val_dataset_str="ImageNet:split=VAL",
+ nb_knn=(10, 20, 100, 200),
+ temperature=0.07,
+ autocast_dtype=torch.float,
+ accuracy_averaging=AccuracyAveraging.MEAN_ACCURACY,
+ transform=None,
+ gather_on_cpu=False,
+ batch_size=256,
+ num_workers=5,
+ n_per_class_list=[-1],
+ n_tries=1,
+):
+ transform = transform or make_classification_eval_transform()
+
+ train_dataset = make_dataset(
+ dataset_str=train_dataset_str,
+ transform=transform,
+ )
+ val_dataset = make_dataset(
+ dataset_str=val_dataset_str,
+ transform=transform,
+ )
+
+ with torch.cuda.amp.autocast(dtype=autocast_dtype):
+ results_dict_knn = eval_knn(
+ model=model,
+ train_dataset=train_dataset,
+ val_dataset=val_dataset,
+ accuracy_averaging=accuracy_averaging,
+ nb_knn=nb_knn,
+ temperature=temperature,
+ batch_size=batch_size,
+ num_workers=num_workers,
+ gather_on_cpu=gather_on_cpu,
+ n_per_class_list=n_per_class_list,
+ n_tries=n_tries,
+ )
+
+ results_dict = {}
+ if distributed.is_main_process():
+ for knn_ in results_dict_knn.keys():
+ top1 = results_dict_knn[knn_]["top-1"].item() * 100.0
+ top5 = results_dict_knn[knn_]["top-5"].item() * 100.0
+ results_dict[f"{knn_} Top 1"] = top1
+ results_dict[f"{knn_} Top 5"] = top5
+ logger.info(f"{knn_} classifier result: Top1: {top1:.2f} Top5: {top5:.2f}")
+
+ metrics_file_path = os.path.join(output_dir, "results_eval_knn.json")
+ with open(metrics_file_path, "a") as f:
+ for k, v in results_dict.items():
+ f.write(json.dumps({k: v}) + "\n")
+
+ if distributed.is_enabled():
+ torch.distributed.barrier()
+ return results_dict
+
+
+def main(args):
+ model, autocast_dtype = setup_and_build_model(args)
+ eval_knn_with_model(
+ model=model,
+ output_dir=args.output_dir,
+ train_dataset_str=args.train_dataset_str,
+ val_dataset_str=args.val_dataset_str,
+ nb_knn=args.nb_knn,
+ temperature=args.temperature,
+ autocast_dtype=autocast_dtype,
+ accuracy_averaging=AccuracyAveraging.MEAN_ACCURACY,
+ transform=None,
+ gather_on_cpu=args.gather_on_cpu,
+ batch_size=args.batch_size,
+ num_workers=5,
+ n_per_class_list=args.n_per_class_list,
+ n_tries=args.n_tries,
+ )
+ return 0
+
+
+if __name__ == "__main__":
+ description = "DINOv2 k-NN evaluation"
+ args_parser = get_args_parser(description=description)
+ args = args_parser.parse_args()
+ sys.exit(main(args))
diff --git a/mapper/models/dinov2/eval/linear.py b/mapper/models/dinov2/eval/linear.py
new file mode 100644
index 0000000000000000000000000000000000000000..1bd4c5de5a041be8a188f007257d1e91b6d6921e
--- /dev/null
+++ b/mapper/models/dinov2/eval/linear.py
@@ -0,0 +1,625 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+import argparse
+from functools import partial
+import json
+import logging
+import os
+import sys
+from typing import List, Optional
+
+import numpy as np
+import torch
+import torch.nn as nn
+from torch.nn.parallel import DistributedDataParallel
+from fvcore.common.checkpoint import Checkpointer, PeriodicCheckpointer
+
+from dinov2.data import SamplerType, make_data_loader, make_dataset
+from dinov2.data.transforms import make_classification_eval_transform, make_classification_train_transform
+import dinov2.distributed as distributed
+from dinov2.eval.metrics import MetricType, build_metric
+from dinov2.eval.setup import get_args_parser as get_setup_args_parser
+from dinov2.eval.setup import setup_and_build_model
+from dinov2.eval.utils import ModelWithIntermediateLayers, evaluate
+from dinov2.logging import MetricLogger
+
+
+logger = logging.getLogger("dinov2")
+
+
+def get_args_parser(
+ description: Optional[str] = None,
+ parents: Optional[List[argparse.ArgumentParser]] = None,
+ add_help: bool = True,
+):
+ parents = parents or []
+ setup_args_parser = get_setup_args_parser(parents=parents, add_help=False)
+ parents = [setup_args_parser]
+ parser = argparse.ArgumentParser(
+ description=description,
+ parents=parents,
+ add_help=add_help,
+ )
+ parser.add_argument(
+ "--train-dataset",
+ dest="train_dataset_str",
+ type=str,
+ help="Training dataset",
+ )
+ parser.add_argument(
+ "--val-dataset",
+ dest="val_dataset_str",
+ type=str,
+ help="Validation dataset",
+ )
+ parser.add_argument(
+ "--test-datasets",
+ dest="test_dataset_strs",
+ type=str,
+ nargs="+",
+ help="Test datasets, none to reuse the validation dataset",
+ )
+ parser.add_argument(
+ "--epochs",
+ type=int,
+ help="Number of training epochs",
+ )
+ parser.add_argument(
+ "--batch-size",
+ type=int,
+ help="Batch Size (per GPU)",
+ )
+ parser.add_argument(
+ "--num-workers",
+ type=int,
+ help="Number de Workers",
+ )
+ parser.add_argument(
+ "--epoch-length",
+ type=int,
+ help="Length of an epoch in number of iterations",
+ )
+ parser.add_argument(
+ "--save-checkpoint-frequency",
+ type=int,
+ help="Number of epochs between two named checkpoint saves.",
+ )
+ parser.add_argument(
+ "--eval-period-iterations",
+ type=int,
+ help="Number of iterations between two evaluations.",
+ )
+ parser.add_argument(
+ "--learning-rates",
+ nargs="+",
+ type=float,
+ help="Learning rates to grid search.",
+ )
+ parser.add_argument(
+ "--no-resume",
+ action="store_true",
+ help="Whether to not resume from existing checkpoints",
+ )
+ parser.add_argument(
+ "--val-metric-type",
+ type=MetricType,
+ choices=list(MetricType),
+ help="Validation metric",
+ )
+ parser.add_argument(
+ "--test-metric-types",
+ type=MetricType,
+ choices=list(MetricType),
+ nargs="+",
+ help="Evaluation metric",
+ )
+ parser.add_argument(
+ "--classifier-fpath",
+ type=str,
+ help="Path to a file containing pretrained linear classifiers",
+ )
+ parser.add_argument(
+ "--val-class-mapping-fpath",
+ type=str,
+ help="Path to a file containing a mapping to adjust classifier outputs",
+ )
+ parser.add_argument(
+ "--test-class-mapping-fpaths",
+ nargs="+",
+ type=str,
+ help="Path to a file containing a mapping to adjust classifier outputs",
+ )
+ parser.set_defaults(
+ train_dataset_str="ImageNet:split=TRAIN",
+ val_dataset_str="ImageNet:split=VAL",
+ test_dataset_strs=None,
+ epochs=10,
+ batch_size=128,
+ num_workers=8,
+ epoch_length=1250,
+ save_checkpoint_frequency=20,
+ eval_period_iterations=1250,
+ learning_rates=[1e-5, 2e-5, 5e-5, 1e-4, 2e-4, 5e-4, 1e-3, 2e-3, 5e-3, 1e-2, 2e-2, 5e-2, 0.1],
+ val_metric_type=MetricType.MEAN_ACCURACY,
+ test_metric_types=None,
+ classifier_fpath=None,
+ val_class_mapping_fpath=None,
+ test_class_mapping_fpaths=[None],
+ )
+ return parser
+
+
+def has_ddp_wrapper(m: nn.Module) -> bool:
+ return isinstance(m, DistributedDataParallel)
+
+
+def remove_ddp_wrapper(m: nn.Module) -> nn.Module:
+ return m.module if has_ddp_wrapper(m) else m
+
+
+def _pad_and_collate(batch):
+ maxlen = max(len(targets) for image, targets in batch)
+ padded_batch = [
+ (image, np.pad(targets, (0, maxlen - len(targets)), constant_values=-1)) for image, targets in batch
+ ]
+ return torch.utils.data.default_collate(padded_batch)
+
+
+def create_linear_input(x_tokens_list, use_n_blocks, use_avgpool):
+ intermediate_output = x_tokens_list[-use_n_blocks:]
+ output = torch.cat([class_token for _, class_token in intermediate_output], dim=-1)
+ if use_avgpool:
+ output = torch.cat(
+ (
+ output,
+ torch.mean(intermediate_output[-1][0], dim=1), # patch tokens
+ ),
+ dim=-1,
+ )
+ output = output.reshape(output.shape[0], -1)
+ return output.float()
+
+
+class LinearClassifier(nn.Module):
+ """Linear layer to train on top of frozen features"""
+
+ def __init__(self, out_dim, use_n_blocks, use_avgpool, num_classes=1000):
+ super().__init__()
+ self.out_dim = out_dim
+ self.use_n_blocks = use_n_blocks
+ self.use_avgpool = use_avgpool
+ self.num_classes = num_classes
+ self.linear = nn.Linear(out_dim, num_classes)
+ self.linear.weight.data.normal_(mean=0.0, std=0.01)
+ self.linear.bias.data.zero_()
+
+ def forward(self, x_tokens_list):
+ output = create_linear_input(x_tokens_list, self.use_n_blocks, self.use_avgpool)
+ return self.linear(output)
+
+
+class AllClassifiers(nn.Module):
+ def __init__(self, classifiers_dict):
+ super().__init__()
+ self.classifiers_dict = nn.ModuleDict()
+ self.classifiers_dict.update(classifiers_dict)
+
+ def forward(self, inputs):
+ return {k: v.forward(inputs) for k, v in self.classifiers_dict.items()}
+
+ def __len__(self):
+ return len(self.classifiers_dict)
+
+
+class LinearPostprocessor(nn.Module):
+ def __init__(self, linear_classifier, class_mapping=None):
+ super().__init__()
+ self.linear_classifier = linear_classifier
+ self.register_buffer("class_mapping", None if class_mapping is None else torch.LongTensor(class_mapping))
+
+ def forward(self, samples, targets):
+ preds = self.linear_classifier(samples)
+ return {
+ "preds": preds[:, self.class_mapping] if self.class_mapping is not None else preds,
+ "target": targets,
+ }
+
+
+def scale_lr(learning_rates, batch_size):
+ return learning_rates * (batch_size * distributed.get_global_size()) / 256.0
+
+
+def setup_linear_classifiers(sample_output, n_last_blocks_list, learning_rates, batch_size, num_classes=1000):
+ linear_classifiers_dict = nn.ModuleDict()
+ optim_param_groups = []
+ for n in n_last_blocks_list:
+ for avgpool in [False, True]:
+ for _lr in learning_rates:
+ lr = scale_lr(_lr, batch_size)
+ out_dim = create_linear_input(sample_output, use_n_blocks=n, use_avgpool=avgpool).shape[1]
+ linear_classifier = LinearClassifier(
+ out_dim, use_n_blocks=n, use_avgpool=avgpool, num_classes=num_classes
+ )
+ linear_classifier = linear_classifier.cuda()
+ linear_classifiers_dict[
+ f"classifier_{n}_blocks_avgpool_{avgpool}_lr_{lr:.5f}".replace(".", "_")
+ ] = linear_classifier
+ optim_param_groups.append({"params": linear_classifier.parameters(), "lr": lr})
+
+ linear_classifiers = AllClassifiers(linear_classifiers_dict)
+ if distributed.is_enabled():
+ linear_classifiers = nn.parallel.DistributedDataParallel(linear_classifiers)
+
+ return linear_classifiers, optim_param_groups
+
+
+@torch.no_grad()
+def evaluate_linear_classifiers(
+ feature_model,
+ linear_classifiers,
+ data_loader,
+ metric_type,
+ metrics_file_path,
+ training_num_classes,
+ iteration,
+ prefixstring="",
+ class_mapping=None,
+ best_classifier_on_val=None,
+):
+ logger.info("running validation !")
+
+ num_classes = len(class_mapping) if class_mapping is not None else training_num_classes
+ metric = build_metric(metric_type, num_classes=num_classes)
+ postprocessors = {k: LinearPostprocessor(v, class_mapping) for k, v in linear_classifiers.classifiers_dict.items()}
+ metrics = {k: metric.clone() for k in linear_classifiers.classifiers_dict}
+
+ _, results_dict_temp = evaluate(
+ feature_model,
+ data_loader,
+ postprocessors,
+ metrics,
+ torch.cuda.current_device(),
+ )
+
+ logger.info("")
+ results_dict = {}
+ max_accuracy = 0
+ best_classifier = ""
+ for i, (classifier_string, metric) in enumerate(results_dict_temp.items()):
+ logger.info(f"{prefixstring} -- Classifier: {classifier_string} * {metric}")
+ if (
+ best_classifier_on_val is None and metric["top-1"].item() > max_accuracy
+ ) or classifier_string == best_classifier_on_val:
+ max_accuracy = metric["top-1"].item()
+ best_classifier = classifier_string
+
+ results_dict["best_classifier"] = {"name": best_classifier, "accuracy": max_accuracy}
+
+ logger.info(f"best classifier: {results_dict['best_classifier']}")
+
+ if distributed.is_main_process():
+ with open(metrics_file_path, "a") as f:
+ f.write(f"iter: {iteration}\n")
+ for k, v in results_dict.items():
+ f.write(json.dumps({k: v}) + "\n")
+ f.write("\n")
+
+ return results_dict
+
+
+def eval_linear(
+ *,
+ feature_model,
+ linear_classifiers,
+ train_data_loader,
+ val_data_loader,
+ metrics_file_path,
+ optimizer,
+ scheduler,
+ output_dir,
+ max_iter,
+ checkpoint_period, # In number of iter, creates a new file every period
+ running_checkpoint_period, # Period to update main checkpoint file
+ eval_period,
+ metric_type,
+ training_num_classes,
+ resume=True,
+ classifier_fpath=None,
+ val_class_mapping=None,
+):
+ checkpointer = Checkpointer(linear_classifiers, output_dir, optimizer=optimizer, scheduler=scheduler)
+ start_iter = checkpointer.resume_or_load(classifier_fpath or "", resume=resume).get("iteration", -1) + 1
+
+ periodic_checkpointer = PeriodicCheckpointer(checkpointer, checkpoint_period, max_iter=max_iter)
+ iteration = start_iter
+ logger.info("Starting training from iteration {}".format(start_iter))
+ metric_logger = MetricLogger(delimiter=" ")
+ header = "Training"
+
+ for data, labels in metric_logger.log_every(
+ train_data_loader,
+ 10,
+ header,
+ max_iter,
+ start_iter,
+ ):
+ data = data.cuda(non_blocking=True)
+ labels = labels.cuda(non_blocking=True)
+
+ features = feature_model(data)
+ outputs = linear_classifiers(features)
+
+ losses = {f"loss_{k}": nn.CrossEntropyLoss()(v, labels) for k, v in outputs.items()}
+ loss = sum(losses.values())
+
+ # compute the gradients
+ optimizer.zero_grad()
+ loss.backward()
+
+ # step
+ optimizer.step()
+ scheduler.step()
+
+ # log
+ if iteration % 10 == 0:
+ torch.cuda.synchronize()
+ metric_logger.update(loss=loss.item())
+ metric_logger.update(lr=optimizer.param_groups[0]["lr"])
+ print("lr", optimizer.param_groups[0]["lr"])
+
+ if iteration - start_iter > 5:
+ if iteration % running_checkpoint_period == 0:
+ torch.cuda.synchronize()
+ if distributed.is_main_process():
+ logger.info("Checkpointing running_checkpoint")
+ periodic_checkpointer.save("running_checkpoint_linear_eval", iteration=iteration)
+ torch.cuda.synchronize()
+ periodic_checkpointer.step(iteration)
+
+ if eval_period > 0 and (iteration + 1) % eval_period == 0 and iteration != max_iter - 1:
+ _ = evaluate_linear_classifiers(
+ feature_model=feature_model,
+ linear_classifiers=remove_ddp_wrapper(linear_classifiers),
+ data_loader=val_data_loader,
+ metrics_file_path=metrics_file_path,
+ prefixstring=f"ITER: {iteration}",
+ metric_type=metric_type,
+ training_num_classes=training_num_classes,
+ iteration=iteration,
+ class_mapping=val_class_mapping,
+ )
+ torch.cuda.synchronize()
+
+ iteration = iteration + 1
+
+ val_results_dict = evaluate_linear_classifiers(
+ feature_model=feature_model,
+ linear_classifiers=remove_ddp_wrapper(linear_classifiers),
+ data_loader=val_data_loader,
+ metrics_file_path=metrics_file_path,
+ metric_type=metric_type,
+ training_num_classes=training_num_classes,
+ iteration=iteration,
+ class_mapping=val_class_mapping,
+ )
+ return val_results_dict, feature_model, linear_classifiers, iteration
+
+
+def make_eval_data_loader(test_dataset_str, batch_size, num_workers, metric_type):
+ test_dataset = make_dataset(
+ dataset_str=test_dataset_str,
+ transform=make_classification_eval_transform(),
+ )
+ test_data_loader = make_data_loader(
+ dataset=test_dataset,
+ batch_size=batch_size,
+ num_workers=num_workers,
+ sampler_type=SamplerType.DISTRIBUTED,
+ drop_last=False,
+ shuffle=False,
+ persistent_workers=False,
+ collate_fn=_pad_and_collate if metric_type == MetricType.IMAGENET_REAL_ACCURACY else None,
+ )
+ return test_data_loader
+
+
+def test_on_datasets(
+ feature_model,
+ linear_classifiers,
+ test_dataset_strs,
+ batch_size,
+ num_workers,
+ test_metric_types,
+ metrics_file_path,
+ training_num_classes,
+ iteration,
+ best_classifier_on_val,
+ prefixstring="",
+ test_class_mappings=[None],
+):
+ results_dict = {}
+ for test_dataset_str, class_mapping, metric_type in zip(test_dataset_strs, test_class_mappings, test_metric_types):
+ logger.info(f"Testing on {test_dataset_str}")
+ test_data_loader = make_eval_data_loader(test_dataset_str, batch_size, num_workers, metric_type)
+ dataset_results_dict = evaluate_linear_classifiers(
+ feature_model,
+ remove_ddp_wrapper(linear_classifiers),
+ test_data_loader,
+ metric_type,
+ metrics_file_path,
+ training_num_classes,
+ iteration,
+ prefixstring="",
+ class_mapping=class_mapping,
+ best_classifier_on_val=best_classifier_on_val,
+ )
+ results_dict[f"{test_dataset_str}_accuracy"] = 100.0 * dataset_results_dict["best_classifier"]["accuracy"]
+ return results_dict
+
+
+def run_eval_linear(
+ model,
+ output_dir,
+ train_dataset_str,
+ val_dataset_str,
+ batch_size,
+ epochs,
+ epoch_length,
+ num_workers,
+ save_checkpoint_frequency,
+ eval_period_iterations,
+ learning_rates,
+ autocast_dtype,
+ test_dataset_strs=None,
+ resume=True,
+ classifier_fpath=None,
+ val_class_mapping_fpath=None,
+ test_class_mapping_fpaths=[None],
+ val_metric_type=MetricType.MEAN_ACCURACY,
+ test_metric_types=None,
+):
+ seed = 0
+
+ if test_dataset_strs is None:
+ test_dataset_strs = [val_dataset_str]
+ if test_metric_types is None:
+ test_metric_types = [val_metric_type] * len(test_dataset_strs)
+ else:
+ assert len(test_metric_types) == len(test_dataset_strs)
+ assert len(test_dataset_strs) == len(test_class_mapping_fpaths)
+
+ train_transform = make_classification_train_transform()
+ train_dataset = make_dataset(
+ dataset_str=train_dataset_str,
+ transform=train_transform,
+ )
+ training_num_classes = len(torch.unique(torch.Tensor(train_dataset.get_targets().astype(int))))
+ sampler_type = SamplerType.SHARDED_INFINITE
+ # sampler_type = SamplerType.INFINITE
+
+ n_last_blocks_list = [1, 4]
+ n_last_blocks = max(n_last_blocks_list)
+ autocast_ctx = partial(torch.cuda.amp.autocast, enabled=True, dtype=autocast_dtype)
+ feature_model = ModelWithIntermediateLayers(model, n_last_blocks, autocast_ctx)
+ sample_output = feature_model(train_dataset[0][0].unsqueeze(0).cuda())
+
+ linear_classifiers, optim_param_groups = setup_linear_classifiers(
+ sample_output,
+ n_last_blocks_list,
+ learning_rates,
+ batch_size,
+ training_num_classes,
+ )
+
+ optimizer = torch.optim.SGD(optim_param_groups, momentum=0.9, weight_decay=0)
+ max_iter = epochs * epoch_length
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, max_iter, eta_min=0)
+ checkpointer = Checkpointer(linear_classifiers, output_dir, optimizer=optimizer, scheduler=scheduler)
+ start_iter = checkpointer.resume_or_load(classifier_fpath or "", resume=resume).get("iteration", -1) + 1
+ train_data_loader = make_data_loader(
+ dataset=train_dataset,
+ batch_size=batch_size,
+ num_workers=num_workers,
+ shuffle=True,
+ seed=seed,
+ sampler_type=sampler_type,
+ sampler_advance=start_iter,
+ drop_last=True,
+ persistent_workers=True,
+ )
+ val_data_loader = make_eval_data_loader(val_dataset_str, batch_size, num_workers, val_metric_type)
+
+ checkpoint_period = save_checkpoint_frequency * epoch_length
+
+ if val_class_mapping_fpath is not None:
+ logger.info(f"Using class mapping from {val_class_mapping_fpath}")
+ val_class_mapping = np.load(val_class_mapping_fpath)
+ else:
+ val_class_mapping = None
+
+ test_class_mappings = []
+ for class_mapping_fpath in test_class_mapping_fpaths:
+ if class_mapping_fpath is not None and class_mapping_fpath != "None":
+ logger.info(f"Using class mapping from {class_mapping_fpath}")
+ class_mapping = np.load(class_mapping_fpath)
+ else:
+ class_mapping = None
+ test_class_mappings.append(class_mapping)
+
+ metrics_file_path = os.path.join(output_dir, "results_eval_linear.json")
+ val_results_dict, feature_model, linear_classifiers, iteration = eval_linear(
+ feature_model=feature_model,
+ linear_classifiers=linear_classifiers,
+ train_data_loader=train_data_loader,
+ val_data_loader=val_data_loader,
+ metrics_file_path=metrics_file_path,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ output_dir=output_dir,
+ max_iter=max_iter,
+ checkpoint_period=checkpoint_period,
+ running_checkpoint_period=epoch_length,
+ eval_period=eval_period_iterations,
+ metric_type=val_metric_type,
+ training_num_classes=training_num_classes,
+ resume=resume,
+ val_class_mapping=val_class_mapping,
+ classifier_fpath=classifier_fpath,
+ )
+ results_dict = {}
+ if len(test_dataset_strs) > 1 or test_dataset_strs[0] != val_dataset_str:
+ results_dict = test_on_datasets(
+ feature_model,
+ linear_classifiers,
+ test_dataset_strs,
+ batch_size,
+ 0, # num_workers,
+ test_metric_types,
+ metrics_file_path,
+ training_num_classes,
+ iteration,
+ val_results_dict["best_classifier"]["name"],
+ prefixstring="",
+ test_class_mappings=test_class_mappings,
+ )
+ results_dict["best_classifier"] = val_results_dict["best_classifier"]["name"]
+ results_dict[f"{val_dataset_str}_accuracy"] = 100.0 * val_results_dict["best_classifier"]["accuracy"]
+ logger.info("Test Results Dict " + str(results_dict))
+
+ return results_dict
+
+
+def main(args):
+ model, autocast_dtype = setup_and_build_model(args)
+ run_eval_linear(
+ model=model,
+ output_dir=args.output_dir,
+ train_dataset_str=args.train_dataset_str,
+ val_dataset_str=args.val_dataset_str,
+ test_dataset_strs=args.test_dataset_strs,
+ batch_size=args.batch_size,
+ epochs=args.epochs,
+ epoch_length=args.epoch_length,
+ num_workers=args.num_workers,
+ save_checkpoint_frequency=args.save_checkpoint_frequency,
+ eval_period_iterations=args.eval_period_iterations,
+ learning_rates=args.learning_rates,
+ autocast_dtype=autocast_dtype,
+ resume=not args.no_resume,
+ classifier_fpath=args.classifier_fpath,
+ val_metric_type=args.val_metric_type,
+ test_metric_types=args.test_metric_types,
+ val_class_mapping_fpath=args.val_class_mapping_fpath,
+ test_class_mapping_fpaths=args.test_class_mapping_fpaths,
+ )
+ return 0
+
+
+if __name__ == "__main__":
+ description = "DINOv2 linear evaluation"
+ args_parser = get_args_parser(description=description)
+ args = args_parser.parse_args()
+ sys.exit(main(args))
diff --git a/mapper/models/dinov2/eval/log_regression.py b/mapper/models/dinov2/eval/log_regression.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f36ec134e0ce25697428a0b3f21cdc2f0145645
--- /dev/null
+++ b/mapper/models/dinov2/eval/log_regression.py
@@ -0,0 +1,444 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+import argparse
+import gc
+import logging
+import sys
+import time
+from typing import List, Optional
+
+from cuml.linear_model import LogisticRegression
+import torch
+import torch.backends.cudnn as cudnn
+import torch.distributed
+from torch import nn
+from torch.utils.data import TensorDataset
+from torchmetrics import MetricTracker
+
+from dinov2.data import make_dataset
+from dinov2.data.transforms import make_classification_eval_transform
+from dinov2.distributed import get_global_rank, get_global_size
+from dinov2.eval.metrics import MetricType, build_metric
+from dinov2.eval.setup import get_args_parser as get_setup_args_parser
+from dinov2.eval.setup import setup_and_build_model
+from dinov2.eval.utils import evaluate, extract_features
+from dinov2.utils.dtype import as_torch_dtype
+
+
+logger = logging.getLogger("dinov2")
+
+DEFAULT_MAX_ITER = 1_000
+C_POWER_RANGE = torch.linspace(-6, 5, 45)
+_CPU_DEVICE = torch.device("cpu")
+
+
+def get_args_parser(
+ description: Optional[str] = None,
+ parents: Optional[List[argparse.ArgumentParser]] = None,
+ add_help: bool = True,
+):
+ parents = parents or []
+ setup_args_parser = get_setup_args_parser(parents=parents, add_help=False)
+ parents = [setup_args_parser]
+ parser = argparse.ArgumentParser(
+ description=description,
+ parents=parents,
+ add_help=add_help,
+ )
+ parser.add_argument(
+ "--train-dataset",
+ dest="train_dataset_str",
+ type=str,
+ help="Training dataset",
+ )
+ parser.add_argument(
+ "--val-dataset",
+ dest="val_dataset_str",
+ type=str,
+ help="Validation dataset",
+ )
+ parser.add_argument(
+ "--finetune-dataset-str",
+ dest="finetune_dataset_str",
+ type=str,
+ help="Fine-tuning dataset",
+ )
+ parser.add_argument(
+ "--finetune-on-val",
+ action="store_true",
+ help="If there is no finetune dataset, whether to choose the "
+ "hyperparameters on the val set instead of 10%% of the train dataset",
+ )
+ parser.add_argument(
+ "--metric-type",
+ type=MetricType,
+ choices=list(MetricType),
+ help="Metric type",
+ )
+ parser.add_argument(
+ "--train-features-device",
+ type=str,
+ help="Device to gather train features (cpu, cuda, cuda:0, etc.), default: %(default)s",
+ )
+ parser.add_argument(
+ "--train-dtype",
+ type=str,
+ help="Data type to convert the train features to (default: %(default)s)",
+ )
+ parser.add_argument(
+ "--max-train-iters",
+ type=int,
+ help="Maximum number of train iterations (default: %(default)s)",
+ )
+ parser.set_defaults(
+ train_dataset_str="ImageNet:split=TRAIN",
+ val_dataset_str="ImageNet:split=VAL",
+ finetune_dataset_str=None,
+ metric_type=MetricType.MEAN_ACCURACY,
+ train_features_device="cpu",
+ train_dtype="float64",
+ max_train_iters=DEFAULT_MAX_ITER,
+ finetune_on_val=False,
+ )
+ return parser
+
+
+class LogRegModule(nn.Module):
+ def __init__(
+ self,
+ C,
+ max_iter=DEFAULT_MAX_ITER,
+ dtype=torch.float64,
+ device=_CPU_DEVICE,
+ ):
+ super().__init__()
+ self.dtype = dtype
+ self.device = device
+ self.estimator = LogisticRegression(
+ penalty="l2",
+ C=C,
+ max_iter=max_iter,
+ output_type="numpy",
+ tol=1e-12,
+ linesearch_max_iter=50,
+ )
+
+ def forward(self, samples, targets):
+ samples_device = samples.device
+ samples = samples.to(dtype=self.dtype, device=self.device)
+ if self.device == _CPU_DEVICE:
+ samples = samples.numpy()
+ probas = self.estimator.predict_proba(samples)
+ return {"preds": torch.from_numpy(probas).to(samples_device), "target": targets}
+
+ def fit(self, train_features, train_labels):
+ train_features = train_features.to(dtype=self.dtype, device=self.device)
+ train_labels = train_labels.to(dtype=self.dtype, device=self.device)
+ if self.device == _CPU_DEVICE:
+ # both cuML and sklearn only work with numpy arrays on CPU
+ train_features = train_features.numpy()
+ train_labels = train_labels.numpy()
+ self.estimator.fit(train_features, train_labels)
+
+
+def evaluate_model(*, logreg_model, logreg_metric, test_data_loader, device):
+ postprocessors = {"metrics": logreg_model}
+ metrics = {"metrics": logreg_metric}
+ return evaluate(nn.Identity(), test_data_loader, postprocessors, metrics, device)
+
+
+def train_for_C(*, C, max_iter, train_features, train_labels, dtype=torch.float64, device=_CPU_DEVICE):
+ logreg_model = LogRegModule(C, max_iter=max_iter, dtype=dtype, device=device)
+ logreg_model.fit(train_features, train_labels)
+ return logreg_model
+
+
+def train_and_evaluate(
+ *,
+ C,
+ max_iter,
+ train_features,
+ train_labels,
+ logreg_metric,
+ test_data_loader,
+ train_dtype=torch.float64,
+ train_features_device,
+ eval_device,
+):
+ logreg_model = train_for_C(
+ C=C,
+ max_iter=max_iter,
+ train_features=train_features,
+ train_labels=train_labels,
+ dtype=train_dtype,
+ device=train_features_device,
+ )
+ return evaluate_model(
+ logreg_model=logreg_model,
+ logreg_metric=logreg_metric,
+ test_data_loader=test_data_loader,
+ device=eval_device,
+ )
+
+
+def sweep_C_values(
+ *,
+ train_features,
+ train_labels,
+ test_data_loader,
+ metric_type,
+ num_classes,
+ train_dtype=torch.float64,
+ train_features_device=_CPU_DEVICE,
+ max_train_iters=DEFAULT_MAX_ITER,
+):
+ if metric_type == MetricType.PER_CLASS_ACCURACY:
+ # If we want to output per-class accuracy, we select the hyperparameters with mean per class
+ metric_type = MetricType.MEAN_PER_CLASS_ACCURACY
+ logreg_metric = build_metric(metric_type, num_classes=num_classes)
+ metric_tracker = MetricTracker(logreg_metric, maximize=True)
+ ALL_C = 10**C_POWER_RANGE
+ logreg_models = {}
+
+ train_features = train_features.to(dtype=train_dtype, device=train_features_device)
+ train_labels = train_labels.to(device=train_features_device)
+
+ for i in range(get_global_rank(), len(ALL_C), get_global_size()):
+ C = ALL_C[i].item()
+ logger.info(
+ f"Training for C = {C:.5f}, dtype={train_dtype}, "
+ f"features: {train_features.shape}, {train_features.dtype}, "
+ f"labels: {train_labels.shape}, {train_labels.dtype}"
+ )
+ logreg_models[C] = train_for_C(
+ C=C,
+ max_iter=max_train_iters,
+ train_features=train_features,
+ train_labels=train_labels,
+ dtype=train_dtype,
+ device=train_features_device,
+ )
+
+ gather_list = [None for _ in range(get_global_size())]
+ torch.distributed.all_gather_object(gather_list, logreg_models)
+
+ logreg_models_gathered = {}
+ for logreg_dict in gather_list:
+ logreg_models_gathered.update(logreg_dict)
+
+ for i in range(len(ALL_C)):
+ metric_tracker.increment()
+ C = ALL_C[i].item()
+ evals = evaluate_model(
+ logreg_model=logreg_models_gathered[C],
+ logreg_metric=metric_tracker,
+ test_data_loader=test_data_loader,
+ device=torch.cuda.current_device(),
+ )
+ logger.info(f"Trained for C = {C:.5f}, accuracies = {evals}")
+
+ best_stats, which_epoch = metric_tracker.best_metric(return_step=True)
+ best_stats_100 = {k: 100.0 * v for k, v in best_stats.items()}
+ if which_epoch["top-1"] == i:
+ best_C = C
+ logger.info(f"Sweep best {best_stats_100}, best C = {best_C:.6f}")
+
+ return best_stats, best_C
+
+
+def eval_log_regression(
+ *,
+ model,
+ train_dataset,
+ val_dataset,
+ finetune_dataset,
+ metric_type,
+ batch_size,
+ num_workers,
+ finetune_on_val=False,
+ train_dtype=torch.float64,
+ train_features_device=_CPU_DEVICE,
+ max_train_iters=DEFAULT_MAX_ITER,
+):
+ """
+ Implements the "standard" process for log regression evaluation:
+ The value of C is chosen by training on train_dataset and evaluating on
+ finetune_dataset. Then, the final model is trained on a concatenation of
+ train_dataset and finetune_dataset, and is evaluated on val_dataset.
+ If there is no finetune_dataset, the value of C is the one that yields
+ the best results on a random 10% subset of the train dataset
+ """
+
+ start = time.time()
+
+ train_features, train_labels = extract_features(
+ model, train_dataset, batch_size, num_workers, gather_on_cpu=(train_features_device == _CPU_DEVICE)
+ )
+ val_features, val_labels = extract_features(
+ model, val_dataset, batch_size, num_workers, gather_on_cpu=(train_features_device == _CPU_DEVICE)
+ )
+ val_data_loader = torch.utils.data.DataLoader(
+ TensorDataset(val_features, val_labels),
+ batch_size=batch_size,
+ drop_last=False,
+ num_workers=0,
+ persistent_workers=False,
+ )
+
+ if finetune_dataset is None and finetune_on_val:
+ logger.info("Choosing hyperparameters on the val dataset")
+ finetune_features, finetune_labels = val_features, val_labels
+ elif finetune_dataset is None and not finetune_on_val:
+ logger.info("Choosing hyperparameters on 10% of the train dataset")
+ torch.manual_seed(0)
+ indices = torch.randperm(len(train_features), device=train_features.device)
+ finetune_index = indices[: len(train_features) // 10]
+ train_index = indices[len(train_features) // 10 :]
+ finetune_features, finetune_labels = train_features[finetune_index], train_labels[finetune_index]
+ train_features, train_labels = train_features[train_index], train_labels[train_index]
+ else:
+ logger.info("Choosing hyperparameters on the finetune dataset")
+ finetune_features, finetune_labels = extract_features(
+ model, finetune_dataset, batch_size, num_workers, gather_on_cpu=(train_features_device == _CPU_DEVICE)
+ )
+ # release the model - free GPU memory
+ del model
+ gc.collect()
+ torch.cuda.empty_cache()
+ finetune_data_loader = torch.utils.data.DataLoader(
+ TensorDataset(finetune_features, finetune_labels),
+ batch_size=batch_size,
+ drop_last=False,
+ )
+
+ if len(train_labels.shape) > 1:
+ num_classes = train_labels.shape[1]
+ else:
+ num_classes = train_labels.max() + 1
+
+ logger.info("Using cuML for logistic regression")
+
+ best_stats, best_C = sweep_C_values(
+ train_features=train_features,
+ train_labels=train_labels,
+ test_data_loader=finetune_data_loader,
+ metric_type=metric_type,
+ num_classes=num_classes,
+ train_dtype=train_dtype,
+ train_features_device=train_features_device,
+ max_train_iters=max_train_iters,
+ )
+
+ if not finetune_on_val:
+ logger.info("Best parameter found, concatenating features")
+ train_features = torch.cat((train_features, finetune_features))
+ train_labels = torch.cat((train_labels, finetune_labels))
+
+ logger.info("Training final model")
+ logreg_metric = build_metric(metric_type, num_classes=num_classes)
+ evals = train_and_evaluate(
+ C=best_C,
+ max_iter=max_train_iters,
+ train_features=train_features,
+ train_labels=train_labels,
+ logreg_metric=logreg_metric.clone(),
+ test_data_loader=val_data_loader,
+ eval_device=torch.cuda.current_device(),
+ train_dtype=train_dtype,
+ train_features_device=train_features_device,
+ )
+
+ best_stats = evals[1]["metrics"]
+
+ best_stats["best_C"] = best_C
+
+ logger.info(f"Log regression evaluation done in {int(time.time() - start)}s")
+ return best_stats
+
+
+def eval_log_regression_with_model(
+ model,
+ train_dataset_str="ImageNet:split=TRAIN",
+ val_dataset_str="ImageNet:split=VAL",
+ finetune_dataset_str=None,
+ autocast_dtype=torch.float,
+ finetune_on_val=False,
+ metric_type=MetricType.MEAN_ACCURACY,
+ train_dtype=torch.float64,
+ train_features_device=_CPU_DEVICE,
+ max_train_iters=DEFAULT_MAX_ITER,
+):
+ cudnn.benchmark = True
+
+ transform = make_classification_eval_transform(resize_size=224)
+ target_transform = None
+
+ train_dataset = make_dataset(dataset_str=train_dataset_str, transform=transform, target_transform=target_transform)
+ val_dataset = make_dataset(dataset_str=val_dataset_str, transform=transform, target_transform=target_transform)
+ if finetune_dataset_str is not None:
+ finetune_dataset = make_dataset(
+ dataset_str=finetune_dataset_str, transform=transform, target_transform=target_transform
+ )
+ else:
+ finetune_dataset = None
+
+ with torch.cuda.amp.autocast(dtype=autocast_dtype):
+ results_dict_logreg = eval_log_regression(
+ model=model,
+ train_dataset=train_dataset,
+ val_dataset=val_dataset,
+ finetune_dataset=finetune_dataset,
+ metric_type=metric_type,
+ batch_size=256,
+ num_workers=0, # 5,
+ finetune_on_val=finetune_on_val,
+ train_dtype=train_dtype,
+ train_features_device=train_features_device,
+ max_train_iters=max_train_iters,
+ )
+
+ results_dict = {
+ "top-1": results_dict_logreg["top-1"].cpu().numpy() * 100.0,
+ "top-5": results_dict_logreg.get("top-5", torch.tensor(0.0)).cpu().numpy() * 100.0,
+ "best_C": results_dict_logreg["best_C"],
+ }
+ logger.info(
+ "\n".join(
+ [
+ "Training of the supervised logistic regression on frozen features completed.\n"
+ "Top-1 test accuracy: {acc:.1f}".format(acc=results_dict["top-1"]),
+ "Top-5 test accuracy: {acc:.1f}".format(acc=results_dict["top-5"]),
+ "obtained for C = {c:.6f}".format(c=results_dict["best_C"]),
+ ]
+ )
+ )
+
+ torch.distributed.barrier()
+ return results_dict
+
+
+def main(args):
+ model, autocast_dtype = setup_and_build_model(args)
+ eval_log_regression_with_model(
+ model=model,
+ train_dataset_str=args.train_dataset_str,
+ val_dataset_str=args.val_dataset_str,
+ finetune_dataset_str=args.finetune_dataset_str,
+ autocast_dtype=autocast_dtype,
+ finetune_on_val=args.finetune_on_val,
+ metric_type=args.metric_type,
+ train_dtype=as_torch_dtype(args.train_dtype),
+ train_features_device=torch.device(args.train_features_device),
+ max_train_iters=args.max_train_iters,
+ )
+ return 0
+
+
+if __name__ == "__main__":
+ description = "DINOv2 logistic regression evaluation"
+ args_parser = get_args_parser(description=description)
+ args = args_parser.parse_args()
+ sys.exit(main(args))
diff --git a/mapper/models/dinov2/eval/metrics.py b/mapper/models/dinov2/eval/metrics.py
new file mode 100644
index 0000000000000000000000000000000000000000..52be81a859dddde82da93c3657c35352d2bb0a48
--- /dev/null
+++ b/mapper/models/dinov2/eval/metrics.py
@@ -0,0 +1,113 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+from enum import Enum
+import logging
+from typing import Any, Dict, Optional
+
+import torch
+from torch import Tensor
+from torchmetrics import Metric, MetricCollection
+from torchmetrics.classification import MulticlassAccuracy
+from torchmetrics.utilities.data import dim_zero_cat, select_topk
+
+
+logger = logging.getLogger("dinov2")
+
+
+class MetricType(Enum):
+ MEAN_ACCURACY = "mean_accuracy"
+ MEAN_PER_CLASS_ACCURACY = "mean_per_class_accuracy"
+ PER_CLASS_ACCURACY = "per_class_accuracy"
+ IMAGENET_REAL_ACCURACY = "imagenet_real_accuracy"
+
+ @property
+ def accuracy_averaging(self):
+ return getattr(AccuracyAveraging, self.name, None)
+
+ def __str__(self):
+ return self.value
+
+
+class AccuracyAveraging(Enum):
+ MEAN_ACCURACY = "micro"
+ MEAN_PER_CLASS_ACCURACY = "macro"
+ PER_CLASS_ACCURACY = "none"
+
+ def __str__(self):
+ return self.value
+
+
+def build_metric(metric_type: MetricType, *, num_classes: int, ks: Optional[tuple] = None):
+ if metric_type.accuracy_averaging is not None:
+ return build_topk_accuracy_metric(
+ average_type=metric_type.accuracy_averaging,
+ num_classes=num_classes,
+ ks=(1, 5) if ks is None else ks,
+ )
+ elif metric_type == MetricType.IMAGENET_REAL_ACCURACY:
+ return build_topk_imagenet_real_accuracy_metric(
+ num_classes=num_classes,
+ ks=(1, 5) if ks is None else ks,
+ )
+
+ raise ValueError(f"Unknown metric type {metric_type}")
+
+
+def build_topk_accuracy_metric(average_type: AccuracyAveraging, num_classes: int, ks: tuple = (1, 5)):
+ metrics: Dict[str, Metric] = {
+ f"top-{k}": MulticlassAccuracy(top_k=k, num_classes=int(num_classes), average=average_type.value) for k in ks
+ }
+ return MetricCollection(metrics)
+
+
+def build_topk_imagenet_real_accuracy_metric(num_classes: int, ks: tuple = (1, 5)):
+ metrics: Dict[str, Metric] = {f"top-{k}": ImageNetReaLAccuracy(top_k=k, num_classes=int(num_classes)) for k in ks}
+ return MetricCollection(metrics)
+
+
+class ImageNetReaLAccuracy(Metric):
+ is_differentiable: bool = False
+ higher_is_better: Optional[bool] = None
+ full_state_update: bool = False
+
+ def __init__(
+ self,
+ num_classes: int,
+ top_k: int = 1,
+ **kwargs: Any,
+ ) -> None:
+ super().__init__(**kwargs)
+ self.num_classes = num_classes
+ self.top_k = top_k
+ self.add_state("tp", [], dist_reduce_fx="cat")
+
+ def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore
+ # preds [B, D]
+ # target [B, A]
+ # preds_oh [B, D] with 0 and 1
+ # select top K highest probabilities, use one hot representation
+ preds_oh = select_topk(preds, self.top_k)
+ # target_oh [B, D + 1] with 0 and 1
+ target_oh = torch.zeros((preds_oh.shape[0], preds_oh.shape[1] + 1), device=target.device, dtype=torch.int32)
+ target = target.long()
+ # for undefined targets (-1) use a fake value `num_classes`
+ target[target == -1] = self.num_classes
+ # fill targets, use one hot representation
+ target_oh.scatter_(1, target, 1)
+ # target_oh [B, D] (remove the fake target at index `num_classes`)
+ target_oh = target_oh[:, :-1]
+ # tp [B] with 0 and 1
+ tp = (preds_oh * target_oh == 1).sum(dim=1)
+ # at least one match between prediction and target
+ tp.clip_(max=1)
+ # ignore instances where no targets are defined
+ mask = target_oh.sum(dim=1) > 0
+ tp = tp[mask]
+ self.tp.append(tp) # type: ignore
+
+ def compute(self) -> Tensor:
+ tp = dim_zero_cat(self.tp) # type: ignore
+ return tp.float().mean()
diff --git a/mapper/models/dinov2/eval/segmentation/__init__.py b/mapper/models/dinov2/eval/segmentation/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b88da6bf80be92af00b72dfdb0a806fa64a7a2d9
--- /dev/null
+++ b/mapper/models/dinov2/eval/segmentation/__init__.py
@@ -0,0 +1,4 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
diff --git a/mapper/models/dinov2/eval/segmentation/hooks/__init__.py b/mapper/models/dinov2/eval/segmentation/hooks/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..738cc2d2069521ea0353acd0cb0a03e3ddf1fa51
--- /dev/null
+++ b/mapper/models/dinov2/eval/segmentation/hooks/__init__.py
@@ -0,0 +1,6 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+from .optimizer import DistOptimizerHook
diff --git a/mapper/models/dinov2/eval/segmentation/hooks/optimizer.py b/mapper/models/dinov2/eval/segmentation/hooks/optimizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..f593f26a84475bbf7ebda9607a4d10914b13a443
--- /dev/null
+++ b/mapper/models/dinov2/eval/segmentation/hooks/optimizer.py
@@ -0,0 +1,40 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+try:
+ import apex
+except ImportError:
+ print("apex is not installed")
+
+from mmcv.runner import OptimizerHook, HOOKS
+
+
+@HOOKS.register_module()
+class DistOptimizerHook(OptimizerHook):
+ """Optimizer hook for distributed training."""
+
+ def __init__(self, update_interval=1, grad_clip=None, coalesce=True, bucket_size_mb=-1, use_fp16=False):
+ self.grad_clip = grad_clip
+ self.coalesce = coalesce
+ self.bucket_size_mb = bucket_size_mb
+ self.update_interval = update_interval
+ self.use_fp16 = use_fp16
+
+ def before_run(self, runner):
+ runner.optimizer.zero_grad()
+
+ def after_train_iter(self, runner):
+ runner.outputs["loss"] /= self.update_interval
+ if self.use_fp16:
+ # runner.outputs['loss'].backward()
+ with apex.amp.scale_loss(runner.outputs["loss"], runner.optimizer) as scaled_loss:
+ scaled_loss.backward()
+ else:
+ runner.outputs["loss"].backward()
+ if self.every_n_iters(runner, self.update_interval):
+ if self.grad_clip is not None:
+ self.clip_grads(runner.model.parameters())
+ runner.optimizer.step()
+ runner.optimizer.zero_grad()
diff --git a/mapper/models/dinov2/eval/segmentation/models/__init__.py b/mapper/models/dinov2/eval/segmentation/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..88e4563d4c162d67e7900955a06bd9248d4c9a48
--- /dev/null
+++ b/mapper/models/dinov2/eval/segmentation/models/__init__.py
@@ -0,0 +1,7 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+from .backbones import * # noqa: F403
+from .decode_heads import * # noqa: F403
diff --git a/mapper/models/dinov2/eval/segmentation/models/backbones/__init__.py b/mapper/models/dinov2/eval/segmentation/models/backbones/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..520d75bc6e064b9d64487293604ac1bda6e2b6f7
--- /dev/null
+++ b/mapper/models/dinov2/eval/segmentation/models/backbones/__init__.py
@@ -0,0 +1,6 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+from .vision_transformer import DinoVisionTransformer
diff --git a/mapper/models/dinov2/eval/segmentation/models/backbones/vision_transformer.py b/mapper/models/dinov2/eval/segmentation/models/backbones/vision_transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..c3e9753ae92a36be52f100e3004cbeeff777d14a
--- /dev/null
+++ b/mapper/models/dinov2/eval/segmentation/models/backbones/vision_transformer.py
@@ -0,0 +1,19 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+from mmcv.runner import BaseModule
+from mmseg.models.builder import BACKBONES
+
+
+@BACKBONES.register_module()
+class DinoVisionTransformer(BaseModule):
+ """Vision Transformer."""
+
+ def __init__(
+ self,
+ *args,
+ **kwargs,
+ ):
+ super().__init__()
diff --git a/mapper/models/dinov2/eval/segmentation/models/decode_heads/__init__.py b/mapper/models/dinov2/eval/segmentation/models/decode_heads/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c55317875262dadf8970c2b3882f016b8d4731ac
--- /dev/null
+++ b/mapper/models/dinov2/eval/segmentation/models/decode_heads/__init__.py
@@ -0,0 +1,6 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+from .linear_head import BNHead
diff --git a/mapper/models/dinov2/eval/segmentation/models/decode_heads/linear_head.py b/mapper/models/dinov2/eval/segmentation/models/decode_heads/linear_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..d1f39c68fb136f84d1aa5284da5b69581bb177cc
--- /dev/null
+++ b/mapper/models/dinov2/eval/segmentation/models/decode_heads/linear_head.py
@@ -0,0 +1,90 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn as nn
+
+from mmseg.models.builder import HEADS
+from mmseg.models.decode_heads.decode_head import BaseDecodeHead
+from mmseg.ops import resize
+
+
+@HEADS.register_module()
+class BNHead(BaseDecodeHead):
+ """Just a batchnorm."""
+
+ def __init__(self, resize_factors=None, **kwargs):
+ super().__init__(**kwargs)
+ assert self.in_channels == self.channels
+ self.bn = nn.SyncBatchNorm(self.in_channels)
+ self.resize_factors = resize_factors
+
+ def _forward_feature(self, inputs):
+ """Forward function for feature maps before classifying each pixel with
+ ``self.cls_seg`` fc.
+
+ Args:
+ inputs (list[Tensor]): List of multi-level img features.
+
+ Returns:
+ feats (Tensor): A tensor of shape (batch_size, self.channels,
+ H, W) which is feature map for last layer of decoder head.
+ """
+ # print("inputs", [i.shape for i in inputs])
+ x = self._transform_inputs(inputs)
+ # print("x", x.shape)
+ feats = self.bn(x)
+ # print("feats", feats.shape)
+ return feats
+
+ def _transform_inputs(self, inputs):
+ """Transform inputs for decoder.
+ Args:
+ inputs (list[Tensor]): List of multi-level img features.
+ Returns:
+ Tensor: The transformed inputs
+ """
+
+ if self.input_transform == "resize_concat":
+ # accept lists (for cls token)
+ input_list = []
+ for x in inputs:
+ if isinstance(x, list):
+ input_list.extend(x)
+ else:
+ input_list.append(x)
+ inputs = input_list
+ # an image descriptor can be a local descriptor with resolution 1x1
+ for i, x in enumerate(inputs):
+ if len(x.shape) == 2:
+ inputs[i] = x[:, :, None, None]
+ # select indices
+ inputs = [inputs[i] for i in self.in_index]
+ # Resizing shenanigans
+ # print("before", *(x.shape for x in inputs))
+ if self.resize_factors is not None:
+ assert len(self.resize_factors) == len(inputs), (len(self.resize_factors), len(inputs))
+ inputs = [
+ resize(input=x, scale_factor=f, mode="bilinear" if f >= 1 else "area")
+ for x, f in zip(inputs, self.resize_factors)
+ ]
+ # print("after", *(x.shape for x in inputs))
+ upsampled_inputs = [
+ resize(input=x, size=inputs[0].shape[2:], mode="bilinear", align_corners=self.align_corners)
+ for x in inputs
+ ]
+ inputs = torch.cat(upsampled_inputs, dim=1)
+ elif self.input_transform == "multiple_select":
+ inputs = [inputs[i] for i in self.in_index]
+ else:
+ inputs = inputs[self.in_index]
+
+ return inputs
+
+ def forward(self, inputs):
+ """Forward function."""
+ output = self._forward_feature(inputs)
+ output = self.cls_seg(output)
+ return output
diff --git a/mapper/models/dinov2/eval/segmentation/utils/__init__.py b/mapper/models/dinov2/eval/segmentation/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b88da6bf80be92af00b72dfdb0a806fa64a7a2d9
--- /dev/null
+++ b/mapper/models/dinov2/eval/segmentation/utils/__init__.py
@@ -0,0 +1,4 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
diff --git a/mapper/models/dinov2/eval/segmentation/utils/colormaps.py b/mapper/models/dinov2/eval/segmentation/utils/colormaps.py
new file mode 100644
index 0000000000000000000000000000000000000000..e6ef604b2c75792e95e438abfd51ab03d40de340
--- /dev/null
+++ b/mapper/models/dinov2/eval/segmentation/utils/colormaps.py
@@ -0,0 +1,362 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+ADE20K_COLORMAP = [
+ (0, 0, 0),
+ (120, 120, 120),
+ (180, 120, 120),
+ (6, 230, 230),
+ (80, 50, 50),
+ (4, 200, 3),
+ (120, 120, 80),
+ (140, 140, 140),
+ (204, 5, 255),
+ (230, 230, 230),
+ (4, 250, 7),
+ (224, 5, 255),
+ (235, 255, 7),
+ (150, 5, 61),
+ (120, 120, 70),
+ (8, 255, 51),
+ (255, 6, 82),
+ (143, 255, 140),
+ (204, 255, 4),
+ (255, 51, 7),
+ (204, 70, 3),
+ (0, 102, 200),
+ (61, 230, 250),
+ (255, 6, 51),
+ (11, 102, 255),
+ (255, 7, 71),
+ (255, 9, 224),
+ (9, 7, 230),
+ (220, 220, 220),
+ (255, 9, 92),
+ (112, 9, 255),
+ (8, 255, 214),
+ (7, 255, 224),
+ (255, 184, 6),
+ (10, 255, 71),
+ (255, 41, 10),
+ (7, 255, 255),
+ (224, 255, 8),
+ (102, 8, 255),
+ (255, 61, 6),
+ (255, 194, 7),
+ (255, 122, 8),
+ (0, 255, 20),
+ (255, 8, 41),
+ (255, 5, 153),
+ (6, 51, 255),
+ (235, 12, 255),
+ (160, 150, 20),
+ (0, 163, 255),
+ (140, 140, 140),
+ (250, 10, 15),
+ (20, 255, 0),
+ (31, 255, 0),
+ (255, 31, 0),
+ (255, 224, 0),
+ (153, 255, 0),
+ (0, 0, 255),
+ (255, 71, 0),
+ (0, 235, 255),
+ (0, 173, 255),
+ (31, 0, 255),
+ (11, 200, 200),
+ (255, 82, 0),
+ (0, 255, 245),
+ (0, 61, 255),
+ (0, 255, 112),
+ (0, 255, 133),
+ (255, 0, 0),
+ (255, 163, 0),
+ (255, 102, 0),
+ (194, 255, 0),
+ (0, 143, 255),
+ (51, 255, 0),
+ (0, 82, 255),
+ (0, 255, 41),
+ (0, 255, 173),
+ (10, 0, 255),
+ (173, 255, 0),
+ (0, 255, 153),
+ (255, 92, 0),
+ (255, 0, 255),
+ (255, 0, 245),
+ (255, 0, 102),
+ (255, 173, 0),
+ (255, 0, 20),
+ (255, 184, 184),
+ (0, 31, 255),
+ (0, 255, 61),
+ (0, 71, 255),
+ (255, 0, 204),
+ (0, 255, 194),
+ (0, 255, 82),
+ (0, 10, 255),
+ (0, 112, 255),
+ (51, 0, 255),
+ (0, 194, 255),
+ (0, 122, 255),
+ (0, 255, 163),
+ (255, 153, 0),
+ (0, 255, 10),
+ (255, 112, 0),
+ (143, 255, 0),
+ (82, 0, 255),
+ (163, 255, 0),
+ (255, 235, 0),
+ (8, 184, 170),
+ (133, 0, 255),
+ (0, 255, 92),
+ (184, 0, 255),
+ (255, 0, 31),
+ (0, 184, 255),
+ (0, 214, 255),
+ (255, 0, 112),
+ (92, 255, 0),
+ (0, 224, 255),
+ (112, 224, 255),
+ (70, 184, 160),
+ (163, 0, 255),
+ (153, 0, 255),
+ (71, 255, 0),
+ (255, 0, 163),
+ (255, 204, 0),
+ (255, 0, 143),
+ (0, 255, 235),
+ (133, 255, 0),
+ (255, 0, 235),
+ (245, 0, 255),
+ (255, 0, 122),
+ (255, 245, 0),
+ (10, 190, 212),
+ (214, 255, 0),
+ (0, 204, 255),
+ (20, 0, 255),
+ (255, 255, 0),
+ (0, 153, 255),
+ (0, 41, 255),
+ (0, 255, 204),
+ (41, 0, 255),
+ (41, 255, 0),
+ (173, 0, 255),
+ (0, 245, 255),
+ (71, 0, 255),
+ (122, 0, 255),
+ (0, 255, 184),
+ (0, 92, 255),
+ (184, 255, 0),
+ (0, 133, 255),
+ (255, 214, 0),
+ (25, 194, 194),
+ (102, 255, 0),
+ (92, 0, 255),
+]
+
+ADE20K_CLASS_NAMES = [
+ "",
+ "wall",
+ "building;edifice",
+ "sky",
+ "floor;flooring",
+ "tree",
+ "ceiling",
+ "road;route",
+ "bed",
+ "windowpane;window",
+ "grass",
+ "cabinet",
+ "sidewalk;pavement",
+ "person;individual;someone;somebody;mortal;soul",
+ "earth;ground",
+ "door;double;door",
+ "table",
+ "mountain;mount",
+ "plant;flora;plant;life",
+ "curtain;drape;drapery;mantle;pall",
+ "chair",
+ "car;auto;automobile;machine;motorcar",
+ "water",
+ "painting;picture",
+ "sofa;couch;lounge",
+ "shelf",
+ "house",
+ "sea",
+ "mirror",
+ "rug;carpet;carpeting",
+ "field",
+ "armchair",
+ "seat",
+ "fence;fencing",
+ "desk",
+ "rock;stone",
+ "wardrobe;closet;press",
+ "lamp",
+ "bathtub;bathing;tub;bath;tub",
+ "railing;rail",
+ "cushion",
+ "base;pedestal;stand",
+ "box",
+ "column;pillar",
+ "signboard;sign",
+ "chest;of;drawers;chest;bureau;dresser",
+ "counter",
+ "sand",
+ "sink",
+ "skyscraper",
+ "fireplace;hearth;open;fireplace",
+ "refrigerator;icebox",
+ "grandstand;covered;stand",
+ "path",
+ "stairs;steps",
+ "runway",
+ "case;display;case;showcase;vitrine",
+ "pool;table;billiard;table;snooker;table",
+ "pillow",
+ "screen;door;screen",
+ "stairway;staircase",
+ "river",
+ "bridge;span",
+ "bookcase",
+ "blind;screen",
+ "coffee;table;cocktail;table",
+ "toilet;can;commode;crapper;pot;potty;stool;throne",
+ "flower",
+ "book",
+ "hill",
+ "bench",
+ "countertop",
+ "stove;kitchen;stove;range;kitchen;range;cooking;stove",
+ "palm;palm;tree",
+ "kitchen;island",
+ "computer;computing;machine;computing;device;data;processor;electronic;computer;information;processing;system",
+ "swivel;chair",
+ "boat",
+ "bar",
+ "arcade;machine",
+ "hovel;hut;hutch;shack;shanty",
+ "bus;autobus;coach;charabanc;double-decker;jitney;motorbus;motorcoach;omnibus;passenger;vehicle",
+ "towel",
+ "light;light;source",
+ "truck;motortruck",
+ "tower",
+ "chandelier;pendant;pendent",
+ "awning;sunshade;sunblind",
+ "streetlight;street;lamp",
+ "booth;cubicle;stall;kiosk",
+ "television;television;receiver;television;set;tv;tv;set;idiot;box;boob;tube;telly;goggle;box",
+ "airplane;aeroplane;plane",
+ "dirt;track",
+ "apparel;wearing;apparel;dress;clothes",
+ "pole",
+ "land;ground;soil",
+ "bannister;banister;balustrade;balusters;handrail",
+ "escalator;moving;staircase;moving;stairway",
+ "ottoman;pouf;pouffe;puff;hassock",
+ "bottle",
+ "buffet;counter;sideboard",
+ "poster;posting;placard;notice;bill;card",
+ "stage",
+ "van",
+ "ship",
+ "fountain",
+ "conveyer;belt;conveyor;belt;conveyer;conveyor;transporter",
+ "canopy",
+ "washer;automatic;washer;washing;machine",
+ "plaything;toy",
+ "swimming;pool;swimming;bath;natatorium",
+ "stool",
+ "barrel;cask",
+ "basket;handbasket",
+ "waterfall;falls",
+ "tent;collapsible;shelter",
+ "bag",
+ "minibike;motorbike",
+ "cradle",
+ "oven",
+ "ball",
+ "food;solid;food",
+ "step;stair",
+ "tank;storage;tank",
+ "trade;name;brand;name;brand;marque",
+ "microwave;microwave;oven",
+ "pot;flowerpot",
+ "animal;animate;being;beast;brute;creature;fauna",
+ "bicycle;bike;wheel;cycle",
+ "lake",
+ "dishwasher;dish;washer;dishwashing;machine",
+ "screen;silver;screen;projection;screen",
+ "blanket;cover",
+ "sculpture",
+ "hood;exhaust;hood",
+ "sconce",
+ "vase",
+ "traffic;light;traffic;signal;stoplight",
+ "tray",
+ "ashcan;trash;can;garbage;can;wastebin;ash;bin;ash-bin;ashbin;dustbin;trash;barrel;trash;bin",
+ "fan",
+ "pier;wharf;wharfage;dock",
+ "crt;screen",
+ "plate",
+ "monitor;monitoring;device",
+ "bulletin;board;notice;board",
+ "shower",
+ "radiator",
+ "glass;drinking;glass",
+ "clock",
+ "flag",
+]
+
+
+VOC2012_COLORMAP = [
+ (0, 0, 0),
+ (128, 0, 0),
+ (0, 128, 0),
+ (128, 128, 0),
+ (0, 0, 128),
+ (128, 0, 128),
+ (0, 128, 128),
+ (128, 128, 128),
+ (64, 0, 0),
+ (192, 0, 0),
+ (64, 128, 0),
+ (192, 128, 0),
+ (64, 0, 128),
+ (192, 0, 128),
+ (64, 128, 128),
+ (192, 128, 128),
+ (0, 64, 0),
+ (128, 64, 0),
+ (0, 192, 0),
+ (128, 192, 0),
+ (0, 64, 128),
+]
+
+
+VOC2012_CLASS_NAMES = [
+ "",
+ "aeroplane",
+ "bicycle",
+ "bird",
+ "boat",
+ "bottle",
+ "bus",
+ "car",
+ "cat",
+ "chair",
+ "cow",
+ "diningtable",
+ "dog",
+ "horse",
+ "motorbike",
+ "person",
+ "pottedplant",
+ "sheep",
+ "sofa",
+ "train",
+ "tvmonitor",
+]
diff --git a/mapper/models/dinov2/eval/segmentation_m2f/__init__.py b/mapper/models/dinov2/eval/segmentation_m2f/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..6c678fdf8f1dee14d7cf9be70af14e6f9a1441c3
--- /dev/null
+++ b/mapper/models/dinov2/eval/segmentation_m2f/__init__.py
@@ -0,0 +1,8 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+from .core import * # noqa: F403
+from .models import * # noqa: F403
+from .ops import * # noqa: F403
diff --git a/mapper/models/dinov2/eval/segmentation_m2f/core/__init__.py b/mapper/models/dinov2/eval/segmentation_m2f/core/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..92599806fbd221c1418d179892a0f46dc0b7d4db
--- /dev/null
+++ b/mapper/models/dinov2/eval/segmentation_m2f/core/__init__.py
@@ -0,0 +1,11 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+from mmseg.core.evaluation import * # noqa: F403
+from mmseg.core.seg import * # noqa: F403
+
+from .anchor import * # noqa: F403
+from .box import * # noqa: F403
+from .utils import * # noqa: F403
diff --git a/mapper/models/dinov2/eval/segmentation_m2f/core/anchor/__init__.py b/mapper/models/dinov2/eval/segmentation_m2f/core/anchor/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e71ac4d6e01462221ae01aa16d0e1231cda7e2e7
--- /dev/null
+++ b/mapper/models/dinov2/eval/segmentation_m2f/core/anchor/__init__.py
@@ -0,0 +1,6 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+from .point_generator import MlvlPointGenerator # noqa: F403
diff --git a/mapper/models/dinov2/eval/segmentation_m2f/core/anchor/builder.py b/mapper/models/dinov2/eval/segmentation_m2f/core/anchor/builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..6dba90e22de76d2f23a86d3c057f196d55a99690
--- /dev/null
+++ b/mapper/models/dinov2/eval/segmentation_m2f/core/anchor/builder.py
@@ -0,0 +1,21 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+import warnings
+
+from mmcv.utils import Registry, build_from_cfg
+
+PRIOR_GENERATORS = Registry("Generator for anchors and points")
+
+ANCHOR_GENERATORS = PRIOR_GENERATORS
+
+
+def build_prior_generator(cfg, default_args=None):
+ return build_from_cfg(cfg, PRIOR_GENERATORS, default_args)
+
+
+def build_anchor_generator(cfg, default_args=None):
+ warnings.warn("``build_anchor_generator`` would be deprecated soon, please use " "``build_prior_generator`` ")
+ return build_prior_generator(cfg, default_args=default_args)
diff --git a/mapper/models/dinov2/eval/segmentation_m2f/core/anchor/point_generator.py b/mapper/models/dinov2/eval/segmentation_m2f/core/anchor/point_generator.py
new file mode 100644
index 0000000000000000000000000000000000000000..574d71939080e22284fe99087fb2e7336657bd97
--- /dev/null
+++ b/mapper/models/dinov2/eval/segmentation_m2f/core/anchor/point_generator.py
@@ -0,0 +1,205 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+import numpy as np
+import torch
+from torch.nn.modules.utils import _pair
+
+from .builder import PRIOR_GENERATORS
+
+
+@PRIOR_GENERATORS.register_module()
+class MlvlPointGenerator:
+ """Standard points generator for multi-level (Mlvl) feature maps in 2D
+ points-based detectors.
+
+ Args:
+ strides (list[int] | list[tuple[int, int]]): Strides of anchors
+ in multiple feature levels in order (w, h).
+ offset (float): The offset of points, the value is normalized with
+ corresponding stride. Defaults to 0.5.
+ """
+
+ def __init__(self, strides, offset=0.5):
+ self.strides = [_pair(stride) for stride in strides]
+ self.offset = offset
+
+ @property
+ def num_levels(self):
+ """int: number of feature levels that the generator will be applied"""
+ return len(self.strides)
+
+ @property
+ def num_base_priors(self):
+ """list[int]: The number of priors (points) at a point
+ on the feature grid"""
+ return [1 for _ in range(len(self.strides))]
+
+ def _meshgrid(self, x, y, row_major=True):
+ yy, xx = torch.meshgrid(y, x)
+ if row_major:
+ # warning .flatten() would cause error in ONNX exporting
+ # have to use reshape here
+ return xx.reshape(-1), yy.reshape(-1)
+
+ else:
+ return yy.reshape(-1), xx.reshape(-1)
+
+ def grid_priors(self, featmap_sizes, dtype=torch.float32, device="cuda", with_stride=False):
+ """Generate grid points of multiple feature levels.
+
+ Args:
+ featmap_sizes (list[tuple]): List of feature map sizes in
+ multiple feature levels, each size arrange as
+ as (h, w).
+ dtype (:obj:`dtype`): Dtype of priors. Default: torch.float32.
+ device (str): The device where the anchors will be put on.
+ with_stride (bool): Whether to concatenate the stride to
+ the last dimension of points.
+
+ Return:
+ list[torch.Tensor]: Points of multiple feature levels.
+ The sizes of each tensor should be (N, 2) when with stride is
+ ``False``, where N = width * height, width and height
+ are the sizes of the corresponding feature level,
+ and the last dimension 2 represent (coord_x, coord_y),
+ otherwise the shape should be (N, 4),
+ and the last dimension 4 represent
+ (coord_x, coord_y, stride_w, stride_h).
+ """
+
+ assert self.num_levels == len(featmap_sizes)
+ multi_level_priors = []
+ for i in range(self.num_levels):
+ priors = self.single_level_grid_priors(
+ featmap_sizes[i], level_idx=i, dtype=dtype, device=device, with_stride=with_stride
+ )
+ multi_level_priors.append(priors)
+ return multi_level_priors
+
+ def single_level_grid_priors(self, featmap_size, level_idx, dtype=torch.float32, device="cuda", with_stride=False):
+ """Generate grid Points of a single level.
+
+ Note:
+ This function is usually called by method ``self.grid_priors``.
+
+ Args:
+ featmap_size (tuple[int]): Size of the feature maps, arrange as
+ (h, w).
+ level_idx (int): The index of corresponding feature map level.
+ dtype (:obj:`dtype`): Dtype of priors. Default: torch.float32.
+ device (str, optional): The device the tensor will be put on.
+ Defaults to 'cuda'.
+ with_stride (bool): Concatenate the stride to the last dimension
+ of points.
+
+ Return:
+ Tensor: Points of single feature levels.
+ The shape of tensor should be (N, 2) when with stride is
+ ``False``, where N = width * height, width and height
+ are the sizes of the corresponding feature level,
+ and the last dimension 2 represent (coord_x, coord_y),
+ otherwise the shape should be (N, 4),
+ and the last dimension 4 represent
+ (coord_x, coord_y, stride_w, stride_h).
+ """
+ feat_h, feat_w = featmap_size
+ stride_w, stride_h = self.strides[level_idx]
+ shift_x = (torch.arange(0, feat_w, device=device) + self.offset) * stride_w
+ # keep featmap_size as Tensor instead of int, so that we
+ # can convert to ONNX correctly
+ shift_x = shift_x.to(dtype)
+
+ shift_y = (torch.arange(0, feat_h, device=device) + self.offset) * stride_h
+ # keep featmap_size as Tensor instead of int, so that we
+ # can convert to ONNX correctly
+ shift_y = shift_y.to(dtype)
+ shift_xx, shift_yy = self._meshgrid(shift_x, shift_y)
+ if not with_stride:
+ shifts = torch.stack([shift_xx, shift_yy], dim=-1)
+ else:
+ # use `shape[0]` instead of `len(shift_xx)` for ONNX export
+ stride_w = shift_xx.new_full((shift_xx.shape[0],), stride_w).to(dtype)
+ stride_h = shift_xx.new_full((shift_yy.shape[0],), stride_h).to(dtype)
+ shifts = torch.stack([shift_xx, shift_yy, stride_w, stride_h], dim=-1)
+ all_points = shifts.to(device)
+ return all_points
+
+ def valid_flags(self, featmap_sizes, pad_shape, device="cuda"):
+ """Generate valid flags of points of multiple feature levels.
+
+ Args:
+ featmap_sizes (list(tuple)): List of feature map sizes in
+ multiple feature levels, each size arrange as
+ as (h, w).
+ pad_shape (tuple(int)): The padded shape of the image,
+ arrange as (h, w).
+ device (str): The device where the anchors will be put on.
+
+ Return:
+ list(torch.Tensor): Valid flags of points of multiple levels.
+ """
+ assert self.num_levels == len(featmap_sizes)
+ multi_level_flags = []
+ for i in range(self.num_levels):
+ point_stride = self.strides[i]
+ feat_h, feat_w = featmap_sizes[i]
+ h, w = pad_shape[:2]
+ valid_feat_h = min(int(np.ceil(h / point_stride[1])), feat_h)
+ valid_feat_w = min(int(np.ceil(w / point_stride[0])), feat_w)
+ flags = self.single_level_valid_flags((feat_h, feat_w), (valid_feat_h, valid_feat_w), device=device)
+ multi_level_flags.append(flags)
+ return multi_level_flags
+
+ def single_level_valid_flags(self, featmap_size, valid_size, device="cuda"):
+ """Generate the valid flags of points of a single feature map.
+
+ Args:
+ featmap_size (tuple[int]): The size of feature maps, arrange as
+ as (h, w).
+ valid_size (tuple[int]): The valid size of the feature maps.
+ The size arrange as as (h, w).
+ device (str, optional): The device where the flags will be put on.
+ Defaults to 'cuda'.
+
+ Returns:
+ torch.Tensor: The valid flags of each points in a single level \
+ feature map.
+ """
+ feat_h, feat_w = featmap_size
+ valid_h, valid_w = valid_size
+ assert valid_h <= feat_h and valid_w <= feat_w
+ valid_x = torch.zeros(feat_w, dtype=torch.bool, device=device)
+ valid_y = torch.zeros(feat_h, dtype=torch.bool, device=device)
+ valid_x[:valid_w] = 1
+ valid_y[:valid_h] = 1
+ valid_xx, valid_yy = self._meshgrid(valid_x, valid_y)
+ valid = valid_xx & valid_yy
+ return valid
+
+ def sparse_priors(self, prior_idxs, featmap_size, level_idx, dtype=torch.float32, device="cuda"):
+ """Generate sparse points according to the ``prior_idxs``.
+
+ Args:
+ prior_idxs (Tensor): The index of corresponding anchors
+ in the feature map.
+ featmap_size (tuple[int]): feature map size arrange as (w, h).
+ level_idx (int): The level index of corresponding feature
+ map.
+ dtype (obj:`torch.dtype`): Date type of points. Defaults to
+ ``torch.float32``.
+ device (obj:`torch.device`): The device where the points is
+ located.
+ Returns:
+ Tensor: Anchor with shape (N, 2), N should be equal to
+ the length of ``prior_idxs``. And last dimension
+ 2 represent (coord_x, coord_y).
+ """
+ height, width = featmap_size
+ x = (prior_idxs % width + self.offset) * self.strides[level_idx][0]
+ y = ((prior_idxs // width) % height + self.offset) * self.strides[level_idx][1]
+ prioris = torch.stack([x, y], 1).to(dtype)
+ prioris = prioris.to(device)
+ return prioris
diff --git a/mapper/models/dinov2/eval/segmentation_m2f/core/box/__init__.py b/mapper/models/dinov2/eval/segmentation_m2f/core/box/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..bf35a613f81acd77ecab2dfb75a722fa8e5c0787
--- /dev/null
+++ b/mapper/models/dinov2/eval/segmentation_m2f/core/box/__init__.py
@@ -0,0 +1,7 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+from .builder import * # noqa: F403
+from .samplers import MaskPseudoSampler # noqa: F403
diff --git a/mapper/models/dinov2/eval/segmentation_m2f/core/box/builder.py b/mapper/models/dinov2/eval/segmentation_m2f/core/box/builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..9538c0de3db682c2b111b085a8a1ce321c76a9ff
--- /dev/null
+++ b/mapper/models/dinov2/eval/segmentation_m2f/core/box/builder.py
@@ -0,0 +1,19 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+from mmcv.utils import Registry, build_from_cfg
+
+BBOX_SAMPLERS = Registry("bbox_sampler")
+BBOX_CODERS = Registry("bbox_coder")
+
+
+def build_sampler(cfg, **default_args):
+ """Builder of box sampler."""
+ return build_from_cfg(cfg, BBOX_SAMPLERS, default_args)
+
+
+def build_bbox_coder(cfg, **default_args):
+ """Builder of box coder."""
+ return build_from_cfg(cfg, BBOX_CODERS, default_args)
diff --git a/mapper/models/dinov2/eval/segmentation_m2f/core/box/samplers/__init__.py b/mapper/models/dinov2/eval/segmentation_m2f/core/box/samplers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..19c363e3fabc365d92aeaf1e78189d710db279e9
--- /dev/null
+++ b/mapper/models/dinov2/eval/segmentation_m2f/core/box/samplers/__init__.py
@@ -0,0 +1,6 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+from .mask_pseudo_sampler import MaskPseudoSampler # noqa: F403
diff --git a/mapper/models/dinov2/eval/segmentation_m2f/core/box/samplers/base_sampler.py b/mapper/models/dinov2/eval/segmentation_m2f/core/box/samplers/base_sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..c45cec3ed7af5b49bb54b92d6e6bcf59b06b4c99
--- /dev/null
+++ b/mapper/models/dinov2/eval/segmentation_m2f/core/box/samplers/base_sampler.py
@@ -0,0 +1,92 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+from abc import ABCMeta, abstractmethod
+
+import torch
+
+from .sampling_result import SamplingResult
+
+
+class BaseSampler(metaclass=ABCMeta):
+ """Base class of samplers."""
+
+ def __init__(self, num, pos_fraction, neg_pos_ub=-1, add_gt_as_proposals=True, **kwargs):
+ self.num = num
+ self.pos_fraction = pos_fraction
+ self.neg_pos_ub = neg_pos_ub
+ self.add_gt_as_proposals = add_gt_as_proposals
+ self.pos_sampler = self
+ self.neg_sampler = self
+
+ @abstractmethod
+ def _sample_pos(self, assign_result, num_expected, **kwargs):
+ """Sample positive samples."""
+ pass
+
+ @abstractmethod
+ def _sample_neg(self, assign_result, num_expected, **kwargs):
+ """Sample negative samples."""
+ pass
+
+ def sample(self, assign_result, bboxes, gt_bboxes, gt_labels=None, **kwargs):
+ """Sample positive and negative bboxes.
+
+ This is a simple implementation of bbox sampling given candidates,
+ assigning results and ground truth bboxes.
+
+ Args:
+ assign_result (:obj:`AssignResult`): Bbox assigning results.
+ bboxes (Tensor): Boxes to be sampled from.
+ gt_bboxes (Tensor): Ground truth bboxes.
+ gt_labels (Tensor, optional): Class labels of ground truth bboxes.
+
+ Returns:
+ :obj:`SamplingResult`: Sampling result.
+
+ Example:
+ >>> from mmdet.core.bbox import RandomSampler
+ >>> from mmdet.core.bbox import AssignResult
+ >>> from mmdet.core.bbox.demodata import ensure_rng, random_boxes
+ >>> rng = ensure_rng(None)
+ >>> assign_result = AssignResult.random(rng=rng)
+ >>> bboxes = random_boxes(assign_result.num_preds, rng=rng)
+ >>> gt_bboxes = random_boxes(assign_result.num_gts, rng=rng)
+ >>> gt_labels = None
+ >>> self = RandomSampler(num=32, pos_fraction=0.5, neg_pos_ub=-1,
+ >>> add_gt_as_proposals=False)
+ >>> self = self.sample(assign_result, bboxes, gt_bboxes, gt_labels)
+ """
+ if len(bboxes.shape) < 2:
+ bboxes = bboxes[None, :]
+
+ bboxes = bboxes[:, :4]
+
+ gt_flags = bboxes.new_zeros((bboxes.shape[0],), dtype=torch.uint8)
+ if self.add_gt_as_proposals and len(gt_bboxes) > 0:
+ if gt_labels is None:
+ raise ValueError("gt_labels must be given when add_gt_as_proposals is True")
+ bboxes = torch.cat([gt_bboxes, bboxes], dim=0)
+ assign_result.add_gt_(gt_labels)
+ gt_ones = bboxes.new_ones(gt_bboxes.shape[0], dtype=torch.uint8)
+ gt_flags = torch.cat([gt_ones, gt_flags])
+
+ num_expected_pos = int(self.num * self.pos_fraction)
+ pos_inds = self.pos_sampler._sample_pos(assign_result, num_expected_pos, bboxes=bboxes, **kwargs)
+ # We found that sampled indices have duplicated items occasionally.
+ # (may be a bug of PyTorch)
+ pos_inds = pos_inds.unique()
+ num_sampled_pos = pos_inds.numel()
+ num_expected_neg = self.num - num_sampled_pos
+ if self.neg_pos_ub >= 0:
+ _pos = max(1, num_sampled_pos)
+ neg_upper_bound = int(self.neg_pos_ub * _pos)
+ if num_expected_neg > neg_upper_bound:
+ num_expected_neg = neg_upper_bound
+ neg_inds = self.neg_sampler._sample_neg(assign_result, num_expected_neg, bboxes=bboxes, **kwargs)
+ neg_inds = neg_inds.unique()
+
+ sampling_result = SamplingResult(pos_inds, neg_inds, bboxes, gt_bboxes, assign_result, gt_flags)
+ return sampling_result
diff --git a/mapper/models/dinov2/eval/segmentation_m2f/core/box/samplers/mask_pseudo_sampler.py b/mapper/models/dinov2/eval/segmentation_m2f/core/box/samplers/mask_pseudo_sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..3e67ea61ed0fd65cca0addde1893a3c1e176bf15
--- /dev/null
+++ b/mapper/models/dinov2/eval/segmentation_m2f/core/box/samplers/mask_pseudo_sampler.py
@@ -0,0 +1,45 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/ZwwWayne/K-Net/blob/main/knet/det/mask_pseudo_sampler.py
+
+import torch
+
+from ..builder import BBOX_SAMPLERS
+from .base_sampler import BaseSampler
+from .mask_sampling_result import MaskSamplingResult
+
+
+@BBOX_SAMPLERS.register_module()
+class MaskPseudoSampler(BaseSampler):
+ """A pseudo sampler that does not do sampling actually."""
+
+ def __init__(self, **kwargs):
+ pass
+
+ def _sample_pos(self, **kwargs):
+ """Sample positive samples."""
+ raise NotImplementedError
+
+ def _sample_neg(self, **kwargs):
+ """Sample negative samples."""
+ raise NotImplementedError
+
+ def sample(self, assign_result, masks, gt_masks, **kwargs):
+ """Directly returns the positive and negative indices of samples.
+
+ Args:
+ assign_result (:obj:`AssignResult`): Assigned results
+ masks (torch.Tensor): Bounding boxes
+ gt_masks (torch.Tensor): Ground truth boxes
+ Returns:
+ :obj:`SamplingResult`: sampler results
+ """
+ pos_inds = torch.nonzero(assign_result.gt_inds > 0, as_tuple=False).squeeze(-1).unique()
+ neg_inds = torch.nonzero(assign_result.gt_inds == 0, as_tuple=False).squeeze(-1).unique()
+ gt_flags = masks.new_zeros(masks.shape[0], dtype=torch.uint8)
+ sampling_result = MaskSamplingResult(pos_inds, neg_inds, masks, gt_masks, assign_result, gt_flags)
+ return sampling_result
diff --git a/mapper/models/dinov2/eval/segmentation_m2f/core/box/samplers/mask_sampling_result.py b/mapper/models/dinov2/eval/segmentation_m2f/core/box/samplers/mask_sampling_result.py
new file mode 100644
index 0000000000000000000000000000000000000000..270ffd35a5f120dd0560a7fea7fe83ef0bab66bb
--- /dev/null
+++ b/mapper/models/dinov2/eval/segmentation_m2f/core/box/samplers/mask_sampling_result.py
@@ -0,0 +1,63 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/ZwwWayne/K-Net/blob/main/knet/det/mask_pseudo_sampler.py
+
+import torch
+
+from .sampling_result import SamplingResult
+
+
+class MaskSamplingResult(SamplingResult):
+ """Mask sampling result."""
+
+ def __init__(self, pos_inds, neg_inds, masks, gt_masks, assign_result, gt_flags):
+ self.pos_inds = pos_inds
+ self.neg_inds = neg_inds
+ self.pos_masks = masks[pos_inds]
+ self.neg_masks = masks[neg_inds]
+ self.pos_is_gt = gt_flags[pos_inds]
+
+ self.num_gts = gt_masks.shape[0]
+ self.pos_assigned_gt_inds = assign_result.gt_inds[pos_inds] - 1
+
+ if gt_masks.numel() == 0:
+ # hack for index error case
+ assert self.pos_assigned_gt_inds.numel() == 0
+ self.pos_gt_masks = torch.empty_like(gt_masks)
+ else:
+ self.pos_gt_masks = gt_masks[self.pos_assigned_gt_inds, :]
+
+ if assign_result.labels is not None:
+ self.pos_gt_labels = assign_result.labels[pos_inds]
+ else:
+ self.pos_gt_labels = None
+
+ @property
+ def masks(self):
+ """torch.Tensor: concatenated positive and negative boxes"""
+ return torch.cat([self.pos_masks, self.neg_masks])
+
+ def __nice__(self):
+ data = self.info.copy()
+ data["pos_masks"] = data.pop("pos_masks").shape
+ data["neg_masks"] = data.pop("neg_masks").shape
+ parts = [f"'{k}': {v!r}" for k, v in sorted(data.items())]
+ body = " " + ",\n ".join(parts)
+ return "{\n" + body + "\n}"
+
+ @property
+ def info(self):
+ """Returns a dictionary of info about the object."""
+ return {
+ "pos_inds": self.pos_inds,
+ "neg_inds": self.neg_inds,
+ "pos_masks": self.pos_masks,
+ "neg_masks": self.neg_masks,
+ "pos_is_gt": self.pos_is_gt,
+ "num_gts": self.num_gts,
+ "pos_assigned_gt_inds": self.pos_assigned_gt_inds,
+ }
diff --git a/mapper/models/dinov2/eval/segmentation_m2f/core/box/samplers/sampling_result.py b/mapper/models/dinov2/eval/segmentation_m2f/core/box/samplers/sampling_result.py
new file mode 100644
index 0000000000000000000000000000000000000000..aaee3fe55aeb8c6da7edefbbd382d94b67b6a6b4
--- /dev/null
+++ b/mapper/models/dinov2/eval/segmentation_m2f/core/box/samplers/sampling_result.py
@@ -0,0 +1,152 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+import torch
+
+
+class SamplingResult:
+ """Bbox sampling result.
+
+ Example:
+ >>> # xdoctest: +IGNORE_WANT
+ >>> from mmdet.core.bbox.samplers.sampling_result import * # NOQA
+ >>> self = SamplingResult.random(rng=10)
+ >>> print(f'self = {self}')
+ self =
+ """
+
+ def __init__(self, pos_inds, neg_inds, bboxes, gt_bboxes, assign_result, gt_flags):
+ self.pos_inds = pos_inds
+ self.neg_inds = neg_inds
+ self.pos_bboxes = bboxes[pos_inds]
+ self.neg_bboxes = bboxes[neg_inds]
+ self.pos_is_gt = gt_flags[pos_inds]
+
+ self.num_gts = gt_bboxes.shape[0]
+ self.pos_assigned_gt_inds = assign_result.gt_inds[pos_inds] - 1
+
+ if gt_bboxes.numel() == 0:
+ # hack for index error case
+ assert self.pos_assigned_gt_inds.numel() == 0
+ self.pos_gt_bboxes = torch.empty_like(gt_bboxes).view(-1, 4)
+ else:
+ if len(gt_bboxes.shape) < 2:
+ gt_bboxes = gt_bboxes.view(-1, 4)
+
+ self.pos_gt_bboxes = gt_bboxes[self.pos_assigned_gt_inds.long(), :]
+
+ if assign_result.labels is not None:
+ self.pos_gt_labels = assign_result.labels[pos_inds]
+ else:
+ self.pos_gt_labels = None
+
+ @property
+ def bboxes(self):
+ """torch.Tensor: concatenated positive and negative boxes"""
+ return torch.cat([self.pos_bboxes, self.neg_bboxes])
+
+ def to(self, device):
+ """Change the device of the data inplace.
+
+ Example:
+ >>> self = SamplingResult.random()
+ >>> print(f'self = {self.to(None)}')
+ >>> # xdoctest: +REQUIRES(--gpu)
+ >>> print(f'self = {self.to(0)}')
+ """
+ _dict = self.__dict__
+ for key, value in _dict.items():
+ if isinstance(value, torch.Tensor):
+ _dict[key] = value.to(device)
+ return self
+
+ def __nice__(self):
+ data = self.info.copy()
+ data["pos_bboxes"] = data.pop("pos_bboxes").shape
+ data["neg_bboxes"] = data.pop("neg_bboxes").shape
+ parts = [f"'{k}': {v!r}" for k, v in sorted(data.items())]
+ body = " " + ",\n ".join(parts)
+ return "{\n" + body + "\n}"
+
+ @property
+ def info(self):
+ """Returns a dictionary of info about the object."""
+ return {
+ "pos_inds": self.pos_inds,
+ "neg_inds": self.neg_inds,
+ "pos_bboxes": self.pos_bboxes,
+ "neg_bboxes": self.neg_bboxes,
+ "pos_is_gt": self.pos_is_gt,
+ "num_gts": self.num_gts,
+ "pos_assigned_gt_inds": self.pos_assigned_gt_inds,
+ }
+
+ @classmethod
+ def random(cls, rng=None, **kwargs):
+ """
+ Args:
+ rng (None | int | numpy.random.RandomState): seed or state.
+ kwargs (keyword arguments):
+ - num_preds: number of predicted boxes
+ - num_gts: number of true boxes
+ - p_ignore (float): probability of a predicted box assigned to \
+ an ignored truth.
+ - p_assigned (float): probability of a predicted box not being \
+ assigned.
+ - p_use_label (float | bool): with labels or not.
+
+ Returns:
+ :obj:`SamplingResult`: Randomly generated sampling result.
+
+ Example:
+ >>> from mmdet.core.bbox.samplers.sampling_result import * # NOQA
+ >>> self = SamplingResult.random()
+ >>> print(self.__dict__)
+ """
+ from mmdet.core.bbox import demodata
+ from mmdet.core.bbox.assigners.assign_result import AssignResult
+ from mmdet.core.bbox.samplers.random_sampler import RandomSampler
+
+ rng = demodata.ensure_rng(rng)
+
+ # make probabalistic?
+ num = 32
+ pos_fraction = 0.5
+ neg_pos_ub = -1
+
+ assign_result = AssignResult.random(rng=rng, **kwargs)
+
+ # Note we could just compute an assignment
+ bboxes = demodata.random_boxes(assign_result.num_preds, rng=rng)
+ gt_bboxes = demodata.random_boxes(assign_result.num_gts, rng=rng)
+
+ if rng.rand() > 0.2:
+ # sometimes algorithms squeeze their data, be robust to that
+ gt_bboxes = gt_bboxes.squeeze()
+ bboxes = bboxes.squeeze()
+
+ if assign_result.labels is None:
+ gt_labels = None
+ else:
+ gt_labels = None
+
+ if gt_labels is None:
+ add_gt_as_proposals = False
+ else:
+ add_gt_as_proposals = True # make probabalistic?
+
+ sampler = RandomSampler(
+ num, pos_fraction, neg_pos_ub=neg_pos_ub, add_gt_as_proposals=add_gt_as_proposals, rng=rng
+ )
+ self = sampler.sample(assign_result, bboxes, gt_bboxes, gt_labels)
+ return self
diff --git a/mapper/models/dinov2/eval/segmentation_m2f/core/utils/__init__.py b/mapper/models/dinov2/eval/segmentation_m2f/core/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..6cdc9e19352f50bc2d5433c412ff71186c5df019
--- /dev/null
+++ b/mapper/models/dinov2/eval/segmentation_m2f/core/utils/__init__.py
@@ -0,0 +1,7 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+from .dist_utils import reduce_mean
+from .misc import add_prefix, multi_apply
diff --git a/mapper/models/dinov2/eval/segmentation_m2f/core/utils/dist_utils.py b/mapper/models/dinov2/eval/segmentation_m2f/core/utils/dist_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..7dfed42da821cd94e31b663d86b20b8f09799b30
--- /dev/null
+++ b/mapper/models/dinov2/eval/segmentation_m2f/core/utils/dist_utils.py
@@ -0,0 +1,15 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+import torch.distributed as dist
+
+
+def reduce_mean(tensor):
+ """ "Obtain the mean of tensor on different GPUs."""
+ if not (dist.is_available() and dist.is_initialized()):
+ return tensor
+ tensor = tensor.clone()
+ dist.all_reduce(tensor.div_(dist.get_world_size()), op=dist.ReduceOp.SUM)
+ return tensor
diff --git a/mapper/models/dinov2/eval/segmentation_m2f/core/utils/misc.py b/mapper/models/dinov2/eval/segmentation_m2f/core/utils/misc.py
new file mode 100644
index 0000000000000000000000000000000000000000..e07579e7b182b62153e81fe637ffd0f3081ef2a3
--- /dev/null
+++ b/mapper/models/dinov2/eval/segmentation_m2f/core/utils/misc.py
@@ -0,0 +1,47 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+from functools import partial
+
+
+def multi_apply(func, *args, **kwargs):
+ """Apply function to a list of arguments.
+
+ Note:
+ This function applies the ``func`` to multiple inputs and
+ map the multiple outputs of the ``func`` into different
+ list. Each list contains the same type of outputs corresponding
+ to different inputs.
+
+ Args:
+ func (Function): A function that will be applied to a list of
+ arguments
+
+ Returns:
+ tuple(list): A tuple containing multiple list, each list contains \
+ a kind of returned results by the function
+ """
+ pfunc = partial(func, **kwargs) if kwargs else func
+ map_results = map(pfunc, *args)
+ return tuple(map(list, zip(*map_results)))
+
+
+def add_prefix(inputs, prefix):
+ """Add prefix for dict.
+
+ Args:
+ inputs (dict): The input dict with str keys.
+ prefix (str): The prefix to add.
+
+ Returns:
+
+ dict: The dict with keys updated with ``prefix``.
+ """
+
+ outputs = dict()
+ for name, value in inputs.items():
+ outputs[f"{prefix}.{name}"] = value
+
+ return outputs
diff --git a/mapper/models/dinov2/eval/segmentation_m2f/models/__init__.py b/mapper/models/dinov2/eval/segmentation_m2f/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ed89bb0064d82b4360af020798eab3d2f5a47937
--- /dev/null
+++ b/mapper/models/dinov2/eval/segmentation_m2f/models/__init__.py
@@ -0,0 +1,11 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+from .backbones import * # noqa: F403
+from .builder import MASK_ASSIGNERS, MATCH_COST, TRANSFORMER, build_assigner, build_match_cost
+from .decode_heads import * # noqa: F403
+from .losses import * # noqa: F403
+from .plugins import * # noqa: F403
+from .segmentors import * # noqa: F403
diff --git a/mapper/models/dinov2/eval/segmentation_m2f/models/backbones/__init__.py b/mapper/models/dinov2/eval/segmentation_m2f/models/backbones/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c4bf73bcbcee710676f81cb6517ae787f4d61cc6
--- /dev/null
+++ b/mapper/models/dinov2/eval/segmentation_m2f/models/backbones/__init__.py
@@ -0,0 +1,6 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+from .vit_adapter import ViTAdapter
diff --git a/mapper/models/dinov2/eval/segmentation_m2f/models/backbones/adapter_modules.py b/mapper/models/dinov2/eval/segmentation_m2f/models/backbones/adapter_modules.py
new file mode 100644
index 0000000000000000000000000000000000000000..26bfdf8f6ae6c107d22d61985cce34d4b5ce275f
--- /dev/null
+++ b/mapper/models/dinov2/eval/segmentation_m2f/models/backbones/adapter_modules.py
@@ -0,0 +1,442 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+from functools import partial
+
+import torch
+import torch.nn as nn
+import torch.utils.checkpoint as cp
+
+from ...ops.modules import MSDeformAttn
+from .drop_path import DropPath
+
+
+def get_reference_points(spatial_shapes, device):
+ reference_points_list = []
+ for lvl, (H_, W_) in enumerate(spatial_shapes):
+ ref_y, ref_x = torch.meshgrid(
+ torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device),
+ torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device),
+ )
+ ref_y = ref_y.reshape(-1)[None] / H_
+ ref_x = ref_x.reshape(-1)[None] / W_
+ ref = torch.stack((ref_x, ref_y), -1)
+ reference_points_list.append(ref)
+ reference_points = torch.cat(reference_points_list, 1)
+ reference_points = reference_points[:, :, None]
+ return reference_points
+
+
+def deform_inputs(x, patch_size):
+ bs, c, h, w = x.shape
+ spatial_shapes = torch.as_tensor(
+ [(h // 8, w // 8), (h // 16, w // 16), (h // 32, w // 32)], dtype=torch.long, device=x.device
+ )
+ level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1]))
+ reference_points = get_reference_points([(h // patch_size, w // patch_size)], x.device)
+ deform_inputs1 = [reference_points, spatial_shapes, level_start_index]
+
+ spatial_shapes = torch.as_tensor([(h // patch_size, w // patch_size)], dtype=torch.long, device=x.device)
+ level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1]))
+ reference_points = get_reference_points([(h // 8, w // 8), (h // 16, w // 16), (h // 32, w // 32)], x.device)
+ deform_inputs2 = [reference_points, spatial_shapes, level_start_index]
+
+ return deform_inputs1, deform_inputs2
+
+
+class ConvFFN(nn.Module):
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features)
+ self.dwconv = DWConv(hidden_features)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x, H, W):
+ x = self.fc1(x)
+ x = self.dwconv(x, H, W)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
+
+
+class DWConv(nn.Module):
+ def __init__(self, dim=768):
+ super().__init__()
+ self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)
+
+ def forward(self, x, H, W):
+ B, N, C = x.shape
+ n = N // 21
+ x1 = x[:, 0 : 16 * n, :].transpose(1, 2).view(B, C, H * 2, W * 2).contiguous()
+ x2 = x[:, 16 * n : 20 * n, :].transpose(1, 2).view(B, C, H, W).contiguous()
+ x3 = x[:, 20 * n :, :].transpose(1, 2).view(B, C, H // 2, W // 2).contiguous()
+ x1 = self.dwconv(x1).flatten(2).transpose(1, 2)
+ x2 = self.dwconv(x2).flatten(2).transpose(1, 2)
+ x3 = self.dwconv(x3).flatten(2).transpose(1, 2)
+ x = torch.cat([x1, x2, x3], dim=1)
+ return x
+
+
+class Extractor(nn.Module):
+ def __init__(
+ self,
+ dim,
+ num_heads=6,
+ n_points=4,
+ n_levels=1,
+ deform_ratio=1.0,
+ with_cffn=True,
+ cffn_ratio=0.25,
+ drop=0.0,
+ drop_path=0.0,
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
+ with_cp=False,
+ ):
+ super().__init__()
+ self.query_norm = norm_layer(dim)
+ self.feat_norm = norm_layer(dim)
+ self.attn = MSDeformAttn(
+ d_model=dim, n_levels=n_levels, n_heads=num_heads, n_points=n_points, ratio=deform_ratio
+ )
+ self.with_cffn = with_cffn
+ self.with_cp = with_cp
+ if with_cffn:
+ self.ffn = ConvFFN(in_features=dim, hidden_features=int(dim * cffn_ratio), drop=drop)
+ self.ffn_norm = norm_layer(dim)
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+
+ def forward(self, query, reference_points, feat, spatial_shapes, level_start_index, H, W):
+ def _inner_forward(query, feat):
+
+ attn = self.attn(
+ self.query_norm(query), reference_points, self.feat_norm(feat), spatial_shapes, level_start_index, None
+ )
+ query = query + attn
+
+ if self.with_cffn:
+ query = query + self.drop_path(self.ffn(self.ffn_norm(query), H, W))
+ return query
+
+ if self.with_cp and query.requires_grad:
+ query = cp.checkpoint(_inner_forward, query, feat)
+ else:
+ query = _inner_forward(query, feat)
+
+ return query
+
+
+class Injector(nn.Module):
+ def __init__(
+ self,
+ dim,
+ num_heads=6,
+ n_points=4,
+ n_levels=1,
+ deform_ratio=1.0,
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
+ init_values=0.0,
+ with_cp=False,
+ ):
+ super().__init__()
+ self.with_cp = with_cp
+ self.query_norm = norm_layer(dim)
+ self.feat_norm = norm_layer(dim)
+ self.attn = MSDeformAttn(
+ d_model=dim, n_levels=n_levels, n_heads=num_heads, n_points=n_points, ratio=deform_ratio
+ )
+ self.gamma = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
+
+ def forward(self, query, reference_points, feat, spatial_shapes, level_start_index):
+ def _inner_forward(query, feat):
+
+ attn = self.attn(
+ self.query_norm(query), reference_points, self.feat_norm(feat), spatial_shapes, level_start_index, None
+ )
+ return query + self.gamma * attn
+
+ if self.with_cp and query.requires_grad:
+ query = cp.checkpoint(_inner_forward, query, feat)
+ else:
+ query = _inner_forward(query, feat)
+
+ return query
+
+
+class InteractionBlock(nn.Module):
+ def __init__(
+ self,
+ dim,
+ num_heads=6,
+ n_points=4,
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
+ drop=0.0,
+ drop_path=0.0,
+ with_cffn=True,
+ cffn_ratio=0.25,
+ init_values=0.0,
+ deform_ratio=1.0,
+ extra_extractor=False,
+ with_cp=False,
+ ):
+ super().__init__()
+
+ self.injector = Injector(
+ dim=dim,
+ n_levels=3,
+ num_heads=num_heads,
+ init_values=init_values,
+ n_points=n_points,
+ norm_layer=norm_layer,
+ deform_ratio=deform_ratio,
+ with_cp=with_cp,
+ )
+ self.extractor = Extractor(
+ dim=dim,
+ n_levels=1,
+ num_heads=num_heads,
+ n_points=n_points,
+ norm_layer=norm_layer,
+ deform_ratio=deform_ratio,
+ with_cffn=with_cffn,
+ cffn_ratio=cffn_ratio,
+ drop=drop,
+ drop_path=drop_path,
+ with_cp=with_cp,
+ )
+ if extra_extractor:
+ self.extra_extractors = nn.Sequential(
+ *[
+ Extractor(
+ dim=dim,
+ num_heads=num_heads,
+ n_points=n_points,
+ norm_layer=norm_layer,
+ with_cffn=with_cffn,
+ cffn_ratio=cffn_ratio,
+ deform_ratio=deform_ratio,
+ drop=drop,
+ drop_path=drop_path,
+ with_cp=with_cp,
+ )
+ for _ in range(2)
+ ]
+ )
+ else:
+ self.extra_extractors = None
+
+ def forward(self, x, c, blocks, deform_inputs1, deform_inputs2, H_c, W_c, H_toks, W_toks):
+ x = self.injector(
+ query=x,
+ reference_points=deform_inputs1[0],
+ feat=c,
+ spatial_shapes=deform_inputs1[1],
+ level_start_index=deform_inputs1[2],
+ )
+ for idx, blk in enumerate(blocks):
+ x = blk(x, H_toks, W_toks)
+ c = self.extractor(
+ query=c,
+ reference_points=deform_inputs2[0],
+ feat=x,
+ spatial_shapes=deform_inputs2[1],
+ level_start_index=deform_inputs2[2],
+ H=H_c,
+ W=W_c,
+ )
+ if self.extra_extractors is not None:
+ for extractor in self.extra_extractors:
+ c = extractor(
+ query=c,
+ reference_points=deform_inputs2[0],
+ feat=x,
+ spatial_shapes=deform_inputs2[1],
+ level_start_index=deform_inputs2[2],
+ H=H_c,
+ W=W_c,
+ )
+ return x, c
+
+
+class InteractionBlockWithCls(nn.Module):
+ def __init__(
+ self,
+ dim,
+ num_heads=6,
+ n_points=4,
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
+ drop=0.0,
+ drop_path=0.0,
+ with_cffn=True,
+ cffn_ratio=0.25,
+ init_values=0.0,
+ deform_ratio=1.0,
+ extra_extractor=False,
+ with_cp=False,
+ ):
+ super().__init__()
+
+ self.injector = Injector(
+ dim=dim,
+ n_levels=3,
+ num_heads=num_heads,
+ init_values=init_values,
+ n_points=n_points,
+ norm_layer=norm_layer,
+ deform_ratio=deform_ratio,
+ with_cp=with_cp,
+ )
+ self.extractor = Extractor(
+ dim=dim,
+ n_levels=1,
+ num_heads=num_heads,
+ n_points=n_points,
+ norm_layer=norm_layer,
+ deform_ratio=deform_ratio,
+ with_cffn=with_cffn,
+ cffn_ratio=cffn_ratio,
+ drop=drop,
+ drop_path=drop_path,
+ with_cp=with_cp,
+ )
+ if extra_extractor:
+ self.extra_extractors = nn.Sequential(
+ *[
+ Extractor(
+ dim=dim,
+ num_heads=num_heads,
+ n_points=n_points,
+ norm_layer=norm_layer,
+ with_cffn=with_cffn,
+ cffn_ratio=cffn_ratio,
+ deform_ratio=deform_ratio,
+ drop=drop,
+ drop_path=drop_path,
+ with_cp=with_cp,
+ )
+ for _ in range(2)
+ ]
+ )
+ else:
+ self.extra_extractors = None
+
+ def forward(self, x, c, cls, blocks, deform_inputs1, deform_inputs2, H_c, W_c, H_toks, W_toks):
+ x = self.injector(
+ query=x,
+ reference_points=deform_inputs1[0],
+ feat=c,
+ spatial_shapes=deform_inputs1[1],
+ level_start_index=deform_inputs1[2],
+ )
+ x = torch.cat((cls, x), dim=1)
+ for idx, blk in enumerate(blocks):
+ x = blk(x, H_toks, W_toks)
+ cls, x = (
+ x[
+ :,
+ :1,
+ ],
+ x[
+ :,
+ 1:,
+ ],
+ )
+ c = self.extractor(
+ query=c,
+ reference_points=deform_inputs2[0],
+ feat=x,
+ spatial_shapes=deform_inputs2[1],
+ level_start_index=deform_inputs2[2],
+ H=H_c,
+ W=W_c,
+ )
+ if self.extra_extractors is not None:
+ for extractor in self.extra_extractors:
+ c = extractor(
+ query=c,
+ reference_points=deform_inputs2[0],
+ feat=x,
+ spatial_shapes=deform_inputs2[1],
+ level_start_index=deform_inputs2[2],
+ H=H_c,
+ W=W_c,
+ )
+ return x, c, cls
+
+
+class SpatialPriorModule(nn.Module):
+ def __init__(self, inplanes=64, embed_dim=384, with_cp=False):
+ super().__init__()
+ self.with_cp = with_cp
+
+ self.stem = nn.Sequential(
+ *[
+ nn.Conv2d(3, inplanes, kernel_size=3, stride=2, padding=1, bias=False),
+ nn.SyncBatchNorm(inplanes),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(inplanes, inplanes, kernel_size=3, stride=1, padding=1, bias=False),
+ nn.SyncBatchNorm(inplanes),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(inplanes, inplanes, kernel_size=3, stride=1, padding=1, bias=False),
+ nn.SyncBatchNorm(inplanes),
+ nn.ReLU(inplace=True),
+ nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
+ ]
+ )
+ self.conv2 = nn.Sequential(
+ *[
+ nn.Conv2d(inplanes, 2 * inplanes, kernel_size=3, stride=2, padding=1, bias=False),
+ nn.SyncBatchNorm(2 * inplanes),
+ nn.ReLU(inplace=True),
+ ]
+ )
+ self.conv3 = nn.Sequential(
+ *[
+ nn.Conv2d(2 * inplanes, 4 * inplanes, kernel_size=3, stride=2, padding=1, bias=False),
+ nn.SyncBatchNorm(4 * inplanes),
+ nn.ReLU(inplace=True),
+ ]
+ )
+ self.conv4 = nn.Sequential(
+ *[
+ nn.Conv2d(4 * inplanes, 4 * inplanes, kernel_size=3, stride=2, padding=1, bias=False),
+ nn.SyncBatchNorm(4 * inplanes),
+ nn.ReLU(inplace=True),
+ ]
+ )
+ self.fc1 = nn.Conv2d(inplanes, embed_dim, kernel_size=1, stride=1, padding=0, bias=True)
+ self.fc2 = nn.Conv2d(2 * inplanes, embed_dim, kernel_size=1, stride=1, padding=0, bias=True)
+ self.fc3 = nn.Conv2d(4 * inplanes, embed_dim, kernel_size=1, stride=1, padding=0, bias=True)
+ self.fc4 = nn.Conv2d(4 * inplanes, embed_dim, kernel_size=1, stride=1, padding=0, bias=True)
+
+ def forward(self, x):
+ def _inner_forward(x):
+ c1 = self.stem(x)
+ c2 = self.conv2(c1)
+ c3 = self.conv3(c2)
+ c4 = self.conv4(c3)
+ c1 = self.fc1(c1)
+ c2 = self.fc2(c2)
+ c3 = self.fc3(c3)
+ c4 = self.fc4(c4)
+
+ bs, dim, _, _ = c1.shape
+ # c1 = c1.view(bs, dim, -1).transpose(1, 2) # 4s
+ c2 = c2.view(bs, dim, -1).transpose(1, 2) # 8s
+ c3 = c3.view(bs, dim, -1).transpose(1, 2) # 16s
+ c4 = c4.view(bs, dim, -1).transpose(1, 2) # 32s
+
+ return c1, c2, c3, c4
+
+ if self.with_cp and x.requires_grad:
+ outs = cp.checkpoint(_inner_forward, x)
+ else:
+ outs = _inner_forward(x)
+ return outs
diff --git a/mapper/models/dinov2/eval/segmentation_m2f/models/backbones/drop_path.py b/mapper/models/dinov2/eval/segmentation_m2f/models/backbones/drop_path.py
new file mode 100644
index 0000000000000000000000000000000000000000..864eb8738c44652d12b979fc811503f21cbb00dd
--- /dev/null
+++ b/mapper/models/dinov2/eval/segmentation_m2f/models/backbones/drop_path.py
@@ -0,0 +1,32 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
+
+from torch import nn
+
+
+def drop_path(x, drop_prob: float = 0.0, training: bool = False):
+ if drop_prob == 0.0 or not training:
+ return x
+ keep_prob = 1 - drop_prob
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
+ if keep_prob > 0.0:
+ random_tensor.div_(keep_prob)
+ return x * random_tensor
+
+
+class DropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
+
+ def __init__(self, drop_prob: float = 0.0):
+ super(DropPath, self).__init__()
+ self.drop_prob = drop_prob
+
+ def forward(self, x):
+ return drop_path(x, self.drop_prob, self.training)
diff --git a/mapper/models/dinov2/eval/segmentation_m2f/models/backbones/vit.py b/mapper/models/dinov2/eval/segmentation_m2f/models/backbones/vit.py
new file mode 100644
index 0000000000000000000000000000000000000000..8a147570451bd2fbd016ddfafbbfa33035cbd4f8
--- /dev/null
+++ b/mapper/models/dinov2/eval/segmentation_m2f/models/backbones/vit.py
@@ -0,0 +1,552 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+"""Vision Transformer (ViT) in PyTorch.
+
+A PyTorch implement of Vision Transformers as described in:
+
+'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale'
+ - https://arxiv.org/abs/2010.11929
+
+`How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers`
+ - https://arxiv.org/abs/2106.10270
+
+The official jax code is released and available at https://github.com/google-research/vision_transformer
+
+DeiT model defs and weights from https://github.com/facebookresearch/deit,
+paper `DeiT: Data-efficient Image Transformers` - https://arxiv.org/abs/2012.12877
+
+Acknowledgments:
+* The paper authors for releasing code and weights, thanks!
+* I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch ... check it out
+for some einops/einsum fun
+* Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT
+* Bert reference code checks against Huggingface Transformers and Tensorflow Bert
+
+Hacked together by / Copyright 2021 Ross Wightman
+"""
+import logging
+import math
+from functools import partial
+from itertools import repeat
+from typing import Callable, Optional
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint as cp
+from mmcv.runner import BaseModule, load_checkpoint
+from mmseg.ops import resize
+from mmseg.utils import get_root_logger
+from torch import Tensor
+
+from .drop_path import DropPath
+
+
+def to_2tuple(x):
+ return tuple(repeat(x, 2))
+
+
+class Mlp(nn.Module):
+ def __init__(
+ self,
+ in_features: int,
+ hidden_features: Optional[int] = None,
+ out_features: Optional[int] = None,
+ act_layer: Callable[..., nn.Module] = nn.GELU,
+ drop: float = 0.0,
+ bias: bool = True,
+ ) -> None:
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x: Tensor) -> Tensor:
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
+
+
+class SwiGLUFFN(nn.Module):
+ def __init__(
+ self,
+ in_features: int,
+ hidden_features: Optional[int] = None,
+ out_features: Optional[int] = None,
+ act_layer: Callable[..., nn.Module] = None,
+ drop: float = 0.0,
+ ) -> None:
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ swiglu_hidden_features = int(2 * hidden_features / 3)
+ align_as = 8
+ swiglu_hidden_features = (swiglu_hidden_features + align_as - 1) // align_as * align_as
+ self.w1 = nn.Linear(in_features, swiglu_hidden_features)
+ self.w2 = nn.Linear(in_features, swiglu_hidden_features)
+ self.w3 = nn.Linear(swiglu_hidden_features, out_features)
+
+ def forward(self, x: Tensor) -> Tensor:
+ x1 = self.w1(x)
+ x2 = self.w2(x)
+ hidden = F.silu(x1) * x2
+ return self.w3(hidden)
+
+
+class PatchEmbed(nn.Module):
+ """2D Image to Patch Embedding."""
+
+ def __init__(
+ self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True, bias=True
+ ):
+ super().__init__()
+ img_size = to_2tuple(img_size)
+ patch_size = to_2tuple(patch_size)
+ self.img_size = img_size
+ self.patch_size = patch_size
+ self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
+ self.num_patches = self.grid_size[0] * self.grid_size[1]
+ self.flatten = flatten
+
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
+
+ def forward(self, x):
+ x = self.proj(x)
+ _, _, H, W = x.shape
+ if self.flatten:
+ x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
+ x = self.norm(x)
+ return x, H, W
+
+
+class Attention(nn.Module):
+ def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0.0, proj_drop=0.0):
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = head_dim**-0.5
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x, H, W):
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+ q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
+
+ attn = (q @ k.transpose(-2, -1)) * self.scale
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class MemEffAttention(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int = 8,
+ qkv_bias: bool = False,
+ attn_drop: float = 0.0,
+ proj_drop: float = 0.0,
+ ) -> None:
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = head_dim**-0.5
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x: Tensor, H, W) -> Tensor:
+ from xformers.ops import memory_efficient_attention, unbind
+
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
+
+ q, k, v = unbind(qkv, 2)
+
+ x = memory_efficient_attention(q, k, v)
+ x = x.reshape([B, N, C])
+
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+def window_partition(x, window_size):
+ """
+ Args:
+ x: (B, H, W, C)
+ window_size (int): window size
+ Returns:
+ windows: (num_windows*B, window_size, window_size, C)
+ """
+ B, H, W, C = x.shape
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
+ return windows
+
+
+def window_reverse(windows, window_size, H, W):
+ """
+ Args:
+ windows: (num_windows*B, window_size, window_size, C)
+ window_size (int): Window size
+ H (int): Height of image
+ W (int): Width of image
+ Returns:
+ x: (B, H, W, C)
+ """
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
+ x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
+ return x
+
+
+class WindowedAttention(nn.Module):
+ def __init__(
+ self, dim, num_heads=8, qkv_bias=False, attn_drop=0.0, proj_drop=0.0, window_size=14, pad_mode="constant"
+ ):
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = head_dim**-0.5
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+ self.window_size = window_size
+ self.pad_mode = pad_mode
+
+ def forward(self, x, H, W):
+ B, N, C = x.shape
+ N_ = self.window_size * self.window_size
+ H_ = math.ceil(H / self.window_size) * self.window_size
+ W_ = math.ceil(W / self.window_size) * self.window_size
+
+ qkv = self.qkv(x) # [B, N, C]
+ qkv = qkv.transpose(1, 2).reshape(B, C * 3, H, W) # [B, C, H, W]
+ qkv = F.pad(qkv, [0, W_ - W, 0, H_ - H], mode=self.pad_mode)
+
+ qkv = F.unfold(
+ qkv, kernel_size=(self.window_size, self.window_size), stride=(self.window_size, self.window_size)
+ )
+ B, C_kw_kw, L = qkv.shape # L - the num of windows
+ qkv = qkv.reshape(B, C * 3, N_, L).permute(0, 3, 2, 1) # [B, L, N_, C]
+ qkv = qkv.reshape(B, L, N_, 3, self.num_heads, C // self.num_heads).permute(3, 0, 1, 4, 2, 5)
+ q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
+
+ # q,k,v [B, L, num_head, N_, C/num_head]
+ attn = (q @ k.transpose(-2, -1)) * self.scale # [B, L, num_head, N_, N_]
+ # if self.mask:
+ # attn = attn * mask
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn) # [B, L, num_head, N_, N_]
+ # attn @ v = [B, L, num_head, N_, C/num_head]
+ x = (attn @ v).permute(0, 2, 4, 3, 1).reshape(B, C_kw_kw // 3, L)
+
+ x = F.fold(
+ x,
+ output_size=(H_, W_),
+ kernel_size=(self.window_size, self.window_size),
+ stride=(self.window_size, self.window_size),
+ ) # [B, C, H_, W_]
+ x = x[:, :, :H, :W].reshape(B, C, N).transpose(-1, -2)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+# class WindowedAttention(nn.Module):
+# def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., window_size=14, pad_mode="constant"):
+# super().__init__()
+# self.num_heads = num_heads
+# head_dim = dim // num_heads
+# self.scale = head_dim ** -0.5
+#
+# self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+# self.attn_drop = nn.Dropout(attn_drop)
+# self.proj = nn.Linear(dim, dim)
+# self.proj_drop = nn.Dropout(proj_drop)
+# self.window_size = window_size
+# self.pad_mode = pad_mode
+#
+# def forward(self, x, H, W):
+# B, N, C = x.shape
+#
+# N_ = self.window_size * self.window_size
+# H_ = math.ceil(H / self.window_size) * self.window_size
+# W_ = math.ceil(W / self.window_size) * self.window_size
+# x = x.view(B, H, W, C)
+# x = F.pad(x, [0, 0, 0, W_ - W, 0, H_- H], mode=self.pad_mode)
+#
+# x = window_partition(x, window_size=self.window_size)# nW*B, window_size, window_size, C
+# x = x.view(-1, N_, C)
+#
+# qkv = self.qkv(x).view(-1, N_, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+# q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
+# attn = (q @ k.transpose(-2, -1)) * self.scale # [B, L, num_head, N_, N_]
+# attn = attn.softmax(dim=-1)
+# attn = self.attn_drop(attn) # [B, L, num_head, N_, N_]
+# x = (attn @ v).transpose(1, 2).reshape(-1, self.window_size, self.window_size, C)
+#
+# x = window_reverse(x, self.window_size, H_, W_)
+# x = x[:, :H, :W, :].reshape(B, N, C).contiguous()
+# x = self.proj(x)
+# x = self.proj_drop(x)
+# return x
+
+
+class Block(nn.Module):
+ def __init__(
+ self,
+ dim,
+ num_heads,
+ mlp_ratio=4.0,
+ qkv_bias=False,
+ drop=0.0,
+ attn_drop=0.0,
+ drop_path=0.0,
+ act_layer=nn.GELU,
+ norm_layer=nn.LayerNorm,
+ windowed=False,
+ window_size=14,
+ pad_mode="constant",
+ layer_scale=False,
+ with_cp=False,
+ ffn_layer=Mlp,
+ memeff=False,
+ ):
+ super().__init__()
+ self.with_cp = with_cp
+ self.norm1 = norm_layer(dim)
+ if windowed:
+ self.attn = WindowedAttention(
+ dim,
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ attn_drop=attn_drop,
+ proj_drop=drop,
+ window_size=window_size,
+ pad_mode=pad_mode,
+ )
+ elif memeff:
+ self.attn = MemEffAttention(
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop
+ )
+ else:
+ self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = ffn_layer(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+ self.layer_scale = layer_scale
+ if layer_scale:
+ self.gamma1 = nn.Parameter(torch.ones((dim)), requires_grad=True)
+ self.gamma2 = nn.Parameter(torch.ones((dim)), requires_grad=True)
+
+ def forward(self, x, H, W):
+ def _inner_forward(x):
+ if self.layer_scale:
+ x = x + self.drop_path(self.gamma1 * self.attn(self.norm1(x), H, W))
+ x = x + self.drop_path(self.gamma2 * self.mlp(self.norm2(x)))
+ else:
+ x = x + self.drop_path(self.attn(self.norm1(x), H, W))
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
+ return x
+
+ if self.with_cp and x.requires_grad:
+ x = cp.checkpoint(_inner_forward, x)
+ else:
+ x = _inner_forward(x)
+
+ return x
+
+
+class TIMMVisionTransformer(BaseModule):
+ """Vision Transformer.
+
+ A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
+ - https://arxiv.org/abs/2010.11929
+
+ Includes distillation token & head support for `DeiT: Data-efficient Image Transformers`
+ - https://arxiv.org/abs/2012.12877
+ """
+
+ def __init__(
+ self,
+ img_size=224,
+ patch_size=16,
+ in_chans=3,
+ num_classes=1000,
+ embed_dim=768,
+ depth=12,
+ num_heads=12,
+ mlp_ratio=4.0,
+ qkv_bias=True,
+ drop_rate=0.0,
+ attn_drop_rate=0.0,
+ drop_path_rate=0.0,
+ layer_scale=True,
+ embed_layer=PatchEmbed,
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
+ act_layer=nn.GELU,
+ window_attn=False,
+ window_size=14,
+ pretrained=None,
+ with_cp=False,
+ pre_norm=False,
+ ffn_type="mlp",
+ memeff=False,
+ ):
+ """
+ Args:
+ img_size (int, tuple): input image size
+ patch_size (int, tuple): patch size
+ in_chans (int): number of input channels
+ num_classes (int): number of classes for classification head
+ embed_dim (int): embedding dimension
+ depth (int): depth of transformer
+ num_heads (int): number of attention heads
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
+ qkv_bias (bool): enable bias for qkv if True
+ drop_rate (float): dropout rate
+ attn_drop_rate (float): attention dropout rate
+ drop_path_rate (float): stochastic depth rate
+ embed_layer (nn.Module): patch embedding layer
+ norm_layer: (nn.Module): normalization layer
+ pretrained: (str): pretrained path
+ """
+ super().__init__()
+ self.num_classes = num_classes
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
+ self.num_tokens = 1
+ norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
+ act_layer = act_layer or nn.GELU
+ self.norm_layer = norm_layer
+ self.act_layer = act_layer
+ self.pretrain_size = img_size
+ self.drop_path_rate = drop_path_rate
+ self.drop_rate = drop_rate
+ self.patch_size = patch_size
+
+ window_attn = [window_attn] * depth if not isinstance(window_attn, list) else window_attn
+ window_size = [window_size] * depth if not isinstance(window_size, list) else window_size
+ logging.info("window attention:", window_attn)
+ logging.info("window size:", window_size)
+ logging.info("layer scale:", layer_scale)
+
+ self.patch_embed = embed_layer(
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, bias=not pre_norm
+ )
+ num_patches = self.patch_embed.num_patches
+
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
+ self.pos_drop = nn.Dropout(p=drop_rate)
+
+ ffn_types = {"mlp": Mlp, "swiglu": SwiGLUFFN}
+
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
+ self.blocks = nn.Sequential(
+ *[
+ Block(
+ dim=embed_dim,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ drop=drop_rate,
+ attn_drop=attn_drop_rate,
+ drop_path=dpr[i],
+ norm_layer=norm_layer,
+ act_layer=act_layer,
+ windowed=window_attn[i],
+ window_size=window_size[i],
+ layer_scale=layer_scale,
+ with_cp=with_cp,
+ ffn_layer=ffn_types[ffn_type],
+ memeff=memeff,
+ )
+ for i in range(depth)
+ ]
+ )
+
+ # self.norm = norm_layer(embed_dim)
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
+ # For CLIP
+ if pre_norm:
+ norm_pre = norm_layer(embed_dim)
+ self.norm_pre = norm_pre
+ else:
+ self.norm_pre = nn.Identity()
+ self.init_weights(pretrained)
+
+ def init_weights(self, pretrained=None):
+ if isinstance(pretrained, str):
+ logger = get_root_logger()
+ load_checkpoint(self, pretrained, map_location="cpu", strict=False, logger=logger)
+
+ def forward_features(self, x):
+ x, H, W = self.patch_embed(x)
+ cls_token = self.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks
+ x = torch.cat((cls_token, x), dim=1)
+ x = self.pos_drop(x + self.pos_embed)
+
+ # For CLIP
+ x = self.norm_pre(x)
+
+ for blk in self.blocks:
+ x = blk(x, H, W)
+ x = self.norm(x)
+ return x
+
+ def forward(self, x):
+ x = self.forward_features(x)
+ return x
+
+ @staticmethod
+ def resize_pos_embed(pos_embed, input_shpae, pos_shape, mode):
+ """Resize pos_embed weights.
+
+ Resize pos_embed using bicubic interpolate method.
+ Args:
+ pos_embed (torch.Tensor): Position embedding weights.
+ input_shpae (tuple): Tuple for (downsampled input image height,
+ downsampled input image width).
+ pos_shape (tuple): The resolution of downsampled origin training
+ image.
+ mode (str): Algorithm used for upsampling:
+ ``'nearest'`` | ``'linear'`` | ``'bilinear'`` | ``'bicubic'`` |
+ ``'trilinear'``. Default: ``'nearest'``
+ Return:
+ torch.Tensor: The resized pos_embed of shape [B, L_new, C]
+ """
+ assert pos_embed.ndim == 3, "shape of pos_embed must be [B, L, C]"
+ pos_h, pos_w = pos_shape
+ # keep dim for easy deployment
+ cls_token_weight = pos_embed[:, 0:1]
+ pos_embed_weight = pos_embed[:, (-1 * pos_h * pos_w) :]
+ pos_embed_weight = pos_embed_weight.reshape(1, pos_h, pos_w, pos_embed.shape[2]).permute(0, 3, 1, 2)
+ pos_embed_weight = resize(pos_embed_weight, size=input_shpae, align_corners=False, mode=mode)
+ pos_embed_weight = torch.flatten(pos_embed_weight, 2).transpose(1, 2)
+ pos_embed = torch.cat((cls_token_weight, pos_embed_weight), dim=1)
+ return pos_embed
diff --git a/mapper/models/dinov2/eval/segmentation_m2f/models/backbones/vit_adapter.py b/mapper/models/dinov2/eval/segmentation_m2f/models/backbones/vit_adapter.py
new file mode 100644
index 0000000000000000000000000000000000000000..ebc4f0f65e04ed764464d141607b3b2073220f6b
--- /dev/null
+++ b/mapper/models/dinov2/eval/segmentation_m2f/models/backbones/vit_adapter.py
@@ -0,0 +1,217 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+import math
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from mmseg.models.builder import BACKBONES
+from torch.nn.init import normal_
+
+from ...ops.modules import MSDeformAttn
+from .adapter_modules import InteractionBlock, InteractionBlockWithCls, SpatialPriorModule, deform_inputs
+from .vit import TIMMVisionTransformer
+
+
+@BACKBONES.register_module()
+class ViTAdapter(TIMMVisionTransformer):
+ def __init__(
+ self,
+ pretrain_size=224,
+ num_heads=12,
+ conv_inplane=64,
+ n_points=4,
+ deform_num_heads=6,
+ init_values=0.0,
+ interaction_indexes=None,
+ with_cffn=True,
+ cffn_ratio=0.25,
+ deform_ratio=1.0,
+ add_vit_feature=True,
+ pretrained=None,
+ use_extra_extractor=True,
+ freeze_vit=False,
+ use_cls=True,
+ with_cp=False,
+ *args,
+ **kwargs
+ ):
+
+ super().__init__(num_heads=num_heads, pretrained=pretrained, with_cp=with_cp, *args, **kwargs)
+ if freeze_vit:
+ for param in self.parameters():
+ param.requires_grad = False
+
+ # self.num_classes = 80
+ self.use_cls = use_cls
+ if not self.use_cls:
+ self.cls_token = None
+ self.num_block = len(self.blocks)
+ self.pretrain_size = (pretrain_size, pretrain_size)
+ self.interaction_indexes = interaction_indexes
+ self.add_vit_feature = add_vit_feature
+ embed_dim = self.embed_dim
+
+ block_fn = InteractionBlockWithCls if use_cls else InteractionBlock
+
+ self.level_embed = nn.Parameter(torch.zeros(3, embed_dim))
+ self.spm = SpatialPriorModule(inplanes=conv_inplane, embed_dim=embed_dim, with_cp=False)
+ self.interactions = nn.Sequential(
+ *[
+ block_fn(
+ dim=embed_dim,
+ num_heads=deform_num_heads,
+ n_points=n_points,
+ init_values=init_values,
+ drop_path=self.drop_path_rate,
+ norm_layer=self.norm_layer,
+ with_cffn=with_cffn,
+ cffn_ratio=cffn_ratio,
+ deform_ratio=deform_ratio,
+ extra_extractor=((True if i == len(interaction_indexes) - 1 else False) and use_extra_extractor),
+ with_cp=with_cp,
+ )
+ for i in range(len(interaction_indexes))
+ ]
+ )
+ self.up = nn.ConvTranspose2d(embed_dim, embed_dim, 2, 2)
+ self.norm1 = nn.SyncBatchNorm(embed_dim)
+ self.norm2 = nn.SyncBatchNorm(embed_dim)
+ self.norm3 = nn.SyncBatchNorm(embed_dim)
+ self.norm4 = nn.SyncBatchNorm(embed_dim)
+
+ self.up.apply(self._init_weights)
+ self.spm.apply(self._init_weights)
+ self.interactions.apply(self._init_weights)
+ self.apply(self._init_deform_weights)
+ normal_(self.level_embed)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ torch.nn.init.trunc_normal_(m.weight, std=0.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm) or isinstance(m, nn.BatchNorm2d):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+ elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
+ fan_out //= m.groups
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
+ if m.bias is not None:
+ m.bias.data.zero_()
+
+ def _get_pos_embed(self, pos_embed, H, W):
+ pos_embed = pos_embed.reshape(
+ 1, self.pretrain_size[0] // self.patch_size, self.pretrain_size[1] // self.patch_size, -1
+ ).permute(0, 3, 1, 2)
+ pos_embed = (
+ F.interpolate(pos_embed, size=(H, W), mode="bicubic", align_corners=False)
+ .reshape(1, -1, H * W)
+ .permute(0, 2, 1)
+ )
+ return pos_embed
+
+ def _init_deform_weights(self, m):
+ if isinstance(m, MSDeformAttn):
+ m._reset_parameters()
+
+ def _add_level_embed(self, c2, c3, c4):
+ c2 = c2 + self.level_embed[0]
+ c3 = c3 + self.level_embed[1]
+ c4 = c4 + self.level_embed[2]
+ return c2, c3, c4
+
+ def forward(self, x):
+ deform_inputs1, deform_inputs2 = deform_inputs(x, self.patch_size)
+
+ # SPM forward
+ c1, c2, c3, c4 = self.spm(x)
+ c2, c3, c4 = self._add_level_embed(c2, c3, c4)
+ c = torch.cat([c2, c3, c4], dim=1)
+
+ # Patch Embedding forward
+ H_c, W_c = x.shape[2] // 16, x.shape[3] // 16
+ x, H_toks, W_toks = self.patch_embed(x)
+ # print("H_toks, W_toks =", H_toks, W_toks)
+ bs, n, dim = x.shape
+ pos_embed = self._get_pos_embed(self.pos_embed[:, 1:], H_toks, W_toks)
+ if self.use_cls:
+ cls_token = self.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks
+ x = torch.cat((cls_token, x), dim=1)
+ pos_embed = torch.cat((self.pos_embed[:, :1], pos_embed), dim=1)
+ x = self.pos_drop(x + pos_embed)
+ # For CLIP
+ x = self.norm_pre(x)
+
+ # Interaction
+ if self.use_cls:
+ cls, x = (
+ x[
+ :,
+ :1,
+ ],
+ x[
+ :,
+ 1:,
+ ],
+ )
+ outs = list()
+ for i, layer in enumerate(self.interactions):
+ indexes = self.interaction_indexes[i]
+ if self.use_cls:
+ x, c, cls = layer(
+ x,
+ c,
+ cls,
+ self.blocks[indexes[0] : indexes[-1] + 1],
+ deform_inputs1,
+ deform_inputs2,
+ H_c,
+ W_c,
+ H_toks,
+ W_toks,
+ )
+ else:
+ x, c = layer(
+ x,
+ c,
+ self.blocks[indexes[0] : indexes[-1] + 1],
+ deform_inputs1,
+ deform_inputs2,
+ H_c,
+ W_c,
+ H_toks,
+ W_toks,
+ )
+ outs.append(x.transpose(1, 2).view(bs, dim, H_toks, W_toks).contiguous())
+
+ # Split & Reshape
+ c2 = c[:, 0 : c2.size(1), :]
+ c3 = c[:, c2.size(1) : c2.size(1) + c3.size(1), :]
+ c4 = c[:, c2.size(1) + c3.size(1) :, :]
+
+ c2 = c2.transpose(1, 2).view(bs, dim, H_c * 2, W_c * 2).contiguous()
+ c3 = c3.transpose(1, 2).view(bs, dim, H_c, W_c).contiguous()
+ c4 = c4.transpose(1, 2).view(bs, dim, H_c // 2, W_c // 2).contiguous()
+ c1 = self.up(c2) + c1
+
+ if self.add_vit_feature:
+ x1, x2, x3, x4 = outs
+
+ x1 = F.interpolate(x1, size=(4 * H_c, 4 * W_c), mode="bilinear", align_corners=False)
+ x2 = F.interpolate(x2, size=(2 * H_c, 2 * W_c), mode="bilinear", align_corners=False)
+ x3 = F.interpolate(x3, size=(1 * H_c, 1 * W_c), mode="bilinear", align_corners=False)
+ x4 = F.interpolate(x4, size=(H_c // 2, W_c // 2), mode="bilinear", align_corners=False)
+ # print(c1.shape, c2.shape, c3.shape, c4.shape, x1.shape, x2.shape, x3.shape, x4.shape, H_c, H_toks)
+ c1, c2, c3, c4 = c1 + x1, c2 + x2, c3 + x3, c4 + x4
+
+ # Final Norm
+ f1 = self.norm1(c1)
+ f2 = self.norm2(c2)
+ f3 = self.norm3(c3)
+ f4 = self.norm4(c4)
+ return [f1, f2, f3, f4]
diff --git a/mapper/models/dinov2/eval/segmentation_m2f/models/builder.py b/mapper/models/dinov2/eval/segmentation_m2f/models/builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..d7cf7b919f6b0e8e00bde45bc244d9c29a36fed6
--- /dev/null
+++ b/mapper/models/dinov2/eval/segmentation_m2f/models/builder.py
@@ -0,0 +1,25 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+from mmcv.utils import Registry
+
+TRANSFORMER = Registry("Transformer")
+MASK_ASSIGNERS = Registry("mask_assigner")
+MATCH_COST = Registry("match_cost")
+
+
+def build_match_cost(cfg):
+ """Build Match Cost."""
+ return MATCH_COST.build(cfg)
+
+
+def build_assigner(cfg):
+ """Build Assigner."""
+ return MASK_ASSIGNERS.build(cfg)
+
+
+def build_transformer(cfg):
+ """Build Transformer."""
+ return TRANSFORMER.build(cfg)
diff --git a/mapper/models/dinov2/eval/segmentation_m2f/models/decode_heads/__init__.py b/mapper/models/dinov2/eval/segmentation_m2f/models/decode_heads/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..01f08b88950750337781fc671adfea2a935ea8fe
--- /dev/null
+++ b/mapper/models/dinov2/eval/segmentation_m2f/models/decode_heads/__init__.py
@@ -0,0 +1,6 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+from .mask2former_head import Mask2FormerHead
diff --git a/mapper/models/dinov2/eval/segmentation_m2f/models/decode_heads/mask2former_head.py b/mapper/models/dinov2/eval/segmentation_m2f/models/decode_heads/mask2former_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..d1705fc444fa8d1583d88fca36d7fe1e060db9e7
--- /dev/null
+++ b/mapper/models/dinov2/eval/segmentation_m2f/models/decode_heads/mask2former_head.py
@@ -0,0 +1,544 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+import copy
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from mmcv.cnn import Conv2d, build_plugin_layer, caffe2_xavier_init
+from mmcv.cnn.bricks.transformer import build_positional_encoding, build_transformer_layer_sequence
+from mmcv.ops import point_sample
+from mmcv.runner import ModuleList, force_fp32
+from mmseg.models.builder import HEADS, build_loss
+from mmseg.models.decode_heads.decode_head import BaseDecodeHead
+
+from ...core import build_sampler, multi_apply, reduce_mean
+from ..builder import build_assigner
+from ..utils import get_uncertain_point_coords_with_randomness
+
+
+@HEADS.register_module()
+class Mask2FormerHead(BaseDecodeHead):
+ """Implements the Mask2Former head.
+
+ See `Masked-attention Mask Transformer for Universal Image
+ Segmentation `_ for details.
+
+ Args:
+ in_channels (list[int]): Number of channels in the input feature map.
+ feat_channels (int): Number of channels for features.
+ out_channels (int): Number of channels for output.
+ num_things_classes (int): Number of things.
+ num_stuff_classes (int): Number of stuff.
+ num_queries (int): Number of query in Transformer decoder.
+ pixel_decoder (:obj:`mmcv.ConfigDict` | dict): Config for pixel
+ decoder. Defaults to None.
+ enforce_decoder_input_project (bool, optional): Whether to add
+ a layer to change the embed_dim of tranformer encoder in
+ pixel decoder to the embed_dim of transformer decoder.
+ Defaults to False.
+ transformer_decoder (:obj:`mmcv.ConfigDict` | dict): Config for
+ transformer decoder. Defaults to None.
+ positional_encoding (:obj:`mmcv.ConfigDict` | dict): Config for
+ transformer decoder position encoding. Defaults to None.
+ loss_cls (:obj:`mmcv.ConfigDict` | dict): Config of the classification
+ loss. Defaults to None.
+ loss_mask (:obj:`mmcv.ConfigDict` | dict): Config of the mask loss.
+ Defaults to None.
+ loss_dice (:obj:`mmcv.ConfigDict` | dict): Config of the dice loss.
+ Defaults to None.
+ train_cfg (:obj:`mmcv.ConfigDict` | dict): Training config of
+ Mask2Former head.
+ test_cfg (:obj:`mmcv.ConfigDict` | dict): Testing config of
+ Mask2Former head.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Defaults to None.
+ """
+
+ def __init__(
+ self,
+ in_channels,
+ feat_channels,
+ out_channels,
+ num_things_classes=80,
+ num_stuff_classes=53,
+ num_queries=100,
+ num_transformer_feat_level=3,
+ pixel_decoder=None,
+ enforce_decoder_input_project=False,
+ transformer_decoder=None,
+ positional_encoding=None,
+ loss_cls=None,
+ loss_mask=None,
+ loss_dice=None,
+ train_cfg=None,
+ test_cfg=None,
+ init_cfg=None,
+ **kwargs,
+ ):
+ super(Mask2FormerHead, self).__init__(
+ in_channels=in_channels,
+ channels=feat_channels,
+ num_classes=(num_things_classes + num_stuff_classes),
+ init_cfg=init_cfg,
+ input_transform="multiple_select",
+ **kwargs,
+ )
+ self.num_things_classes = num_things_classes
+ self.num_stuff_classes = num_stuff_classes
+ self.num_classes = self.num_things_classes + self.num_stuff_classes
+ self.num_queries = num_queries
+ self.num_transformer_feat_level = num_transformer_feat_level
+ self.num_heads = transformer_decoder.transformerlayers.attn_cfgs.num_heads
+ self.num_transformer_decoder_layers = transformer_decoder.num_layers
+ assert pixel_decoder.encoder.transformerlayers.attn_cfgs.num_levels == num_transformer_feat_level
+ pixel_decoder_ = copy.deepcopy(pixel_decoder)
+ pixel_decoder_.update(in_channels=in_channels, feat_channels=feat_channels, out_channels=out_channels)
+ self.pixel_decoder = build_plugin_layer(pixel_decoder_)[1]
+ self.transformer_decoder = build_transformer_layer_sequence(transformer_decoder)
+ self.decoder_embed_dims = self.transformer_decoder.embed_dims
+
+ self.decoder_input_projs = ModuleList()
+ # from low resolution to high resolution
+ for _ in range(num_transformer_feat_level):
+ if self.decoder_embed_dims != feat_channels or enforce_decoder_input_project:
+ self.decoder_input_projs.append(Conv2d(feat_channels, self.decoder_embed_dims, kernel_size=1))
+ else:
+ self.decoder_input_projs.append(nn.Identity())
+ self.decoder_positional_encoding = build_positional_encoding(positional_encoding)
+ self.query_embed = nn.Embedding(self.num_queries, feat_channels)
+ self.query_feat = nn.Embedding(self.num_queries, feat_channels)
+ # from low resolution to high resolution
+ self.level_embed = nn.Embedding(self.num_transformer_feat_level, feat_channels)
+
+ self.cls_embed = nn.Linear(feat_channels, self.num_classes + 1)
+ self.mask_embed = nn.Sequential(
+ nn.Linear(feat_channels, feat_channels),
+ nn.ReLU(inplace=True),
+ nn.Linear(feat_channels, feat_channels),
+ nn.ReLU(inplace=True),
+ nn.Linear(feat_channels, out_channels),
+ )
+ self.conv_seg = None # fix a bug here (conv_seg is not used)
+
+ self.test_cfg = test_cfg
+ self.train_cfg = train_cfg
+ if train_cfg:
+ self.assigner = build_assigner(self.train_cfg.assigner)
+ self.sampler = build_sampler(self.train_cfg.sampler, context=self)
+ self.num_points = self.train_cfg.get("num_points", 12544)
+ self.oversample_ratio = self.train_cfg.get("oversample_ratio", 3.0)
+ self.importance_sample_ratio = self.train_cfg.get("importance_sample_ratio", 0.75)
+
+ self.class_weight = loss_cls.class_weight
+ self.loss_cls = build_loss(loss_cls)
+ self.loss_mask = build_loss(loss_mask)
+ self.loss_dice = build_loss(loss_dice)
+
+ def init_weights(self):
+ for m in self.decoder_input_projs:
+ if isinstance(m, Conv2d):
+ caffe2_xavier_init(m, bias=0)
+
+ self.pixel_decoder.init_weights()
+
+ for p in self.transformer_decoder.parameters():
+ if p.dim() > 1:
+ nn.init.xavier_normal_(p)
+
+ def get_targets(self, cls_scores_list, mask_preds_list, gt_labels_list, gt_masks_list, img_metas):
+ """Compute classification and mask targets for all images for a decoder
+ layer.
+
+ Args:
+ cls_scores_list (list[Tensor]): Mask score logits from a single
+ decoder layer for all images. Each with shape [num_queries,
+ cls_out_channels].
+ mask_preds_list (list[Tensor]): Mask logits from a single decoder
+ layer for all images. Each with shape [num_queries, h, w].
+ gt_labels_list (list[Tensor]): Ground truth class indices for all
+ images. Each with shape (n, ), n is the sum of number of stuff
+ type and number of instance in a image.
+ gt_masks_list (list[Tensor]): Ground truth mask for each image,
+ each with shape (n, h, w).
+ img_metas (list[dict]): List of image meta information.
+
+ Returns:
+ tuple[list[Tensor]]: a tuple containing the following targets.
+
+ - labels_list (list[Tensor]): Labels of all images.
+ Each with shape [num_queries, ].
+ - label_weights_list (list[Tensor]): Label weights of all
+ images.Each with shape [num_queries, ].
+ - mask_targets_list (list[Tensor]): Mask targets of all images.
+ Each with shape [num_queries, h, w].
+ - mask_weights_list (list[Tensor]): Mask weights of all images.
+ Each with shape [num_queries, ].
+ - num_total_pos (int): Number of positive samples in all
+ images.
+ - num_total_neg (int): Number of negative samples in all
+ images.
+ """
+ (
+ labels_list,
+ label_weights_list,
+ mask_targets_list,
+ mask_weights_list,
+ pos_inds_list,
+ neg_inds_list,
+ ) = multi_apply(
+ self._get_target_single, cls_scores_list, mask_preds_list, gt_labels_list, gt_masks_list, img_metas
+ )
+
+ num_total_pos = sum((inds.numel() for inds in pos_inds_list))
+ num_total_neg = sum((inds.numel() for inds in neg_inds_list))
+ return (labels_list, label_weights_list, mask_targets_list, mask_weights_list, num_total_pos, num_total_neg)
+
+ def _get_target_single(self, cls_score, mask_pred, gt_labels, gt_masks, img_metas):
+ """Compute classification and mask targets for one image.
+
+ Args:
+ cls_score (Tensor): Mask score logits from a single decoder layer
+ for one image. Shape (num_queries, cls_out_channels).
+ mask_pred (Tensor): Mask logits for a single decoder layer for one
+ image. Shape (num_queries, h, w).
+ gt_labels (Tensor): Ground truth class indices for one image with
+ shape (num_gts, ).
+ gt_masks (Tensor): Ground truth mask for each image, each with
+ shape (num_gts, h, w).
+ img_metas (dict): Image informtation.
+
+ Returns:
+ tuple[Tensor]: A tuple containing the following for one image.
+
+ - labels (Tensor): Labels of each image. \
+ shape (num_queries, ).
+ - label_weights (Tensor): Label weights of each image. \
+ shape (num_queries, ).
+ - mask_targets (Tensor): Mask targets of each image. \
+ shape (num_queries, h, w).
+ - mask_weights (Tensor): Mask weights of each image. \
+ shape (num_queries, ).
+ - pos_inds (Tensor): Sampled positive indices for each \
+ image.
+ - neg_inds (Tensor): Sampled negative indices for each \
+ image.
+ """
+ # sample points
+ num_queries = cls_score.shape[0]
+ num_gts = gt_labels.shape[0]
+
+ point_coords = torch.rand((1, self.num_points, 2), device=cls_score.device)
+ # shape (num_queries, num_points)
+ mask_points_pred = point_sample(mask_pred.unsqueeze(1), point_coords.repeat(num_queries, 1, 1)).squeeze(1)
+ # shape (num_gts, num_points)
+ gt_points_masks = point_sample(gt_masks.unsqueeze(1).float(), point_coords.repeat(num_gts, 1, 1)).squeeze(1)
+
+ # assign and sample
+ assign_result = self.assigner.assign(cls_score, mask_points_pred, gt_labels, gt_points_masks, img_metas)
+ sampling_result = self.sampler.sample(assign_result, mask_pred, gt_masks)
+ pos_inds = sampling_result.pos_inds
+ neg_inds = sampling_result.neg_inds
+
+ # label target
+ labels = gt_labels.new_full((self.num_queries,), self.num_classes, dtype=torch.long)
+ labels[pos_inds] = gt_labels[sampling_result.pos_assigned_gt_inds]
+ label_weights = gt_labels.new_ones((self.num_queries,))
+
+ # mask target
+ mask_targets = gt_masks[sampling_result.pos_assigned_gt_inds]
+ mask_weights = mask_pred.new_zeros((self.num_queries,))
+ mask_weights[pos_inds] = 1.0
+
+ return (labels, label_weights, mask_targets, mask_weights, pos_inds, neg_inds)
+
+ def loss_single(self, cls_scores, mask_preds, gt_labels_list, gt_masks_list, img_metas):
+ """Loss function for outputs from a single decoder layer.
+
+ Args:
+ cls_scores (Tensor): Mask score logits from a single decoder layer
+ for all images. Shape (batch_size, num_queries,
+ cls_out_channels). Note `cls_out_channels` should includes
+ background.
+ mask_preds (Tensor): Mask logits for a pixel decoder for all
+ images. Shape (batch_size, num_queries, h, w).
+ gt_labels_list (list[Tensor]): Ground truth class indices for each
+ image, each with shape (num_gts, ).
+ gt_masks_list (list[Tensor]): Ground truth mask for each image,
+ each with shape (num_gts, h, w).
+ img_metas (list[dict]): List of image meta information.
+
+ Returns:
+ tuple[Tensor]: Loss components for outputs from a single \
+ decoder layer.
+ """
+ num_imgs = cls_scores.size(0)
+ cls_scores_list = [cls_scores[i] for i in range(num_imgs)]
+ mask_preds_list = [mask_preds[i] for i in range(num_imgs)]
+ (
+ labels_list,
+ label_weights_list,
+ mask_targets_list,
+ mask_weights_list,
+ num_total_pos,
+ num_total_neg,
+ ) = self.get_targets(cls_scores_list, mask_preds_list, gt_labels_list, gt_masks_list, img_metas)
+ # shape (batch_size, num_queries)
+ labels = torch.stack(labels_list, dim=0)
+ # shape (batch_size, num_queries)
+ label_weights = torch.stack(label_weights_list, dim=0)
+ # shape (num_total_gts, h, w)
+ mask_targets = torch.cat(mask_targets_list, dim=0)
+ # shape (batch_size, num_queries)
+ mask_weights = torch.stack(mask_weights_list, dim=0)
+
+ # classfication loss
+ # shape (batch_size * num_queries, )
+ cls_scores = cls_scores.flatten(0, 1)
+ labels = labels.flatten(0, 1)
+ label_weights = label_weights.flatten(0, 1)
+
+ class_weight = cls_scores.new_tensor(self.class_weight)
+ loss_cls = self.loss_cls(cls_scores, labels, label_weights, avg_factor=class_weight[labels].sum())
+
+ num_total_masks = reduce_mean(cls_scores.new_tensor([num_total_pos]))
+ num_total_masks = max(num_total_masks, 1)
+
+ # extract positive ones
+ # shape (batch_size, num_queries, h, w) -> (num_total_gts, h, w)
+ mask_preds = mask_preds[mask_weights > 0]
+
+ if mask_targets.shape[0] == 0:
+ # zero match
+ loss_dice = mask_preds.sum()
+ loss_mask = mask_preds.sum()
+ return loss_cls, loss_mask, loss_dice
+
+ with torch.no_grad():
+ points_coords = get_uncertain_point_coords_with_randomness(
+ mask_preds.unsqueeze(1), None, self.num_points, self.oversample_ratio, self.importance_sample_ratio
+ )
+ # shape (num_total_gts, h, w) -> (num_total_gts, num_points)
+ mask_point_targets = point_sample(mask_targets.unsqueeze(1).float(), points_coords).squeeze(1)
+ # shape (num_queries, h, w) -> (num_queries, num_points)
+ mask_point_preds = point_sample(mask_preds.unsqueeze(1), points_coords).squeeze(1)
+
+ # dice loss
+ loss_dice = self.loss_dice(mask_point_preds, mask_point_targets, avg_factor=num_total_masks)
+
+ # mask loss
+ # shape (num_queries, num_points) -> (num_queries * num_points, )
+ mask_point_preds = mask_point_preds.reshape(-1, 1)
+ # shape (num_total_gts, num_points) -> (num_total_gts * num_points, )
+ mask_point_targets = mask_point_targets.reshape(-1)
+ loss_mask = self.loss_mask(mask_point_preds, mask_point_targets, avg_factor=num_total_masks * self.num_points)
+
+ return loss_cls, loss_mask, loss_dice
+
+ @force_fp32(apply_to=("all_cls_scores", "all_mask_preds"))
+ def loss(self, all_cls_scores, all_mask_preds, gt_labels_list, gt_masks_list, img_metas):
+ """Loss function.
+
+ Args:
+ all_cls_scores (Tensor): Classification scores for all decoder
+ layers with shape [num_decoder, batch_size, num_queries,
+ cls_out_channels].
+ all_mask_preds (Tensor): Mask scores for all decoder layers with
+ shape [num_decoder, batch_size, num_queries, h, w].
+ gt_labels_list (list[Tensor]): Ground truth class indices for each
+ image with shape (n, ). n is the sum of number of stuff type
+ and number of instance in a image.
+ gt_masks_list (list[Tensor]): Ground truth mask for each image with
+ shape (n, h, w).
+ img_metas (list[dict]): List of image meta information.
+
+ Returns:
+ dict[str, Tensor]: A dictionary of loss components.
+ """
+ num_dec_layers = len(all_cls_scores)
+ all_gt_labels_list = [gt_labels_list for _ in range(num_dec_layers)]
+ all_gt_masks_list = [gt_masks_list for _ in range(num_dec_layers)]
+ img_metas_list = [img_metas for _ in range(num_dec_layers)]
+ losses_cls, losses_mask, losses_dice = multi_apply(
+ self.loss_single, all_cls_scores, all_mask_preds, all_gt_labels_list, all_gt_masks_list, img_metas_list
+ )
+
+ loss_dict = dict()
+ # loss from the last decoder layer
+ loss_dict["loss_cls"] = losses_cls[-1]
+ loss_dict["loss_mask"] = losses_mask[-1]
+ loss_dict["loss_dice"] = losses_dice[-1]
+ # loss from other decoder layers
+ num_dec_layer = 0
+ for loss_cls_i, loss_mask_i, loss_dice_i in zip(losses_cls[:-1], losses_mask[:-1], losses_dice[:-1]):
+ loss_dict[f"d{num_dec_layer}.loss_cls"] = loss_cls_i
+ loss_dict[f"d{num_dec_layer}.loss_mask"] = loss_mask_i
+ loss_dict[f"d{num_dec_layer}.loss_dice"] = loss_dice_i
+ num_dec_layer += 1
+ return loss_dict
+
+ def forward_head(self, decoder_out, mask_feature, attn_mask_target_size):
+ """Forward for head part which is called after every decoder layer.
+
+ Args:
+ decoder_out (Tensor): in shape (num_queries, batch_size, c).
+ mask_feature (Tensor): in shape (batch_size, c, h, w).
+ attn_mask_target_size (tuple[int, int]): target attention
+ mask size.
+
+ Returns:
+ tuple: A tuple contain three elements.
+
+ - cls_pred (Tensor): Classification scores in shape \
+ (batch_size, num_queries, cls_out_channels). \
+ Note `cls_out_channels` should includes background.
+ - mask_pred (Tensor): Mask scores in shape \
+ (batch_size, num_queries,h, w).
+ - attn_mask (Tensor): Attention mask in shape \
+ (batch_size * num_heads, num_queries, h, w).
+ """
+ decoder_out = self.transformer_decoder.post_norm(decoder_out)
+ decoder_out = decoder_out.transpose(0, 1)
+ # shape (num_queries, batch_size, c)
+ cls_pred = self.cls_embed(decoder_out)
+ # shape (num_queries, batch_size, c)
+ mask_embed = self.mask_embed(decoder_out)
+ # shape (num_queries, batch_size, h, w)
+ mask_pred = torch.einsum("bqc,bchw->bqhw", mask_embed, mask_feature)
+ attn_mask = F.interpolate(mask_pred, attn_mask_target_size, mode="bilinear", align_corners=False)
+ # shape (num_queries, batch_size, h, w) ->
+ # (batch_size * num_head, num_queries, h, w)
+ attn_mask = attn_mask.flatten(2).unsqueeze(1).repeat((1, self.num_heads, 1, 1)).flatten(0, 1)
+ attn_mask = attn_mask.sigmoid() < 0.5
+ attn_mask = attn_mask.detach()
+
+ return cls_pred, mask_pred, attn_mask
+
+ def forward(self, feats, img_metas):
+ """Forward function.
+
+ Args:
+ feats (list[Tensor]): Multi scale Features from the
+ upstream network, each is a 4D-tensor.
+ img_metas (list[dict]): List of image information.
+
+ Returns:
+ tuple: A tuple contains two elements.
+
+ - cls_pred_list (list[Tensor)]: Classification logits \
+ for each decoder layer. Each is a 3D-tensor with shape \
+ (batch_size, num_queries, cls_out_channels). \
+ Note `cls_out_channels` should includes background.
+ - mask_pred_list (list[Tensor]): Mask logits for each \
+ decoder layer. Each with shape (batch_size, num_queries, \
+ h, w).
+ """
+ batch_size = len(img_metas)
+ mask_features, multi_scale_memorys = self.pixel_decoder(feats)
+ # multi_scale_memorys (from low resolution to high resolution)
+ decoder_inputs = []
+ decoder_positional_encodings = []
+ for i in range(self.num_transformer_feat_level):
+ decoder_input = self.decoder_input_projs[i](multi_scale_memorys[i])
+ # shape (batch_size, c, h, w) -> (h*w, batch_size, c)
+ decoder_input = decoder_input.flatten(2).permute(2, 0, 1)
+ level_embed = self.level_embed.weight[i].view(1, 1, -1)
+ decoder_input = decoder_input + level_embed
+ # shape (batch_size, c, h, w) -> (h*w, batch_size, c)
+ mask = decoder_input.new_zeros((batch_size,) + multi_scale_memorys[i].shape[-2:], dtype=torch.bool)
+ decoder_positional_encoding = self.decoder_positional_encoding(mask)
+ decoder_positional_encoding = decoder_positional_encoding.flatten(2).permute(2, 0, 1)
+ decoder_inputs.append(decoder_input)
+ decoder_positional_encodings.append(decoder_positional_encoding)
+ # shape (num_queries, c) -> (num_queries, batch_size, c)
+ query_feat = self.query_feat.weight.unsqueeze(1).repeat((1, batch_size, 1))
+ query_embed = self.query_embed.weight.unsqueeze(1).repeat((1, batch_size, 1))
+
+ cls_pred_list = []
+ mask_pred_list = []
+ cls_pred, mask_pred, attn_mask = self.forward_head(query_feat, mask_features, multi_scale_memorys[0].shape[-2:])
+ cls_pred_list.append(cls_pred)
+ mask_pred_list.append(mask_pred)
+
+ for i in range(self.num_transformer_decoder_layers):
+ level_idx = i % self.num_transformer_feat_level
+ # if a mask is all True(all background), then set it all False.
+ attn_mask[torch.where(attn_mask.sum(-1) == attn_mask.shape[-1])] = False
+
+ # cross_attn + self_attn
+ layer = self.transformer_decoder.layers[i]
+ attn_masks = [attn_mask, None]
+ query_feat = layer(
+ query=query_feat,
+ key=decoder_inputs[level_idx],
+ value=decoder_inputs[level_idx],
+ query_pos=query_embed,
+ key_pos=decoder_positional_encodings[level_idx],
+ attn_masks=attn_masks,
+ query_key_padding_mask=None,
+ # here we do not apply masking on padded region
+ key_padding_mask=None,
+ )
+ cls_pred, mask_pred, attn_mask = self.forward_head(
+ query_feat, mask_features, multi_scale_memorys[(i + 1) % self.num_transformer_feat_level].shape[-2:]
+ )
+
+ cls_pred_list.append(cls_pred)
+ mask_pred_list.append(mask_pred)
+
+ return cls_pred_list, mask_pred_list
+
+ def forward_train(self, x, img_metas, gt_semantic_seg, gt_labels, gt_masks):
+ """Forward function for training mode.
+
+ Args:
+ x (list[Tensor]): Multi-level features from the upstream network,
+ each is a 4D-tensor.
+ img_metas (list[Dict]): List of image information.
+ gt_semantic_seg (list[tensor]):Each element is the ground truth
+ of semantic segmentation with the shape (N, H, W).
+ train_cfg (dict): The training config, which not been used in
+ maskformer.
+ gt_labels (list[Tensor]): Each element is ground truth labels of
+ each box, shape (num_gts,).
+ gt_masks (list[BitmapMasks]): Each element is masks of instances
+ of a image, shape (num_gts, h, w).
+
+ Returns:
+ losses (dict[str, Tensor]): a dictionary of loss components
+ """
+
+ # forward
+ all_cls_scores, all_mask_preds = self(x, img_metas)
+
+ # loss
+ losses = self.loss(all_cls_scores, all_mask_preds, gt_labels, gt_masks, img_metas)
+
+ return losses
+
+ def forward_test(self, inputs, img_metas, test_cfg):
+ """Test segment without test-time aumengtation.
+
+ Only the output of last decoder layers was used.
+
+ Args:
+ inputs (list[Tensor]): Multi-level features from the
+ upstream network, each is a 4D-tensor.
+ img_metas (list[dict]): List of image information.
+ test_cfg (dict): Testing config.
+
+ Returns:
+ seg_mask (Tensor): Predicted semantic segmentation logits.
+ """
+ all_cls_scores, all_mask_preds = self(inputs, img_metas)
+ cls_score, mask_pred = all_cls_scores[-1], all_mask_preds[-1]
+ ori_h, ori_w, _ = img_metas[0]["ori_shape"]
+
+ # semantic inference
+ cls_score = F.softmax(cls_score, dim=-1)[..., :-1]
+ mask_pred = mask_pred.sigmoid()
+ seg_mask = torch.einsum("bqc,bqhw->bchw", cls_score, mask_pred)
+ return seg_mask
diff --git a/mapper/models/dinov2/eval/segmentation_m2f/models/losses/__init__.py b/mapper/models/dinov2/eval/segmentation_m2f/models/losses/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..229a887817372f4991b32354180592cfb236d728
--- /dev/null
+++ b/mapper/models/dinov2/eval/segmentation_m2f/models/losses/__init__.py
@@ -0,0 +1,8 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+from .cross_entropy_loss import CrossEntropyLoss, binary_cross_entropy, cross_entropy, mask_cross_entropy
+from .dice_loss import DiceLoss
+from .match_costs import ClassificationCost, CrossEntropyLossCost, DiceCost
diff --git a/mapper/models/dinov2/eval/segmentation_m2f/models/losses/cross_entropy_loss.py b/mapper/models/dinov2/eval/segmentation_m2f/models/losses/cross_entropy_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..0a1f9dd4aa52ebe94cc527db36b1c7fa2f53813e
--- /dev/null
+++ b/mapper/models/dinov2/eval/segmentation_m2f/models/losses/cross_entropy_loss.py
@@ -0,0 +1,279 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+import warnings
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from mmseg.models.builder import LOSSES
+from mmseg.models.losses.utils import get_class_weight, weight_reduce_loss
+
+
+def cross_entropy(
+ pred,
+ label,
+ weight=None,
+ class_weight=None,
+ reduction="mean",
+ avg_factor=None,
+ ignore_index=-100,
+ avg_non_ignore=False,
+):
+ """cross_entropy. The wrapper function for :func:`F.cross_entropy`
+
+ Args:
+ pred (torch.Tensor): The prediction with shape (N, 1).
+ label (torch.Tensor): The learning label of the prediction.
+ weight (torch.Tensor, optional): Sample-wise loss weight.
+ Default: None.
+ class_weight (list[float], optional): The weight for each class.
+ Default: None.
+ reduction (str, optional): The method used to reduce the loss.
+ Options are 'none', 'mean' and 'sum'. Default: 'mean'.
+ avg_factor (int, optional): Average factor that is used to average
+ the loss. Default: None.
+ ignore_index (int): Specifies a target value that is ignored and
+ does not contribute to the input gradients. When
+ ``avg_non_ignore `` is ``True``, and the ``reduction`` is
+ ``''mean''``, the loss is averaged over non-ignored targets.
+ Defaults: -100.
+ avg_non_ignore (bool): The flag decides to whether the loss is
+ only averaged over non-ignored targets. Default: False.
+ `New in version 0.23.0.`
+ """
+
+ # class_weight is a manual rescaling weight given to each class.
+ # If given, has to be a Tensor of size C element-wise losses
+ loss = F.cross_entropy(pred, label, weight=class_weight, reduction="none", ignore_index=ignore_index)
+
+ # apply weights and do the reduction
+ # average loss over non-ignored elements
+ # pytorch's official cross_entropy average loss over non-ignored elements
+ # refer to https://github.com/pytorch/pytorch/blob/56b43f4fec1f76953f15a627694d4bba34588969/torch/nn/functional.py#L2660 # noqa
+ if (avg_factor is None) and avg_non_ignore and reduction == "mean":
+ avg_factor = label.numel() - (label == ignore_index).sum().item()
+ if weight is not None:
+ weight = weight.float()
+ loss = weight_reduce_loss(loss, weight=weight, reduction=reduction, avg_factor=avg_factor)
+
+ return loss
+
+
+def _expand_onehot_labels(labels, label_weights, target_shape, ignore_index):
+ """Expand onehot labels to match the size of prediction."""
+ bin_labels = labels.new_zeros(target_shape)
+ valid_mask = (labels >= 0) & (labels != ignore_index)
+ inds = torch.nonzero(valid_mask, as_tuple=True)
+
+ if inds[0].numel() > 0:
+ if labels.dim() == 3:
+ bin_labels[inds[0], labels[valid_mask], inds[1], inds[2]] = 1
+ else:
+ bin_labels[inds[0], labels[valid_mask]] = 1
+
+ valid_mask = valid_mask.unsqueeze(1).expand(target_shape).float()
+
+ if label_weights is None:
+ bin_label_weights = valid_mask
+ else:
+ bin_label_weights = label_weights.unsqueeze(1).expand(target_shape)
+ bin_label_weights = bin_label_weights * valid_mask
+
+ return bin_labels, bin_label_weights, valid_mask
+
+
+def binary_cross_entropy(
+ pred,
+ label,
+ weight=None,
+ reduction="mean",
+ avg_factor=None,
+ class_weight=None,
+ ignore_index=-100,
+ avg_non_ignore=False,
+ **kwargs,
+):
+ """Calculate the binary CrossEntropy loss.
+
+ Args:
+ pred (torch.Tensor): The prediction with shape (N, 1).
+ label (torch.Tensor): The learning label of the prediction.
+ Note: In bce loss, label < 0 is invalid.
+ weight (torch.Tensor, optional): Sample-wise loss weight.
+ reduction (str, optional): The method used to reduce the loss.
+ Options are "none", "mean" and "sum".
+ avg_factor (int, optional): Average factor that is used to average
+ the loss. Defaults to None.
+ class_weight (list[float], optional): The weight for each class.
+ ignore_index (int): The label index to be ignored. Default: -100.
+ avg_non_ignore (bool): The flag decides to whether the loss is
+ only averaged over non-ignored targets. Default: False.
+ `New in version 0.23.0.`
+
+ Returns:
+ torch.Tensor: The calculated loss
+ """
+ if pred.size(1) == 1:
+ # For binary class segmentation, the shape of pred is
+ # [N, 1, H, W] and that of label is [N, H, W].
+ assert label.max() <= 1, "For pred with shape [N, 1, H, W], its label must have at " "most 2 classes"
+ pred = pred.squeeze()
+ if pred.dim() != label.dim():
+ assert (pred.dim() == 2 and label.dim() == 1) or (pred.dim() == 4 and label.dim() == 3), (
+ "Only pred shape [N, C], label shape [N] or pred shape [N, C, " "H, W], label shape [N, H, W] are supported"
+ )
+ # `weight` returned from `_expand_onehot_labels`
+ # has been treated for valid (non-ignore) pixels
+ label, weight, valid_mask = _expand_onehot_labels(label, weight, pred.shape, ignore_index)
+ else:
+ # should mask out the ignored elements
+ valid_mask = ((label >= 0) & (label != ignore_index)).float()
+ if weight is not None:
+ weight = weight * valid_mask
+ else:
+ weight = valid_mask
+ # average loss over non-ignored and valid elements
+ if reduction == "mean" and avg_factor is None and avg_non_ignore:
+ avg_factor = valid_mask.sum().item()
+
+ loss = F.binary_cross_entropy_with_logits(pred, label.float(), pos_weight=class_weight, reduction="none")
+ # do the reduction for the weighted loss
+ loss = weight_reduce_loss(loss, weight, reduction=reduction, avg_factor=avg_factor)
+
+ return loss
+
+
+def mask_cross_entropy(
+ pred, target, label, reduction="mean", avg_factor=None, class_weight=None, ignore_index=None, **kwargs
+):
+ """Calculate the CrossEntropy loss for masks.
+
+ Args:
+ pred (torch.Tensor): The prediction with shape (N, C), C is the number
+ of classes.
+ target (torch.Tensor): The learning label of the prediction.
+ label (torch.Tensor): ``label`` indicates the class label of the mask'
+ corresponding object. This will be used to select the mask in the
+ of the class which the object belongs to when the mask prediction
+ if not class-agnostic.
+ reduction (str, optional): The method used to reduce the loss.
+ Options are "none", "mean" and "sum".
+ avg_factor (int, optional): Average factor that is used to average
+ the loss. Defaults to None.
+ class_weight (list[float], optional): The weight for each class.
+ ignore_index (None): Placeholder, to be consistent with other loss.
+ Default: None.
+
+ Returns:
+ torch.Tensor: The calculated loss
+ """
+ assert ignore_index is None, "BCE loss does not support ignore_index"
+ assert reduction == "mean" and avg_factor is None
+ num_rois = pred.size()[0]
+ inds = torch.arange(0, num_rois, dtype=torch.long, device=pred.device)
+ pred_slice = pred[inds, label].squeeze(1)
+ return F.binary_cross_entropy_with_logits(pred_slice, target, weight=class_weight, reduction="mean")[None]
+
+
+@LOSSES.register_module(force=True)
+class CrossEntropyLoss(nn.Module):
+ """CrossEntropyLoss.
+
+ Args:
+ use_sigmoid (bool, optional): Whether the prediction uses sigmoid
+ of softmax. Defaults to False.
+ use_mask (bool, optional): Whether to use mask cross entropy loss.
+ Defaults to False.
+ reduction (str, optional): . Defaults to 'mean'.
+ Options are "none", "mean" and "sum".
+ class_weight (list[float] | str, optional): Weight of each class. If in
+ str format, read them from a file. Defaults to None.
+ loss_weight (float, optional): Weight of the loss. Defaults to 1.0.
+ loss_name (str, optional): Name of the loss item. If you want this loss
+ item to be included into the backward graph, `loss_` must be the
+ prefix of the name. Defaults to 'loss_ce'.
+ avg_non_ignore (bool): The flag decides to whether the loss is
+ only averaged over non-ignored targets. Default: False.
+ `New in version 0.23.0.`
+ """
+
+ def __init__(
+ self,
+ use_sigmoid=False,
+ use_mask=False,
+ reduction="mean",
+ class_weight=None,
+ loss_weight=1.0,
+ loss_name="loss_ce",
+ avg_non_ignore=False,
+ ):
+ super(CrossEntropyLoss, self).__init__()
+ assert (use_sigmoid is False) or (use_mask is False)
+ self.use_sigmoid = use_sigmoid
+ self.use_mask = use_mask
+ self.reduction = reduction
+ self.loss_weight = loss_weight
+ self.class_weight = get_class_weight(class_weight)
+ self.avg_non_ignore = avg_non_ignore
+ if not self.avg_non_ignore and self.reduction == "mean":
+ warnings.warn(
+ "Default ``avg_non_ignore`` is False, if you would like to "
+ "ignore the certain label and average loss over non-ignore "
+ "labels, which is the same with PyTorch official "
+ "cross_entropy, set ``avg_non_ignore=True``."
+ )
+
+ if self.use_sigmoid:
+ self.cls_criterion = binary_cross_entropy
+ elif self.use_mask:
+ self.cls_criterion = mask_cross_entropy
+ else:
+ self.cls_criterion = cross_entropy
+ self._loss_name = loss_name
+
+ def extra_repr(self):
+ """Extra repr."""
+ s = f"avg_non_ignore={self.avg_non_ignore}"
+ return s
+
+ def forward(
+ self, cls_score, label, weight=None, avg_factor=None, reduction_override=None, ignore_index=-100, **kwargs
+ ):
+ """Forward function."""
+ assert reduction_override in (None, "none", "mean", "sum")
+ reduction = reduction_override if reduction_override else self.reduction
+ if self.class_weight is not None:
+ class_weight = cls_score.new_tensor(self.class_weight)
+ else:
+ class_weight = None
+ # Note: for BCE loss, label < 0 is invalid.
+ loss_cls = self.loss_weight * self.cls_criterion(
+ cls_score,
+ label,
+ weight,
+ class_weight=class_weight,
+ reduction=reduction,
+ avg_factor=avg_factor,
+ avg_non_ignore=self.avg_non_ignore,
+ ignore_index=ignore_index,
+ **kwargs,
+ )
+ return loss_cls
+
+ @property
+ def loss_name(self):
+ """Loss Name.
+
+ This function must be implemented and will return the name of this
+ loss function. This name will be used to combine different loss items
+ by simple sum operation. In addition, if you want this loss item to be
+ included into the backward graph, `loss_` must be the prefix of the
+ name.
+
+ Returns:
+ str: The name of this loss item.
+ """
+ return self._loss_name
diff --git a/mapper/models/dinov2/eval/segmentation_m2f/models/losses/dice_loss.py b/mapper/models/dinov2/eval/segmentation_m2f/models/losses/dice_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..1bc5ba893c502861032ed531283f225e183eb693
--- /dev/null
+++ b/mapper/models/dinov2/eval/segmentation_m2f/models/losses/dice_loss.py
@@ -0,0 +1,153 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn as nn
+from mmseg.models.builder import LOSSES
+from mmseg.models.losses.utils import weight_reduce_loss
+
+
+def dice_loss(pred, target, weight=None, eps=1e-3, reduction="mean", avg_factor=None):
+ """Calculate dice loss, which is proposed in
+ `V-Net: Fully Convolutional Neural Networks for Volumetric
+ Medical Image Segmentation `_.
+
+ Args:
+ pred (torch.Tensor): The prediction, has a shape (n, *)
+ target (torch.Tensor): The learning label of the prediction,
+ shape (n, *), same shape of pred.
+ weight (torch.Tensor, optional): The weight of loss for each
+ prediction, has a shape (n,). Defaults to None.
+ eps (float): Avoid dividing by zero. Default: 1e-3.
+ reduction (str, optional): The method used to reduce the loss into
+ a scalar. Defaults to 'mean'.
+ Options are "none", "mean" and "sum".
+ avg_factor (int, optional): Average factor that is used to average
+ the loss. Defaults to None.
+ """
+
+ input = pred.flatten(1)
+ target = target.flatten(1).float()
+
+ a = torch.sum(input * target, 1)
+ b = torch.sum(input * input, 1) + eps
+ c = torch.sum(target * target, 1) + eps
+ d = (2 * a) / (b + c)
+ loss = 1 - d
+ if weight is not None:
+ assert weight.ndim == loss.ndim
+ assert len(weight) == len(pred)
+ loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
+ return loss
+
+
+def naive_dice_loss(pred, target, weight=None, eps=1e-3, reduction="mean", avg_factor=None):
+ """Calculate naive dice loss, the coefficient in the denominator is the
+ first power instead of the second power.
+
+ Args:
+ pred (torch.Tensor): The prediction, has a shape (n, *)
+ target (torch.Tensor): The learning label of the prediction,
+ shape (n, *), same shape of pred.
+ weight (torch.Tensor, optional): The weight of loss for each
+ prediction, has a shape (n,). Defaults to None.
+ eps (float): Avoid dividing by zero. Default: 1e-3.
+ reduction (str, optional): The method used to reduce the loss into
+ a scalar. Defaults to 'mean'.
+ Options are "none", "mean" and "sum".
+ avg_factor (int, optional): Average factor that is used to average
+ the loss. Defaults to None.
+ """
+ input = pred.flatten(1)
+ target = target.flatten(1).float()
+
+ a = torch.sum(input * target, 1)
+ b = torch.sum(input, 1)
+ c = torch.sum(target, 1)
+ d = (2 * a + eps) / (b + c + eps)
+ loss = 1 - d
+ if weight is not None:
+ assert weight.ndim == loss.ndim
+ assert len(weight) == len(pred)
+ loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
+ return loss
+
+
+@LOSSES.register_module(force=True)
+class DiceLoss(nn.Module):
+ def __init__(self, use_sigmoid=True, activate=True, reduction="mean", naive_dice=False, loss_weight=1.0, eps=1e-3):
+ """Dice Loss, there are two forms of dice loss is supported:
+
+ - the one proposed in `V-Net: Fully Convolutional Neural
+ Networks for Volumetric Medical Image Segmentation
+ `_.
+ - the dice loss in which the power of the number in the
+ denominator is the first power instead of the second
+ power.
+
+ Args:
+ use_sigmoid (bool, optional): Whether to the prediction is
+ used for sigmoid or softmax. Defaults to True.
+ activate (bool): Whether to activate the predictions inside,
+ this will disable the inside sigmoid operation.
+ Defaults to True.
+ reduction (str, optional): The method used
+ to reduce the loss. Options are "none",
+ "mean" and "sum". Defaults to 'mean'.
+ naive_dice (bool, optional): If false, use the dice
+ loss defined in the V-Net paper, otherwise, use the
+ naive dice loss in which the power of the number in the
+ denominator is the first power instead of the second
+ power.Defaults to False.
+ loss_weight (float, optional): Weight of loss. Defaults to 1.0.
+ eps (float): Avoid dividing by zero. Defaults to 1e-3.
+ """
+
+ super(DiceLoss, self).__init__()
+ self.use_sigmoid = use_sigmoid
+ self.reduction = reduction
+ self.naive_dice = naive_dice
+ self.loss_weight = loss_weight
+ self.eps = eps
+ self.activate = activate
+
+ def forward(self, pred, target, weight=None, reduction_override=None, avg_factor=None):
+ """Forward function.
+
+ Args:
+ pred (torch.Tensor): The prediction, has a shape (n, *).
+ target (torch.Tensor): The label of the prediction,
+ shape (n, *), same shape of pred.
+ weight (torch.Tensor, optional): The weight of loss for each
+ prediction, has a shape (n,). Defaults to None.
+ avg_factor (int, optional): Average factor that is used to average
+ the loss. Defaults to None.
+ reduction_override (str, optional): The reduction method used to
+ override the original reduction method of the loss.
+ Options are "none", "mean" and "sum".
+
+ Returns:
+ torch.Tensor: The calculated loss
+ """
+
+ assert reduction_override in (None, "none", "mean", "sum")
+ reduction = reduction_override if reduction_override else self.reduction
+
+ if self.activate:
+ if self.use_sigmoid:
+ pred = pred.sigmoid()
+ else:
+ raise NotImplementedError
+
+ if self.naive_dice:
+ loss = self.loss_weight * naive_dice_loss(
+ pred, target, weight, eps=self.eps, reduction=reduction, avg_factor=avg_factor
+ )
+ else:
+ loss = self.loss_weight * dice_loss(
+ pred, target, weight, eps=self.eps, reduction=reduction, avg_factor=avg_factor
+ )
+
+ return loss
diff --git a/mapper/models/dinov2/eval/segmentation_m2f/models/losses/match_costs.py b/mapper/models/dinov2/eval/segmentation_m2f/models/losses/match_costs.py
new file mode 100644
index 0000000000000000000000000000000000000000..4917d2a939c01398dd49c0d90b06f4c37d283ce0
--- /dev/null
+++ b/mapper/models/dinov2/eval/segmentation_m2f/models/losses/match_costs.py
@@ -0,0 +1,153 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn.functional as F
+
+from ..builder import MATCH_COST
+
+
+@MATCH_COST.register_module()
+class ClassificationCost:
+ """ClsSoftmaxCost.Borrow from
+ mmdet.core.bbox.match_costs.match_cost.ClassificationCost.
+
+ Args:
+ weight (int | float, optional): loss_weight
+
+ Examples:
+ >>> import torch
+ >>> self = ClassificationCost()
+ >>> cls_pred = torch.rand(4, 3)
+ >>> gt_labels = torch.tensor([0, 1, 2])
+ >>> factor = torch.tensor([10, 8, 10, 8])
+ >>> self(cls_pred, gt_labels)
+ tensor([[-0.3430, -0.3525, -0.3045],
+ [-0.3077, -0.2931, -0.3992],
+ [-0.3664, -0.3455, -0.2881],
+ [-0.3343, -0.2701, -0.3956]])
+ """
+
+ def __init__(self, weight=1.0):
+ self.weight = weight
+
+ def __call__(self, cls_pred, gt_labels):
+ """
+ Args:
+ cls_pred (Tensor): Predicted classification logits, shape
+ [num_query, num_class].
+ gt_labels (Tensor): Label of `gt_bboxes`, shape (num_gt,).
+
+ Returns:
+ torch.Tensor: cls_cost value with weight
+ """
+ # Following the official DETR repo, contrary to the loss that
+ # NLL is used, we approximate it in 1 - cls_score[gt_label].
+ # The 1 is a constant that doesn't change the matching,
+ # so it can be omitted.
+ cls_score = cls_pred.softmax(-1)
+ cls_cost = -cls_score[:, gt_labels]
+ return cls_cost * self.weight
+
+
+@MATCH_COST.register_module()
+class DiceCost:
+ """Cost of mask assignments based on dice losses.
+
+ Args:
+ weight (int | float, optional): loss_weight. Defaults to 1.
+ pred_act (bool, optional): Whether to apply sigmoid to mask_pred.
+ Defaults to False.
+ eps (float, optional): default 1e-12.
+ """
+
+ def __init__(self, weight=1.0, pred_act=False, eps=1e-3):
+ self.weight = weight
+ self.pred_act = pred_act
+ self.eps = eps
+
+ def binary_mask_dice_loss(self, mask_preds, gt_masks):
+ """
+ Args:
+ mask_preds (Tensor): Mask prediction in shape (N1, H, W).
+ gt_masks (Tensor): Ground truth in shape (N2, H, W)
+ store 0 or 1, 0 for negative class and 1 for
+ positive class.
+
+ Returns:
+ Tensor: Dice cost matrix in shape (N1, N2).
+ """
+ mask_preds = mask_preds.reshape((mask_preds.shape[0], -1))
+ gt_masks = gt_masks.reshape((gt_masks.shape[0], -1)).float()
+ numerator = 2 * torch.einsum("nc,mc->nm", mask_preds, gt_masks)
+ denominator = mask_preds.sum(-1)[:, None] + gt_masks.sum(-1)[None, :]
+ loss = 1 - (numerator + self.eps) / (denominator + self.eps)
+ return loss
+
+ def __call__(self, mask_preds, gt_masks):
+ """
+ Args:
+ mask_preds (Tensor): Mask prediction logits in shape (N1, H, W).
+ gt_masks (Tensor): Ground truth in shape (N2, H, W).
+
+ Returns:
+ Tensor: Dice cost matrix in shape (N1, N2).
+ """
+ if self.pred_act:
+ mask_preds = mask_preds.sigmoid()
+ dice_cost = self.binary_mask_dice_loss(mask_preds, gt_masks)
+ return dice_cost * self.weight
+
+
+@MATCH_COST.register_module()
+class CrossEntropyLossCost:
+ """CrossEntropyLossCost.
+
+ Args:
+ weight (int | float, optional): loss weight. Defaults to 1.
+ use_sigmoid (bool, optional): Whether the prediction uses sigmoid
+ of softmax. Defaults to True.
+ """
+
+ def __init__(self, weight=1.0, use_sigmoid=True):
+ assert use_sigmoid, "use_sigmoid = False is not supported yet."
+ self.weight = weight
+ self.use_sigmoid = use_sigmoid
+
+ def _binary_cross_entropy(self, cls_pred, gt_labels):
+ """
+ Args:
+ cls_pred (Tensor): The prediction with shape (num_query, 1, *) or
+ (num_query, *).
+ gt_labels (Tensor): The learning label of prediction with
+ shape (num_gt, *).
+ Returns:
+ Tensor: Cross entropy cost matrix in shape (num_query, num_gt).
+ """
+ cls_pred = cls_pred.flatten(1).float()
+ gt_labels = gt_labels.flatten(1).float()
+ n = cls_pred.shape[1]
+ pos = F.binary_cross_entropy_with_logits(cls_pred, torch.ones_like(cls_pred), reduction="none")
+ neg = F.binary_cross_entropy_with_logits(cls_pred, torch.zeros_like(cls_pred), reduction="none")
+ cls_cost = torch.einsum("nc,mc->nm", pos, gt_labels) + torch.einsum("nc,mc->nm", neg, 1 - gt_labels)
+ cls_cost = cls_cost / n
+
+ return cls_cost
+
+ def __call__(self, cls_pred, gt_labels):
+ """
+ Args:
+ cls_pred (Tensor): Predicted classification logits.
+ gt_labels (Tensor): Labels.
+ Returns:
+ Tensor: Cross entropy cost matrix with weight in
+ shape (num_query, num_gt).
+ """
+ if self.use_sigmoid:
+ cls_cost = self._binary_cross_entropy(cls_pred, gt_labels)
+ else:
+ raise NotImplementedError
+
+ return cls_cost * self.weight
diff --git a/mapper/models/dinov2/eval/segmentation_m2f/models/plugins/__init__.py b/mapper/models/dinov2/eval/segmentation_m2f/models/plugins/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..81a60db4de31238cb38e078683e5ca265839fe60
--- /dev/null
+++ b/mapper/models/dinov2/eval/segmentation_m2f/models/plugins/__init__.py
@@ -0,0 +1,6 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+from .msdeformattn_pixel_decoder import MSDeformAttnPixelDecoder
diff --git a/mapper/models/dinov2/eval/segmentation_m2f/models/plugins/msdeformattn_pixel_decoder.py b/mapper/models/dinov2/eval/segmentation_m2f/models/plugins/msdeformattn_pixel_decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..db1947175917f73f3f24184cb09c78e092d46ef8
--- /dev/null
+++ b/mapper/models/dinov2/eval/segmentation_m2f/models/plugins/msdeformattn_pixel_decoder.py
@@ -0,0 +1,242 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from mmcv.cnn import PLUGIN_LAYERS, Conv2d, ConvModule, caffe2_xavier_init, normal_init, xavier_init
+from mmcv.cnn.bricks.transformer import build_positional_encoding, build_transformer_layer_sequence
+from mmcv.runner import BaseModule, ModuleList
+
+from ...core.anchor import MlvlPointGenerator
+from ..utils.transformer import MultiScaleDeformableAttention
+
+
+@PLUGIN_LAYERS.register_module()
+class MSDeformAttnPixelDecoder(BaseModule):
+ """Pixel decoder with multi-scale deformable attention.
+
+ Args:
+ in_channels (list[int] | tuple[int]): Number of channels in the
+ input feature maps.
+ strides (list[int] | tuple[int]): Output strides of feature from
+ backbone.
+ feat_channels (int): Number of channels for feature.
+ out_channels (int): Number of channels for output.
+ num_outs (int): Number of output scales.
+ norm_cfg (:obj:`mmcv.ConfigDict` | dict): Config for normalization.
+ Defaults to dict(type='GN', num_groups=32).
+ act_cfg (:obj:`mmcv.ConfigDict` | dict): Config for activation.
+ Defaults to dict(type='ReLU').
+ encoder (:obj:`mmcv.ConfigDict` | dict): Config for transformer
+ encoder. Defaults to `DetrTransformerEncoder`.
+ positional_encoding (:obj:`mmcv.ConfigDict` | dict): Config for
+ transformer encoder position encoding. Defaults to
+ dict(type='SinePositionalEncoding', num_feats=128,
+ normalize=True).
+ init_cfg (:obj:`mmcv.ConfigDict` | dict): Initialization config dict.
+ """
+
+ def __init__(
+ self,
+ in_channels=[256, 512, 1024, 2048],
+ strides=[4, 8, 16, 32],
+ feat_channels=256,
+ out_channels=256,
+ num_outs=3,
+ norm_cfg=dict(type="GN", num_groups=32),
+ act_cfg=dict(type="ReLU"),
+ encoder=dict(
+ type="DetrTransformerEncoder",
+ num_layers=6,
+ transformerlayers=dict(
+ type="BaseTransformerLayer",
+ attn_cfgs=dict(
+ type="MultiScaleDeformableAttention",
+ embed_dims=256,
+ num_heads=8,
+ num_levels=3,
+ num_points=4,
+ im2col_step=64,
+ dropout=0.0,
+ batch_first=False,
+ norm_cfg=None,
+ init_cfg=None,
+ ),
+ feedforward_channels=1024,
+ ffn_dropout=0.0,
+ operation_order=("self_attn", "norm", "ffn", "norm"),
+ ),
+ init_cfg=None,
+ ),
+ positional_encoding=dict(type="SinePositionalEncoding", num_feats=128, normalize=True),
+ init_cfg=None,
+ ):
+ super().__init__(init_cfg=init_cfg)
+ self.strides = strides
+ self.num_input_levels = len(in_channels)
+ self.num_encoder_levels = encoder.transformerlayers.attn_cfgs.num_levels
+ assert self.num_encoder_levels >= 1, "num_levels in attn_cfgs must be at least one"
+ input_conv_list = []
+ # from top to down (low to high resolution)
+ for i in range(self.num_input_levels - 1, self.num_input_levels - self.num_encoder_levels - 1, -1):
+ input_conv = ConvModule(
+ in_channels[i], feat_channels, kernel_size=1, norm_cfg=norm_cfg, act_cfg=None, bias=True
+ )
+ input_conv_list.append(input_conv)
+ self.input_convs = ModuleList(input_conv_list)
+
+ self.encoder = build_transformer_layer_sequence(encoder)
+ self.postional_encoding = build_positional_encoding(positional_encoding)
+ # high resolution to low resolution
+ self.level_encoding = nn.Embedding(self.num_encoder_levels, feat_channels)
+
+ # fpn-like structure
+ self.lateral_convs = ModuleList()
+ self.output_convs = ModuleList()
+ self.use_bias = norm_cfg is None
+ # from top to down (low to high resolution)
+ # fpn for the rest features that didn't pass in encoder
+ for i in range(self.num_input_levels - self.num_encoder_levels - 1, -1, -1):
+ lateral_conv = ConvModule(
+ in_channels[i], feat_channels, kernel_size=1, bias=self.use_bias, norm_cfg=norm_cfg, act_cfg=None
+ )
+ output_conv = ConvModule(
+ feat_channels,
+ feat_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=self.use_bias,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg,
+ )
+ self.lateral_convs.append(lateral_conv)
+ self.output_convs.append(output_conv)
+
+ self.mask_feature = Conv2d(feat_channels, out_channels, kernel_size=1, stride=1, padding=0)
+
+ self.num_outs = num_outs
+ self.point_generator = MlvlPointGenerator(strides)
+
+ def init_weights(self):
+ """Initialize weights."""
+ for i in range(0, self.num_encoder_levels):
+ xavier_init(self.input_convs[i].conv, gain=1, bias=0, distribution="uniform")
+
+ for i in range(0, self.num_input_levels - self.num_encoder_levels):
+ caffe2_xavier_init(self.lateral_convs[i].conv, bias=0)
+ caffe2_xavier_init(self.output_convs[i].conv, bias=0)
+
+ caffe2_xavier_init(self.mask_feature, bias=0)
+
+ normal_init(self.level_encoding, mean=0, std=1)
+ for p in self.encoder.parameters():
+ if p.dim() > 1:
+ nn.init.xavier_normal_(p)
+
+ # init_weights defined in MultiScaleDeformableAttention
+ for layer in self.encoder.layers:
+ for attn in layer.attentions:
+ if isinstance(attn, MultiScaleDeformableAttention):
+ attn.init_weights()
+
+ def forward(self, feats):
+ """
+ Args:
+ feats (list[Tensor]): Feature maps of each level. Each has
+ shape of (batch_size, c, h, w).
+
+ Returns:
+ tuple: A tuple containing the following:
+
+ - mask_feature (Tensor): shape (batch_size, c, h, w).
+ - multi_scale_features (list[Tensor]): Multi scale \
+ features, each in shape (batch_size, c, h, w).
+ """
+ # generate padding mask for each level, for each image
+ batch_size = feats[0].shape[0]
+ encoder_input_list = []
+ padding_mask_list = []
+ level_positional_encoding_list = []
+ spatial_shapes = []
+ reference_points_list = []
+ for i in range(self.num_encoder_levels):
+ level_idx = self.num_input_levels - i - 1
+ feat = feats[level_idx]
+ feat_projected = self.input_convs[i](feat)
+ h, w = feat.shape[-2:]
+
+ # no padding
+ padding_mask_resized = feat.new_zeros((batch_size,) + feat.shape[-2:], dtype=torch.bool)
+ pos_embed = self.postional_encoding(padding_mask_resized)
+ level_embed = self.level_encoding.weight[i]
+ level_pos_embed = level_embed.view(1, -1, 1, 1) + pos_embed
+ # (h_i * w_i, 2)
+ reference_points = self.point_generator.single_level_grid_priors(
+ feat.shape[-2:], level_idx, device=feat.device
+ )
+ # normalize
+ factor = feat.new_tensor([[w, h]]) * self.strides[level_idx]
+ reference_points = reference_points / factor
+
+ # shape (batch_size, c, h_i, w_i) -> (h_i * w_i, batch_size, c)
+ feat_projected = feat_projected.flatten(2).permute(2, 0, 1)
+ level_pos_embed = level_pos_embed.flatten(2).permute(2, 0, 1)
+ padding_mask_resized = padding_mask_resized.flatten(1)
+
+ encoder_input_list.append(feat_projected)
+ padding_mask_list.append(padding_mask_resized)
+ level_positional_encoding_list.append(level_pos_embed)
+ spatial_shapes.append(feat.shape[-2:])
+ reference_points_list.append(reference_points)
+ # shape (batch_size, total_num_query),
+ # total_num_query=sum([., h_i * w_i,.])
+ padding_masks = torch.cat(padding_mask_list, dim=1)
+ # shape (total_num_query, batch_size, c)
+ encoder_inputs = torch.cat(encoder_input_list, dim=0)
+ level_positional_encodings = torch.cat(level_positional_encoding_list, dim=0)
+ device = encoder_inputs.device
+ # shape (num_encoder_levels, 2), from low
+ # resolution to high resolution
+ spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=device)
+ # shape (0, h_0*w_0, h_0*w_0+h_1*w_1, ...)
+ level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1]))
+ reference_points = torch.cat(reference_points_list, dim=0)
+ reference_points = reference_points[None, :, None].repeat(batch_size, 1, self.num_encoder_levels, 1)
+ valid_radios = reference_points.new_ones((batch_size, self.num_encoder_levels, 2))
+ # shape (num_total_query, batch_size, c)
+ memory = self.encoder(
+ query=encoder_inputs,
+ key=None,
+ value=None,
+ query_pos=level_positional_encodings,
+ key_pos=None,
+ attn_masks=None,
+ key_padding_mask=None,
+ query_key_padding_mask=padding_masks,
+ spatial_shapes=spatial_shapes,
+ reference_points=reference_points,
+ level_start_index=level_start_index,
+ valid_radios=valid_radios,
+ )
+ # (num_total_query, batch_size, c) -> (batch_size, c, num_total_query)
+ memory = memory.permute(1, 2, 0)
+
+ # from low resolution to high resolution
+ num_query_per_level = [e[0] * e[1] for e in spatial_shapes]
+ outs = torch.split(memory, num_query_per_level, dim=-1)
+ outs = [x.reshape(batch_size, -1, spatial_shapes[i][0], spatial_shapes[i][1]) for i, x in enumerate(outs)]
+
+ for i in range(self.num_input_levels - self.num_encoder_levels - 1, -1, -1):
+ x = feats[i]
+ cur_feat = self.lateral_convs[i](x)
+ y = cur_feat + F.interpolate(outs[-1], size=cur_feat.shape[-2:], mode="bilinear", align_corners=False)
+ y = self.output_convs[i](y)
+ outs.append(y)
+ multi_scale_features = outs[: self.num_outs]
+
+ mask_feature = self.mask_feature(outs[-1])
+ return mask_feature, multi_scale_features
diff --git a/mapper/models/dinov2/eval/segmentation_m2f/models/segmentors/__init__.py b/mapper/models/dinov2/eval/segmentation_m2f/models/segmentors/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..adf0062691e4889612e118f28ced853cd0bc33db
--- /dev/null
+++ b/mapper/models/dinov2/eval/segmentation_m2f/models/segmentors/__init__.py
@@ -0,0 +1,6 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+from .encoder_decoder_mask2former import EncoderDecoderMask2Former
diff --git a/mapper/models/dinov2/eval/segmentation_m2f/models/segmentors/encoder_decoder_mask2former.py b/mapper/models/dinov2/eval/segmentation_m2f/models/segmentors/encoder_decoder_mask2former.py
new file mode 100644
index 0000000000000000000000000000000000000000..cfe572c9d317303bff8d51b85217d144906ebfe7
--- /dev/null
+++ b/mapper/models/dinov2/eval/segmentation_m2f/models/segmentors/encoder_decoder_mask2former.py
@@ -0,0 +1,271 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from mmseg.core import add_prefix
+from mmseg.models import builder
+from mmseg.models.builder import SEGMENTORS
+from mmseg.models.segmentors.base import BaseSegmentor
+from mmseg.ops import resize
+
+
+@SEGMENTORS.register_module()
+class EncoderDecoderMask2Former(BaseSegmentor):
+ """Encoder Decoder segmentors.
+
+ EncoderDecoder typically consists of backbone, decode_head, auxiliary_head.
+ Note that auxiliary_head is only used for deep supervision during training,
+ which could be dumped during inference.
+ """
+
+ def __init__(
+ self,
+ backbone,
+ decode_head,
+ neck=None,
+ auxiliary_head=None,
+ train_cfg=None,
+ test_cfg=None,
+ pretrained=None,
+ init_cfg=None,
+ ):
+ super(EncoderDecoderMask2Former, self).__init__(init_cfg)
+ if pretrained is not None:
+ assert backbone.get("pretrained") is None, "both backbone and segmentor set pretrained weight"
+ backbone.pretrained = pretrained
+ self.backbone = builder.build_backbone(backbone)
+ if neck is not None:
+ self.neck = builder.build_neck(neck)
+ decode_head.update(train_cfg=train_cfg)
+ decode_head.update(test_cfg=test_cfg)
+ self._init_decode_head(decode_head)
+ self._init_auxiliary_head(auxiliary_head)
+
+ self.train_cfg = train_cfg
+ self.test_cfg = test_cfg
+
+ assert self.with_decode_head
+
+ def _init_decode_head(self, decode_head):
+ """Initialize ``decode_head``"""
+ self.decode_head = builder.build_head(decode_head)
+ self.align_corners = self.decode_head.align_corners
+ self.num_classes = self.decode_head.num_classes
+
+ def _init_auxiliary_head(self, auxiliary_head):
+ """Initialize ``auxiliary_head``"""
+ if auxiliary_head is not None:
+ if isinstance(auxiliary_head, list):
+ self.auxiliary_head = nn.ModuleList()
+ for head_cfg in auxiliary_head:
+ self.auxiliary_head.append(builder.build_head(head_cfg))
+ else:
+ self.auxiliary_head = builder.build_head(auxiliary_head)
+
+ def extract_feat(self, img):
+ """Extract features from images."""
+ x = self.backbone(img)
+ if self.with_neck:
+ x = self.neck(x)
+ return x
+
+ def encode_decode(self, img, img_metas):
+ """Encode images with backbone and decode into a semantic segmentation
+ map of the same size as input."""
+ x = self.extract_feat(img)
+ out = self._decode_head_forward_test(x, img_metas)
+ out = resize(input=out, size=img.shape[2:], mode="bilinear", align_corners=self.align_corners)
+ return out
+
+ def _decode_head_forward_train(self, x, img_metas, gt_semantic_seg, **kwargs):
+ """Run forward function and calculate loss for decode head in
+ training."""
+ losses = dict()
+ loss_decode = self.decode_head.forward_train(x, img_metas, gt_semantic_seg, **kwargs)
+
+ losses.update(add_prefix(loss_decode, "decode"))
+ return losses
+
+ def _decode_head_forward_test(self, x, img_metas):
+ """Run forward function and calculate loss for decode head in
+ inference."""
+ seg_logits = self.decode_head.forward_test(x, img_metas, self.test_cfg)
+ return seg_logits
+
+ def _auxiliary_head_forward_train(self, x, img_metas, gt_semantic_seg):
+ """Run forward function and calculate loss for auxiliary head in
+ training."""
+ losses = dict()
+ if isinstance(self.auxiliary_head, nn.ModuleList):
+ for idx, aux_head in enumerate(self.auxiliary_head):
+ loss_aux = aux_head.forward_train(x, img_metas, gt_semantic_seg, self.train_cfg)
+ losses.update(add_prefix(loss_aux, f"aux_{idx}"))
+ else:
+ loss_aux = self.auxiliary_head.forward_train(x, img_metas, gt_semantic_seg, self.train_cfg)
+ losses.update(add_prefix(loss_aux, "aux"))
+
+ return losses
+
+ def forward_dummy(self, img):
+ """Dummy forward function."""
+ seg_logit = self.encode_decode(img, None)
+
+ return seg_logit
+
+ def forward_train(self, img, img_metas, gt_semantic_seg, **kwargs):
+ """Forward function for training.
+
+ Args:
+ img (Tensor): Input images.
+ img_metas (list[dict]): List of image info dict where each dict
+ has: 'img_shape', 'scale_factor', 'flip', and may also contain
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
+ For details on the values of these keys see
+ `mmseg/datasets/pipelines/formatting.py:Collect`.
+ gt_semantic_seg (Tensor): Semantic segmentation masks
+ used if the architecture supports semantic segmentation task.
+
+ Returns:
+ dict[str, Tensor]: a dictionary of loss components
+ """
+
+ x = self.extract_feat(img)
+
+ losses = dict()
+
+ loss_decode = self._decode_head_forward_train(x, img_metas, gt_semantic_seg, **kwargs)
+ losses.update(loss_decode)
+
+ if self.with_auxiliary_head:
+ loss_aux = self._auxiliary_head_forward_train(x, img_metas, gt_semantic_seg)
+ losses.update(loss_aux)
+
+ return losses
+
+ def slide_inference(self, img, img_meta, rescale):
+ """Inference by sliding-window with overlap.
+
+ If h_crop > h_img or w_crop > w_img, the small patch will be used to
+ decode without padding.
+ """
+
+ h_stride, w_stride = self.test_cfg.stride
+ h_crop, w_crop = self.test_cfg.crop_size
+ batch_size, _, h_img, w_img = img.size()
+ num_classes = self.num_classes
+ h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1
+ w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1
+ preds = img.new_zeros((batch_size, num_classes, h_img, w_img))
+ count_mat = img.new_zeros((batch_size, 1, h_img, w_img))
+ for h_idx in range(h_grids):
+ for w_idx in range(w_grids):
+ y1 = h_idx * h_stride
+ x1 = w_idx * w_stride
+ y2 = min(y1 + h_crop, h_img)
+ x2 = min(x1 + w_crop, w_img)
+ y1 = max(y2 - h_crop, 0)
+ x1 = max(x2 - w_crop, 0)
+ crop_img = img[:, :, y1:y2, x1:x2]
+ crop_seg_logit = self.encode_decode(crop_img, img_meta)
+ preds += F.pad(crop_seg_logit, (int(x1), int(preds.shape[3] - x2), int(y1), int(preds.shape[2] - y2)))
+
+ count_mat[:, :, y1:y2, x1:x2] += 1
+ assert (count_mat == 0).sum() == 0
+ if torch.onnx.is_in_onnx_export():
+ # cast count_mat to constant while exporting to ONNX
+ count_mat = torch.from_numpy(count_mat.cpu().detach().numpy()).to(device=img.device)
+ preds = preds / count_mat
+ if rescale:
+ preds = resize(
+ preds,
+ size=img_meta[0]["ori_shape"][:2],
+ mode="bilinear",
+ align_corners=self.align_corners,
+ warning=False,
+ )
+ return preds
+
+ def whole_inference(self, img, img_meta, rescale):
+ """Inference with full image."""
+
+ seg_logit = self.encode_decode(img, img_meta)
+ if rescale:
+ # support dynamic shape for onnx
+ if torch.onnx.is_in_onnx_export():
+ size = img.shape[2:]
+ else:
+ size = img_meta[0]["ori_shape"][:2]
+ seg_logit = resize(seg_logit, size=size, mode="bilinear", align_corners=self.align_corners, warning=False)
+
+ return seg_logit
+
+ def inference(self, img, img_meta, rescale):
+ """Inference with slide/whole style.
+
+ Args:
+ img (Tensor): The input image of shape (N, 3, H, W).
+ img_meta (dict): Image info dict where each dict has: 'img_shape',
+ 'scale_factor', 'flip', and may also contain
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
+ For details on the values of these keys see
+ `mmseg/datasets/pipelines/formatting.py:Collect`.
+ rescale (bool): Whether rescale back to original shape.
+
+ Returns:
+ Tensor: The output segmentation map.
+ """
+
+ assert self.test_cfg.mode in ["slide", "whole"]
+ ori_shape = img_meta[0]["ori_shape"]
+ assert all(_["ori_shape"] == ori_shape for _ in img_meta)
+ if self.test_cfg.mode == "slide":
+ seg_logit = self.slide_inference(img, img_meta, rescale)
+ else:
+ seg_logit = self.whole_inference(img, img_meta, rescale)
+ output = F.softmax(seg_logit, dim=1)
+ flip = img_meta[0]["flip"]
+ if flip:
+ flip_direction = img_meta[0]["flip_direction"]
+ assert flip_direction in ["horizontal", "vertical"]
+ if flip_direction == "horizontal":
+ output = output.flip(dims=(3,))
+ elif flip_direction == "vertical":
+ output = output.flip(dims=(2,))
+
+ return output
+
+ def simple_test(self, img, img_meta, rescale=True):
+ """Simple test with single image."""
+ seg_logit = self.inference(img, img_meta, rescale)
+ seg_pred = seg_logit.argmax(dim=1)
+ if torch.onnx.is_in_onnx_export():
+ # our inference backend only support 4D output
+ seg_pred = seg_pred.unsqueeze(0)
+ return seg_pred
+ seg_pred = seg_pred.cpu().numpy()
+ # unravel batch dim
+ seg_pred = list(seg_pred)
+ return seg_pred
+
+ def aug_test(self, imgs, img_metas, rescale=True):
+ """Test with augmentations.
+
+ Only rescale=True is supported.
+ """
+ # aug_test rescale all imgs back to ori_shape for now
+ assert rescale
+ # to save memory, we get augmented seg logit inplace
+ seg_logit = self.inference(imgs[0], img_metas[0], rescale)
+ for i in range(1, len(imgs)):
+ cur_seg_logit = self.inference(imgs[i], img_metas[i], rescale)
+ seg_logit += cur_seg_logit
+ seg_logit /= len(imgs)
+ seg_pred = seg_logit.argmax(dim=1)
+ seg_pred = seg_pred.cpu().numpy()
+ # unravel batch dim
+ seg_pred = list(seg_pred)
+ return seg_pred
diff --git a/mapper/models/dinov2/eval/segmentation_m2f/models/utils/__init__.py b/mapper/models/dinov2/eval/segmentation_m2f/models/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e7fdc1668b1015c8feea8fa1a4691bc0ebdbd936
--- /dev/null
+++ b/mapper/models/dinov2/eval/segmentation_m2f/models/utils/__init__.py
@@ -0,0 +1,9 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+from .assigner import MaskHungarianAssigner
+from .point_sample import get_uncertain_point_coords_with_randomness
+from .positional_encoding import LearnedPositionalEncoding, SinePositionalEncoding
+from .transformer import DetrTransformerDecoder, DetrTransformerDecoderLayer, DynamicConv, Transformer
diff --git a/mapper/models/dinov2/eval/segmentation_m2f/models/utils/assigner.py b/mapper/models/dinov2/eval/segmentation_m2f/models/utils/assigner.py
new file mode 100644
index 0000000000000000000000000000000000000000..3cb08fc1bb2e36336989b45a1d3850f260c05963
--- /dev/null
+++ b/mapper/models/dinov2/eval/segmentation_m2f/models/utils/assigner.py
@@ -0,0 +1,157 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+from abc import ABCMeta, abstractmethod
+
+import torch
+
+from ..builder import MASK_ASSIGNERS, build_match_cost
+
+try:
+ from scipy.optimize import linear_sum_assignment
+except ImportError:
+ linear_sum_assignment = None
+
+
+class AssignResult(metaclass=ABCMeta):
+ """Collection of assign results."""
+
+ def __init__(self, num_gts, gt_inds, labels):
+ self.num_gts = num_gts
+ self.gt_inds = gt_inds
+ self.labels = labels
+
+ @property
+ def info(self):
+ info = {
+ "num_gts": self.num_gts,
+ "gt_inds": self.gt_inds,
+ "labels": self.labels,
+ }
+ return info
+
+
+class BaseAssigner(metaclass=ABCMeta):
+ """Base assigner that assigns boxes to ground truth boxes."""
+
+ @abstractmethod
+ def assign(self, masks, gt_masks, gt_masks_ignore=None, gt_labels=None):
+ """Assign boxes to either a ground truth boxes or a negative boxes."""
+ pass
+
+
+@MASK_ASSIGNERS.register_module()
+class MaskHungarianAssigner(BaseAssigner):
+ """Computes one-to-one matching between predictions and ground truth for
+ mask.
+
+ This class computes an assignment between the targets and the predictions
+ based on the costs. The costs are weighted sum of three components:
+ classification cost, regression L1 cost and regression iou cost. The
+ targets don't include the no_object, so generally there are more
+ predictions than targets. After the one-to-one matching, the un-matched
+ are treated as backgrounds. Thus each query prediction will be assigned
+ with `0` or a positive integer indicating the ground truth index:
+
+ - 0: negative sample, no assigned gt
+ - positive integer: positive sample, index (1-based) of assigned gt
+
+ Args:
+ cls_cost (obj:`mmcv.ConfigDict`|dict): Classification cost config.
+ mask_cost (obj:`mmcv.ConfigDict`|dict): Mask cost config.
+ dice_cost (obj:`mmcv.ConfigDict`|dict): Dice cost config.
+ """
+
+ def __init__(
+ self,
+ cls_cost=dict(type="ClassificationCost", weight=1.0),
+ dice_cost=dict(type="DiceCost", weight=1.0),
+ mask_cost=dict(type="MaskFocalCost", weight=1.0),
+ ):
+ self.cls_cost = build_match_cost(cls_cost)
+ self.dice_cost = build_match_cost(dice_cost)
+ self.mask_cost = build_match_cost(mask_cost)
+
+ def assign(self, cls_pred, mask_pred, gt_labels, gt_masks, img_meta, gt_masks_ignore=None, eps=1e-7):
+ """Computes one-to-one matching based on the weighted costs.
+
+ This method assign each query prediction to a ground truth or
+ background. The `assigned_gt_inds` with -1 means don't care,
+ 0 means negative sample, and positive number is the index (1-based)
+ of assigned gt.
+ The assignment is done in the following steps, the order matters.
+
+ 1. assign every prediction to -1
+ 2. compute the weighted costs
+ 3. do Hungarian matching on CPU based on the costs
+ 4. assign all to 0 (background) first, then for each matched pair
+ between predictions and gts, treat this prediction as foreground
+ and assign the corresponding gt index (plus 1) to it.
+
+ Args:
+ mask_pred (Tensor): Predicted mask, shape [num_query, h, w]
+ cls_pred (Tensor): Predicted classification logits, shape
+ [num_query, num_class].
+ gt_masks (Tensor): Ground truth mask, shape [num_gt, h, w].
+ gt_labels (Tensor): Label of `gt_masks`, shape (num_gt,).
+ img_meta (dict): Meta information for current image.
+ gt_masks_ignore (Tensor, optional): Ground truth masks that are
+ labelled as `ignored`. Default None.
+ eps (int | float, optional): A value added to the denominator for
+ numerical stability. Default 1e-7.
+
+ Returns:
+ :obj:`AssignResult`: The assigned result.
+ """
+ assert gt_masks_ignore is None, "Only case when gt_masks_ignore is None is supported."
+ num_gts, num_queries = gt_labels.shape[0], cls_pred.shape[0]
+
+ # 1. assign -1 by default
+ assigned_gt_inds = cls_pred.new_full((num_queries,), -1, dtype=torch.long)
+ assigned_labels = cls_pred.new_full((num_queries,), -1, dtype=torch.long)
+ if num_gts == 0 or num_queries == 0:
+ # No ground truth or boxes, return empty assignment
+ if num_gts == 0:
+ # No ground truth, assign all to background
+ assigned_gt_inds[:] = 0
+ return AssignResult(num_gts, assigned_gt_inds, labels=assigned_labels)
+
+ # 2. compute the weighted costs
+ # classification and maskcost.
+ if self.cls_cost.weight != 0 and cls_pred is not None:
+ cls_cost = self.cls_cost(cls_pred, gt_labels)
+ else:
+ cls_cost = 0
+
+ if self.mask_cost.weight != 0:
+ # mask_pred shape = [nq, h, w]
+ # gt_mask shape = [ng, h, w]
+ # mask_cost shape = [nq, ng]
+ mask_cost = self.mask_cost(mask_pred, gt_masks)
+ else:
+ mask_cost = 0
+
+ if self.dice_cost.weight != 0:
+ dice_cost = self.dice_cost(mask_pred, gt_masks)
+ else:
+ dice_cost = 0
+ cost = cls_cost + mask_cost + dice_cost
+
+ # 3. do Hungarian matching on CPU using linear_sum_assignment
+ cost = cost.detach().cpu()
+ if linear_sum_assignment is None:
+ raise ImportError('Please run "pip install scipy" ' "to install scipy first.")
+
+ matched_row_inds, matched_col_inds = linear_sum_assignment(cost)
+ matched_row_inds = torch.from_numpy(matched_row_inds).to(cls_pred.device)
+ matched_col_inds = torch.from_numpy(matched_col_inds).to(cls_pred.device)
+
+ # 4. assign backgrounds and foregrounds
+ # assign all indices to backgrounds first
+ assigned_gt_inds[:] = 0
+ # assign foregrounds based on matching results
+ assigned_gt_inds[matched_row_inds] = matched_col_inds + 1
+ assigned_labels[matched_row_inds] = gt_labels[matched_col_inds]
+ return AssignResult(num_gts, assigned_gt_inds, labels=assigned_labels)
diff --git a/mapper/models/dinov2/eval/segmentation_m2f/models/utils/point_sample.py b/mapper/models/dinov2/eval/segmentation_m2f/models/utils/point_sample.py
new file mode 100644
index 0000000000000000000000000000000000000000..9f1134082bafb51432618a9632592db070f87284
--- /dev/null
+++ b/mapper/models/dinov2/eval/segmentation_m2f/models/utils/point_sample.py
@@ -0,0 +1,86 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+import torch
+from mmcv.ops import point_sample
+
+
+def get_uncertainty(mask_pred, labels):
+ """Estimate uncertainty based on pred logits.
+
+ We estimate uncertainty as L1 distance between 0.0 and the logits
+ prediction in 'mask_pred' for the foreground class in `classes`.
+
+ Args:
+ mask_pred (Tensor): mask predication logits, shape (num_rois,
+ num_classes, mask_height, mask_width).
+
+ labels (list[Tensor]): Either predicted or ground truth label for
+ each predicted mask, of length num_rois.
+
+ Returns:
+ scores (Tensor): Uncertainty scores with the most uncertain
+ locations having the highest uncertainty score,
+ shape (num_rois, 1, mask_height, mask_width)
+ """
+ if mask_pred.shape[1] == 1:
+ gt_class_logits = mask_pred.clone()
+ else:
+ inds = torch.arange(mask_pred.shape[0], device=mask_pred.device)
+ gt_class_logits = mask_pred[inds, labels].unsqueeze(1)
+ return -torch.abs(gt_class_logits)
+
+
+def get_uncertain_point_coords_with_randomness(
+ mask_pred, labels, num_points, oversample_ratio, importance_sample_ratio
+):
+ """Get ``num_points`` most uncertain points with random points during
+ train.
+
+ Sample points in [0, 1] x [0, 1] coordinate space based on their
+ uncertainty. The uncertainties are calculated for each point using
+ 'get_uncertainty()' function that takes point's logit prediction as
+ input.
+
+ Args:
+ mask_pred (Tensor): A tensor of shape (num_rois, num_classes,
+ mask_height, mask_width) for class-specific or class-agnostic
+ prediction.
+ labels (list): The ground truth class for each instance.
+ num_points (int): The number of points to sample.
+ oversample_ratio (int): Oversampling parameter.
+ importance_sample_ratio (float): Ratio of points that are sampled
+ via importnace sampling.
+
+ Returns:
+ point_coords (Tensor): A tensor of shape (num_rois, num_points, 2)
+ that contains the coordinates sampled points.
+ """
+ assert oversample_ratio >= 1
+ assert 0 <= importance_sample_ratio <= 1
+ batch_size = mask_pred.shape[0]
+ num_sampled = int(num_points * oversample_ratio)
+ point_coords = torch.rand(batch_size, num_sampled, 2, device=mask_pred.device)
+ point_logits = point_sample(mask_pred, point_coords)
+ # It is crucial to calculate uncertainty based on the sampled
+ # prediction value for the points. Calculating uncertainties of the
+ # coarse predictions first and sampling them for points leads to
+ # incorrect results. To illustrate this: assume uncertainty func(
+ # logits)=-abs(logits), a sampled point between two coarse
+ # predictions with -1 and 1 logits has 0 logits, and therefore 0
+ # uncertainty value. However, if we calculate uncertainties for the
+ # coarse predictions first, both will have -1 uncertainty,
+ # and sampled point will get -1 uncertainty.
+ point_uncertainties = get_uncertainty(point_logits, labels)
+ num_uncertain_points = int(importance_sample_ratio * num_points)
+ num_random_points = num_points - num_uncertain_points
+ idx = torch.topk(point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1]
+ shift = num_sampled * torch.arange(batch_size, dtype=torch.long, device=mask_pred.device)
+ idx += shift[:, None]
+ point_coords = point_coords.view(-1, 2)[idx.view(-1), :].view(batch_size, num_uncertain_points, 2)
+ if num_random_points > 0:
+ rand_roi_coords = torch.rand(batch_size, num_random_points, 2, device=mask_pred.device)
+ point_coords = torch.cat((point_coords, rand_roi_coords), dim=1)
+ return point_coords
diff --git a/mapper/models/dinov2/eval/segmentation_m2f/models/utils/positional_encoding.py b/mapper/models/dinov2/eval/segmentation_m2f/models/utils/positional_encoding.py
new file mode 100644
index 0000000000000000000000000000000000000000..bf5d6fabe946d06fe97cc799da47bae93758b34e
--- /dev/null
+++ b/mapper/models/dinov2/eval/segmentation_m2f/models/utils/positional_encoding.py
@@ -0,0 +1,152 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+import math
+
+import torch
+import torch.nn as nn
+from mmcv.cnn.bricks.transformer import POSITIONAL_ENCODING
+from mmcv.runner import BaseModule
+
+
+@POSITIONAL_ENCODING.register_module()
+class SinePositionalEncoding(BaseModule):
+ """Position encoding with sine and cosine functions.
+
+ See `End-to-End Object Detection with Transformers
+ `_ for details.
+
+ Args:
+ num_feats (int): The feature dimension for each position
+ along x-axis or y-axis. Note the final returned dimension
+ for each position is 2 times of this value.
+ temperature (int, optional): The temperature used for scaling
+ the position embedding. Defaults to 10000.
+ normalize (bool, optional): Whether to normalize the position
+ embedding. Defaults to False.
+ scale (float, optional): A scale factor that scales the position
+ embedding. The scale will be used only when `normalize` is True.
+ Defaults to 2*pi.
+ eps (float, optional): A value added to the denominator for
+ numerical stability. Defaults to 1e-6.
+ offset (float): offset add to embed when do the normalization.
+ Defaults to 0.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default: None
+ """
+
+ def __init__(
+ self, num_feats, temperature=10000, normalize=False, scale=2 * math.pi, eps=1e-6, offset=0.0, init_cfg=None
+ ):
+ super(SinePositionalEncoding, self).__init__(init_cfg)
+ if normalize:
+ assert isinstance(scale, (float, int)), (
+ "when normalize is set," "scale should be provided and in float or int type, " f"found {type(scale)}"
+ )
+ self.num_feats = num_feats
+ self.temperature = temperature
+ self.normalize = normalize
+ self.scale = scale
+ self.eps = eps
+ self.offset = offset
+
+ def forward(self, mask):
+ """Forward function for `SinePositionalEncoding`.
+
+ Args:
+ mask (Tensor): ByteTensor mask. Non-zero values representing
+ ignored positions, while zero values means valid positions
+ for this image. Shape [bs, h, w].
+
+ Returns:
+ pos (Tensor): Returned position embedding with shape
+ [bs, num_feats*2, h, w].
+ """
+ # For convenience of exporting to ONNX, it's required to convert
+ # `masks` from bool to int.
+ mask = mask.to(torch.int)
+ not_mask = 1 - mask # logical_not
+ y_embed = not_mask.cumsum(1, dtype=torch.float32)
+ x_embed = not_mask.cumsum(2, dtype=torch.float32)
+ if self.normalize:
+ y_embed = (y_embed + self.offset) / (y_embed[:, -1:, :] + self.eps) * self.scale
+ x_embed = (x_embed + self.offset) / (x_embed[:, :, -1:] + self.eps) * self.scale
+ dim_t = torch.arange(self.num_feats, dtype=torch.float32, device=mask.device)
+ dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_feats)
+ pos_x = x_embed[:, :, :, None] / dim_t
+ pos_y = y_embed[:, :, :, None] / dim_t
+ # use `view` instead of `flatten` for dynamically exporting to ONNX
+ B, H, W = mask.size()
+ pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).view(B, H, W, -1)
+ pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).view(B, H, W, -1)
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
+ return pos
+
+ def __repr__(self):
+ """str: a string that describes the module"""
+ repr_str = self.__class__.__name__
+ repr_str += f"(num_feats={self.num_feats}, "
+ repr_str += f"temperature={self.temperature}, "
+ repr_str += f"normalize={self.normalize}, "
+ repr_str += f"scale={self.scale}, "
+ repr_str += f"eps={self.eps})"
+ return repr_str
+
+
+@POSITIONAL_ENCODING.register_module()
+class LearnedPositionalEncoding(BaseModule):
+ """Position embedding with learnable embedding weights.
+
+ Args:
+ num_feats (int): The feature dimension for each position
+ along x-axis or y-axis. The final returned dimension for
+ each position is 2 times of this value.
+ row_num_embed (int, optional): The dictionary size of row embeddings.
+ Default 50.
+ col_num_embed (int, optional): The dictionary size of col embeddings.
+ Default 50.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ """
+
+ def __init__(self, num_feats, row_num_embed=50, col_num_embed=50, init_cfg=dict(type="Uniform", layer="Embedding")):
+ super(LearnedPositionalEncoding, self).__init__(init_cfg)
+ self.row_embed = nn.Embedding(row_num_embed, num_feats)
+ self.col_embed = nn.Embedding(col_num_embed, num_feats)
+ self.num_feats = num_feats
+ self.row_num_embed = row_num_embed
+ self.col_num_embed = col_num_embed
+
+ def forward(self, mask):
+ """Forward function for `LearnedPositionalEncoding`.
+
+ Args:
+ mask (Tensor): ByteTensor mask. Non-zero values representing
+ ignored positions, while zero values means valid positions
+ for this image. Shape [bs, h, w].
+
+ Returns:
+ pos (Tensor): Returned position embedding with shape
+ [bs, num_feats*2, h, w].
+ """
+ h, w = mask.shape[-2:]
+ x = torch.arange(w, device=mask.device)
+ y = torch.arange(h, device=mask.device)
+ x_embed = self.col_embed(x)
+ y_embed = self.row_embed(y)
+ pos = (
+ torch.cat((x_embed.unsqueeze(0).repeat(h, 1, 1), y_embed.unsqueeze(1).repeat(1, w, 1)), dim=-1)
+ .permute(2, 0, 1)
+ .unsqueeze(0)
+ .repeat(mask.shape[0], 1, 1, 1)
+ )
+ return pos
+
+ def __repr__(self):
+ """str: a string that describes the module"""
+ repr_str = self.__class__.__name__
+ repr_str += f"(num_feats={self.num_feats}, "
+ repr_str += f"row_num_embed={self.row_num_embed}, "
+ repr_str += f"col_num_embed={self.col_num_embed})"
+ return repr_str
diff --git a/mapper/models/dinov2/eval/segmentation_m2f/models/utils/transformer.py b/mapper/models/dinov2/eval/segmentation_m2f/models/utils/transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..8befe6011a34d5ccecb82c8b17b61e19f732f96b
--- /dev/null
+++ b/mapper/models/dinov2/eval/segmentation_m2f/models/utils/transformer.py
@@ -0,0 +1,989 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+import math
+import warnings
+from typing import Sequence
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint as cp
+from mmcv.cnn import Linear, build_activation_layer, build_norm_layer, xavier_init
+from mmcv.cnn.bricks.drop import build_dropout
+from mmcv.cnn.bricks.registry import FEEDFORWARD_NETWORK, TRANSFORMER_LAYER, TRANSFORMER_LAYER_SEQUENCE
+from mmcv.cnn.bricks.transformer import BaseTransformerLayer, TransformerLayerSequence, build_transformer_layer_sequence
+from mmcv.runner.base_module import BaseModule, Sequential
+from mmcv.utils import deprecated_api_warning, to_2tuple
+from torch.nn.init import normal_
+
+from ..builder import TRANSFORMER
+
+try:
+ from mmcv.ops.multi_scale_deform_attn import MultiScaleDeformableAttention
+
+except ImportError:
+ warnings.warn(
+ "`MultiScaleDeformableAttention` in MMCV has been moved to "
+ "`mmcv.ops.multi_scale_deform_attn`, please update your MMCV"
+ )
+ from mmcv.cnn.bricks.transformer import MultiScaleDeformableAttention
+
+
+class AdaptivePadding(nn.Module):
+ """Applies padding to input (if needed) so that input can get fully covered
+ by filter you specified. It support two modes "same" and "corner". The
+ "same" mode is same with "SAME" padding mode in TensorFlow, pad zero around
+ input. The "corner" mode would pad zero to bottom right.
+
+ Args:
+ kernel_size (int | tuple): Size of the kernel:
+ stride (int | tuple): Stride of the filter. Default: 1:
+ dilation (int | tuple): Spacing between kernel elements.
+ Default: 1
+ padding (str): Support "same" and "corner", "corner" mode
+ would pad zero to bottom right, and "same" mode would
+ pad zero around input. Default: "corner".
+ Example:
+ >>> kernel_size = 16
+ >>> stride = 16
+ >>> dilation = 1
+ >>> input = torch.rand(1, 1, 15, 17)
+ >>> adap_pad = AdaptivePadding(
+ >>> kernel_size=kernel_size,
+ >>> stride=stride,
+ >>> dilation=dilation,
+ >>> padding="corner")
+ >>> out = adap_pad(input)
+ >>> assert (out.shape[2], out.shape[3]) == (16, 32)
+ >>> input = torch.rand(1, 1, 16, 17)
+ >>> out = adap_pad(input)
+ >>> assert (out.shape[2], out.shape[3]) == (16, 32)
+ """
+
+ def __init__(self, kernel_size=1, stride=1, dilation=1, padding="corner"):
+
+ super(AdaptivePadding, self).__init__()
+
+ assert padding in ("same", "corner")
+
+ kernel_size = to_2tuple(kernel_size)
+ stride = to_2tuple(stride)
+ padding = to_2tuple(padding)
+ dilation = to_2tuple(dilation)
+
+ self.padding = padding
+ self.kernel_size = kernel_size
+ self.stride = stride
+ self.dilation = dilation
+
+ def get_pad_shape(self, input_shape):
+ input_h, input_w = input_shape
+ kernel_h, kernel_w = self.kernel_size
+ stride_h, stride_w = self.stride
+ output_h = math.ceil(input_h / stride_h)
+ output_w = math.ceil(input_w / stride_w)
+ pad_h = max((output_h - 1) * stride_h + (kernel_h - 1) * self.dilation[0] + 1 - input_h, 0)
+ pad_w = max((output_w - 1) * stride_w + (kernel_w - 1) * self.dilation[1] + 1 - input_w, 0)
+ return pad_h, pad_w
+
+ def forward(self, x):
+ pad_h, pad_w = self.get_pad_shape(x.size()[-2:])
+ if pad_h > 0 or pad_w > 0:
+ if self.padding == "corner":
+ x = F.pad(x, [0, pad_w, 0, pad_h])
+ elif self.padding == "same":
+ x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2])
+ return x
+
+
+class PatchMerging(BaseModule):
+ """Merge patch feature map.
+
+ This layer groups feature map by kernel_size, and applies norm and linear
+ layers to the grouped feature map. Our implementation uses `nn.Unfold` to
+ merge patch, which is about 25% faster than original implementation.
+ Instead, we need to modify pretrained models for compatibility.
+
+ Args:
+ in_channels (int): The num of input channels.
+ to gets fully covered by filter and stride you specified..
+ Default: True.
+ out_channels (int): The num of output channels.
+ kernel_size (int | tuple, optional): the kernel size in the unfold
+ layer. Defaults to 2.
+ stride (int | tuple, optional): the stride of the sliding blocks in the
+ unfold layer. Default: None. (Would be set as `kernel_size`)
+ padding (int | tuple | string ): The padding length of
+ embedding conv. When it is a string, it means the mode
+ of adaptive padding, support "same" and "corner" now.
+ Default: "corner".
+ dilation (int | tuple, optional): dilation parameter in the unfold
+ layer. Default: 1.
+ bias (bool, optional): Whether to add bias in linear layer or not.
+ Defaults: False.
+ norm_cfg (dict, optional): Config dict for normalization layer.
+ Default: dict(type='LN').
+ init_cfg (dict, optional): The extra config for initialization.
+ Default: None.
+ """
+
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ kernel_size=2,
+ stride=None,
+ padding="corner",
+ dilation=1,
+ bias=False,
+ norm_cfg=dict(type="LN"),
+ init_cfg=None,
+ ):
+ super().__init__(init_cfg=init_cfg)
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ if stride:
+ stride = stride
+ else:
+ stride = kernel_size
+
+ kernel_size = to_2tuple(kernel_size)
+ stride = to_2tuple(stride)
+ dilation = to_2tuple(dilation)
+
+ if isinstance(padding, str):
+ self.adap_padding = AdaptivePadding(
+ kernel_size=kernel_size, stride=stride, dilation=dilation, padding=padding
+ )
+ # disable the padding of unfold
+ padding = 0
+ else:
+ self.adap_padding = None
+
+ padding = to_2tuple(padding)
+ self.sampler = nn.Unfold(kernel_size=kernel_size, dilation=dilation, padding=padding, stride=stride)
+
+ sample_dim = kernel_size[0] * kernel_size[1] * in_channels
+
+ if norm_cfg is not None:
+ self.norm = build_norm_layer(norm_cfg, sample_dim)[1]
+ else:
+ self.norm = None
+
+ self.reduction = nn.Linear(sample_dim, out_channels, bias=bias)
+
+ def forward(self, x, input_size):
+ """
+ Args:
+ x (Tensor): Has shape (B, H*W, C_in).
+ input_size (tuple[int]): The spatial shape of x, arrange as (H, W).
+ Default: None.
+
+ Returns:
+ tuple: Contains merged results and its spatial shape.
+
+ - x (Tensor): Has shape (B, Merged_H * Merged_W, C_out)
+ - out_size (tuple[int]): Spatial shape of x, arrange as
+ (Merged_H, Merged_W).
+ """
+ B, L, C = x.shape
+ assert isinstance(input_size, Sequence), f"Expect " f"input_size is " f"`Sequence` " f"but get {input_size}"
+
+ H, W = input_size
+ assert L == H * W, "input feature has wrong size"
+
+ x = x.view(B, H, W, C).permute([0, 3, 1, 2]) # B, C, H, W
+ # Use nn.Unfold to merge patch. About 25% faster than original method,
+ # but need to modify pretrained model for compatibility
+
+ if self.adap_padding:
+ x = self.adap_padding(x)
+ H, W = x.shape[-2:]
+
+ x = self.sampler(x)
+ # if kernel_size=2 and stride=2, x should has shape (B, 4*C, H/2*W/2)
+
+ out_h = (
+ H + 2 * self.sampler.padding[0] - self.sampler.dilation[0] * (self.sampler.kernel_size[0] - 1) - 1
+ ) // self.sampler.stride[0] + 1
+ out_w = (
+ W + 2 * self.sampler.padding[1] - self.sampler.dilation[1] * (self.sampler.kernel_size[1] - 1) - 1
+ ) // self.sampler.stride[1] + 1
+
+ output_size = (out_h, out_w)
+ x = x.transpose(1, 2) # B, H/2*W/2, 4*C
+ x = self.norm(x) if self.norm else x
+ x = self.reduction(x)
+ return x, output_size
+
+
+def inverse_sigmoid(x, eps=1e-5):
+ """Inverse function of sigmoid.
+
+ Args:
+ x (Tensor): The tensor to do the
+ inverse.
+ eps (float): EPS avoid numerical
+ overflow. Defaults 1e-5.
+ Returns:
+ Tensor: The x has passed the inverse
+ function of sigmoid, has same
+ shape with input.
+ """
+ x = x.clamp(min=0, max=1)
+ x1 = x.clamp(min=eps)
+ x2 = (1 - x).clamp(min=eps)
+ return torch.log(x1 / x2)
+
+
+@FEEDFORWARD_NETWORK.register_module(force=True)
+class FFN(BaseModule):
+ """Implements feed-forward networks (FFNs) with identity connection.
+ Args:
+ embed_dims (int): The feature dimension. Same as
+ `MultiheadAttention`. Defaults: 256.
+ feedforward_channels (int): The hidden dimension of FFNs.
+ Defaults: 1024.
+ num_fcs (int, optional): The number of fully-connected layers in
+ FFNs. Default: 2.
+ act_cfg (dict, optional): The activation config for FFNs.
+ Default: dict(type='ReLU')
+ ffn_drop (float, optional): Probability of an element to be
+ zeroed in FFN. Default 0.0.
+ add_identity (bool, optional): Whether to add the
+ identity connection. Default: `True`.
+ dropout_layer (obj:`ConfigDict`): The dropout_layer used
+ when adding the shortcut.
+ init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
+ Default: None.
+ """
+
+ @deprecated_api_warning({"dropout": "ffn_drop", "add_residual": "add_identity"}, cls_name="FFN")
+ def __init__(
+ self,
+ embed_dims=256,
+ feedforward_channels=1024,
+ num_fcs=2,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ffn_drop=0.0,
+ dropout_layer=None,
+ add_identity=True,
+ init_cfg=None,
+ with_cp=False,
+ **kwargs,
+ ):
+ super().__init__(init_cfg)
+ assert num_fcs >= 2, "num_fcs should be no less " f"than 2. got {num_fcs}."
+ self.embed_dims = embed_dims
+ self.feedforward_channels = feedforward_channels
+ self.num_fcs = num_fcs
+ self.act_cfg = act_cfg
+ self.activate = build_activation_layer(act_cfg)
+ self.with_cp = with_cp
+ layers = []
+ in_channels = embed_dims
+ for _ in range(num_fcs - 1):
+ layers.append(Sequential(Linear(in_channels, feedforward_channels), self.activate, nn.Dropout(ffn_drop)))
+ in_channels = feedforward_channels
+ layers.append(Linear(feedforward_channels, embed_dims))
+ layers.append(nn.Dropout(ffn_drop))
+ self.layers = Sequential(*layers)
+ self.dropout_layer = build_dropout(dropout_layer) if dropout_layer else torch.nn.Identity()
+ self.add_identity = add_identity
+
+ @deprecated_api_warning({"residual": "identity"}, cls_name="FFN")
+ def forward(self, x, identity=None):
+ """Forward function for `FFN`.
+ The function would add x to the output tensor if residue is None.
+ """
+
+ if self.with_cp and x.requires_grad:
+ out = cp.checkpoint(self.layers, x)
+ else:
+ out = self.layers(x)
+
+ if not self.add_identity:
+ return self.dropout_layer(out)
+ if identity is None:
+ identity = x
+ return identity + self.dropout_layer(out)
+
+
+@TRANSFORMER_LAYER.register_module()
+class DetrTransformerDecoderLayer(BaseTransformerLayer):
+ """Implements decoder layer in DETR transformer.
+
+ Args:
+ attn_cfgs (list[`mmcv.ConfigDict`] | list[dict] | dict )):
+ Configs for self_attention or cross_attention, the order
+ should be consistent with it in `operation_order`. If it is
+ a dict, it would be expand to the number of attention in
+ `operation_order`.
+ feedforward_channels (int): The hidden dimension for FFNs.
+ ffn_dropout (float): Probability of an element to be zeroed
+ in ffn. Default 0.0.
+ operation_order (tuple[str]): The execution order of operation
+ in transformer. Such as ('self_attn', 'norm', 'ffn', 'norm').
+ Default:None
+ act_cfg (dict): The activation config for FFNs. Default: `LN`
+ norm_cfg (dict): Config dict for normalization layer.
+ Default: `LN`.
+ ffn_num_fcs (int): The number of fully-connected layers in FFNs.
+ Default:2.
+ """
+
+ def __init__(
+ self,
+ attn_cfgs,
+ feedforward_channels,
+ ffn_dropout=0.0,
+ operation_order=None,
+ act_cfg=dict(type="ReLU", inplace=True),
+ norm_cfg=dict(type="LN"),
+ ffn_num_fcs=2,
+ **kwargs,
+ ):
+ super(DetrTransformerDecoderLayer, self).__init__(
+ attn_cfgs=attn_cfgs,
+ feedforward_channels=feedforward_channels,
+ ffn_dropout=ffn_dropout,
+ operation_order=operation_order,
+ act_cfg=act_cfg,
+ norm_cfg=norm_cfg,
+ ffn_num_fcs=ffn_num_fcs,
+ **kwargs,
+ )
+ assert len(operation_order) == 6
+ assert set(operation_order) == set(["self_attn", "norm", "cross_attn", "ffn"])
+
+
+@TRANSFORMER_LAYER_SEQUENCE.register_module()
+class DetrTransformerEncoder(TransformerLayerSequence):
+ """TransformerEncoder of DETR.
+
+ Args:
+ post_norm_cfg (dict): Config of last normalization layer. Default:
+ `LN`. Only used when `self.pre_norm` is `True`
+ """
+
+ def __init__(self, *args, post_norm_cfg=dict(type="LN"), **kwargs):
+ super(DetrTransformerEncoder, self).__init__(*args, **kwargs)
+ if post_norm_cfg is not None:
+ self.post_norm = build_norm_layer(post_norm_cfg, self.embed_dims)[1] if self.pre_norm else None
+ else:
+ assert not self.pre_norm, f"Use prenorm in " f"{self.__class__.__name__}," f"Please specify post_norm_cfg"
+ self.post_norm = None
+
+ def forward(self, *args, **kwargs):
+ """Forward function for `TransformerCoder`.
+
+ Returns:
+ Tensor: forwarded results with shape [num_query, bs, embed_dims].
+ """
+ x = super(DetrTransformerEncoder, self).forward(*args, **kwargs)
+ if self.post_norm is not None:
+ x = self.post_norm(x)
+ return x
+
+
+@TRANSFORMER_LAYER_SEQUENCE.register_module()
+class DetrTransformerDecoder(TransformerLayerSequence):
+ """Implements the decoder in DETR transformer.
+
+ Args:
+ return_intermediate (bool): Whether to return intermediate outputs.
+ post_norm_cfg (dict): Config of last normalization layer. Default:
+ `LN`.
+ """
+
+ def __init__(self, *args, post_norm_cfg=dict(type="LN"), return_intermediate=False, **kwargs):
+
+ super(DetrTransformerDecoder, self).__init__(*args, **kwargs)
+ self.return_intermediate = return_intermediate
+ if post_norm_cfg is not None:
+ self.post_norm = build_norm_layer(post_norm_cfg, self.embed_dims)[1]
+ else:
+ self.post_norm = None
+
+ def forward(self, query, *args, **kwargs):
+ """Forward function for `TransformerDecoder`.
+
+ Args:
+ query (Tensor): Input query with shape
+ `(num_query, bs, embed_dims)`.
+
+ Returns:
+ Tensor: Results with shape [1, num_query, bs, embed_dims] when
+ return_intermediate is `False`, otherwise it has shape
+ [num_layers, num_query, bs, embed_dims].
+ """
+ if not self.return_intermediate:
+ x = super().forward(query, *args, **kwargs)
+ if self.post_norm:
+ x = self.post_norm(x)[None]
+ return x
+
+ intermediate = []
+ for layer in self.layers:
+ query = layer(query, *args, **kwargs)
+ if self.return_intermediate:
+ if self.post_norm is not None:
+ intermediate.append(self.post_norm(query))
+ else:
+ intermediate.append(query)
+ return torch.stack(intermediate)
+
+
+@TRANSFORMER.register_module()
+class Transformer(BaseModule):
+ """Implements the DETR transformer.
+
+ Following the official DETR implementation, this module copy-paste
+ from torch.nn.Transformer with modifications:
+
+ * positional encodings are passed in MultiheadAttention
+ * extra LN at the end of encoder is removed
+ * decoder returns a stack of activations from all decoding layers
+
+ See `paper: End-to-End Object Detection with Transformers
+ `_ for details.
+
+ Args:
+ encoder (`mmcv.ConfigDict` | Dict): Config of
+ TransformerEncoder. Defaults to None.
+ decoder ((`mmcv.ConfigDict` | Dict)): Config of
+ TransformerDecoder. Defaults to None
+ init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
+ Defaults to None.
+ """
+
+ def __init__(self, encoder=None, decoder=None, init_cfg=None):
+ super(Transformer, self).__init__(init_cfg=init_cfg)
+ self.encoder = build_transformer_layer_sequence(encoder)
+ self.decoder = build_transformer_layer_sequence(decoder)
+ self.embed_dims = self.encoder.embed_dims
+
+ def init_weights(self):
+ # follow the official DETR to init parameters
+ for m in self.modules():
+ if hasattr(m, "weight") and m.weight.dim() > 1:
+ xavier_init(m, distribution="uniform")
+ self._is_init = True
+
+ def forward(self, x, mask, query_embed, pos_embed):
+ """Forward function for `Transformer`.
+
+ Args:
+ x (Tensor): Input query with shape [bs, c, h, w] where
+ c = embed_dims.
+ mask (Tensor): The key_padding_mask used for encoder and decoder,
+ with shape [bs, h, w].
+ query_embed (Tensor): The query embedding for decoder, with shape
+ [num_query, c].
+ pos_embed (Tensor): The positional encoding for encoder and
+ decoder, with the same shape as `x`.
+
+ Returns:
+ tuple[Tensor]: results of decoder containing the following tensor.
+
+ - out_dec: Output from decoder. If return_intermediate_dec \
+ is True output has shape [num_dec_layers, bs,
+ num_query, embed_dims], else has shape [1, bs, \
+ num_query, embed_dims].
+ - memory: Output results from encoder, with shape \
+ [bs, embed_dims, h, w].
+ """
+ bs, c, h, w = x.shape
+ # use `view` instead of `flatten` for dynamically exporting to ONNX
+ x = x.view(bs, c, -1).permute(2, 0, 1) # [bs, c, h, w] -> [h*w, bs, c]
+ pos_embed = pos_embed.view(bs, c, -1).permute(2, 0, 1)
+ query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1) # [num_query, dim] -> [num_query, bs, dim]
+ mask = mask.view(bs, -1) # [bs, h, w] -> [bs, h*w]
+ memory = self.encoder(query=x, key=None, value=None, query_pos=pos_embed, query_key_padding_mask=mask)
+ target = torch.zeros_like(query_embed)
+ # out_dec: [num_layers, num_query, bs, dim]
+ out_dec = self.decoder(
+ query=target, key=memory, value=memory, key_pos=pos_embed, query_pos=query_embed, key_padding_mask=mask
+ )
+ out_dec = out_dec.transpose(1, 2)
+ memory = memory.permute(1, 2, 0).reshape(bs, c, h, w)
+ return out_dec, memory
+
+
+@TRANSFORMER_LAYER_SEQUENCE.register_module()
+class DeformableDetrTransformerDecoder(TransformerLayerSequence):
+ """Implements the decoder in DETR transformer.
+
+ Args:
+ return_intermediate (bool): Whether to return intermediate outputs.
+ coder_norm_cfg (dict): Config of last normalization layer. Default:
+ `LN`.
+ """
+
+ def __init__(self, *args, return_intermediate=False, **kwargs):
+
+ super(DeformableDetrTransformerDecoder, self).__init__(*args, **kwargs)
+ self.return_intermediate = return_intermediate
+
+ def forward(self, query, *args, reference_points=None, valid_ratios=None, reg_branches=None, **kwargs):
+ """Forward function for `TransformerDecoder`.
+
+ Args:
+ query (Tensor): Input query with shape
+ `(num_query, bs, embed_dims)`.
+ reference_points (Tensor): The reference
+ points of offset. has shape
+ (bs, num_query, 4) when as_two_stage,
+ otherwise has shape ((bs, num_query, 2).
+ valid_ratios (Tensor): The radios of valid
+ points on the feature map, has shape
+ (bs, num_levels, 2)
+ reg_branch: (obj:`nn.ModuleList`): Used for
+ refining the regression results. Only would
+ be passed when with_box_refine is True,
+ otherwise would be passed a `None`.
+
+ Returns:
+ Tensor: Results with shape [1, num_query, bs, embed_dims] when
+ return_intermediate is `False`, otherwise it has shape
+ [num_layers, num_query, bs, embed_dims].
+ """
+ output = query
+ intermediate = []
+ intermediate_reference_points = []
+ for lid, layer in enumerate(self.layers):
+ if reference_points.shape[-1] == 4:
+ reference_points_input = (
+ reference_points[:, :, None] * torch.cat([valid_ratios, valid_ratios], -1)[:, None]
+ )
+ else:
+ assert reference_points.shape[-1] == 2
+ reference_points_input = reference_points[:, :, None] * valid_ratios[:, None]
+ output = layer(output, *args, reference_points=reference_points_input, **kwargs)
+ output = output.permute(1, 0, 2)
+
+ if reg_branches is not None:
+ tmp = reg_branches[lid](output)
+ if reference_points.shape[-1] == 4:
+ new_reference_points = tmp + inverse_sigmoid(reference_points)
+ new_reference_points = new_reference_points.sigmoid()
+ else:
+ assert reference_points.shape[-1] == 2
+ new_reference_points = tmp
+ new_reference_points[..., :2] = tmp[..., :2] + inverse_sigmoid(reference_points)
+ new_reference_points = new_reference_points.sigmoid()
+ reference_points = new_reference_points.detach()
+
+ output = output.permute(1, 0, 2)
+ if self.return_intermediate:
+ intermediate.append(output)
+ intermediate_reference_points.append(reference_points)
+
+ if self.return_intermediate:
+ return torch.stack(intermediate), torch.stack(intermediate_reference_points)
+
+ return output, reference_points
+
+
+@TRANSFORMER.register_module()
+class DeformableDetrTransformer(Transformer):
+ """Implements the DeformableDETR transformer.
+
+ Args:
+ as_two_stage (bool): Generate query from encoder features.
+ Default: False.
+ num_feature_levels (int): Number of feature maps from FPN:
+ Default: 4.
+ two_stage_num_proposals (int): Number of proposals when set
+ `as_two_stage` as True. Default: 300.
+ """
+
+ def __init__(self, as_two_stage=False, num_feature_levels=4, two_stage_num_proposals=300, **kwargs):
+ super(DeformableDetrTransformer, self).__init__(**kwargs)
+ self.as_two_stage = as_two_stage
+ self.num_feature_levels = num_feature_levels
+ self.two_stage_num_proposals = two_stage_num_proposals
+ self.embed_dims = self.encoder.embed_dims
+ self.init_layers()
+
+ def init_layers(self):
+ """Initialize layers of the DeformableDetrTransformer."""
+ self.level_embeds = nn.Parameter(torch.Tensor(self.num_feature_levels, self.embed_dims))
+
+ if self.as_two_stage:
+ self.enc_output = nn.Linear(self.embed_dims, self.embed_dims)
+ self.enc_output_norm = nn.LayerNorm(self.embed_dims)
+ self.pos_trans = nn.Linear(self.embed_dims * 2, self.embed_dims * 2)
+ self.pos_trans_norm = nn.LayerNorm(self.embed_dims * 2)
+ else:
+ self.reference_points = nn.Linear(self.embed_dims, 2)
+
+ def init_weights(self):
+ """Initialize the transformer weights."""
+ for p in self.parameters():
+ if p.dim() > 1:
+ nn.init.xavier_uniform_(p)
+ for m in self.modules():
+ if isinstance(m, MultiScaleDeformableAttention):
+ m.init_weights()
+ if not self.as_two_stage:
+ xavier_init(self.reference_points, distribution="uniform", bias=0.0)
+ normal_(self.level_embeds)
+
+ def gen_encoder_output_proposals(self, memory, memory_padding_mask, spatial_shapes):
+ """Generate proposals from encoded memory.
+
+ Args:
+ memory (Tensor) : The output of encoder,
+ has shape (bs, num_key, embed_dim). num_key is
+ equal the number of points on feature map from
+ all level.
+ memory_padding_mask (Tensor): Padding mask for memory.
+ has shape (bs, num_key).
+ spatial_shapes (Tensor): The shape of all feature maps.
+ has shape (num_level, 2).
+
+ Returns:
+ tuple: A tuple of feature map and bbox prediction.
+
+ - output_memory (Tensor): The input of decoder, \
+ has shape (bs, num_key, embed_dim). num_key is \
+ equal the number of points on feature map from \
+ all levels.
+ - output_proposals (Tensor): The normalized proposal \
+ after a inverse sigmoid, has shape \
+ (bs, num_keys, 4).
+ """
+
+ N, S, C = memory.shape
+ proposals = []
+ _cur = 0
+ for lvl, (H, W) in enumerate(spatial_shapes):
+ mask_flatten_ = memory_padding_mask[:, _cur : (_cur + H * W)].view(N, H, W, 1)
+ valid_H = torch.sum(~mask_flatten_[:, :, 0, 0], 1)
+ valid_W = torch.sum(~mask_flatten_[:, 0, :, 0], 1)
+
+ grid_y, grid_x = torch.meshgrid(
+ torch.linspace(0, H - 1, H, dtype=torch.float32, device=memory.device),
+ torch.linspace(0, W - 1, W, dtype=torch.float32, device=memory.device),
+ )
+ grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1)
+
+ scale = torch.cat([valid_W.unsqueeze(-1), valid_H.unsqueeze(-1)], 1).view(N, 1, 1, 2)
+ grid = (grid.unsqueeze(0).expand(N, -1, -1, -1) + 0.5) / scale
+ wh = torch.ones_like(grid) * 0.05 * (2.0**lvl)
+ proposal = torch.cat((grid, wh), -1).view(N, -1, 4)
+ proposals.append(proposal)
+ _cur += H * W
+ output_proposals = torch.cat(proposals, 1)
+ output_proposals_valid = ((output_proposals > 0.01) & (output_proposals < 0.99)).all(-1, keepdim=True)
+ output_proposals = torch.log(output_proposals / (1 - output_proposals))
+ output_proposals = output_proposals.masked_fill(memory_padding_mask.unsqueeze(-1), float("inf"))
+ output_proposals = output_proposals.masked_fill(~output_proposals_valid, float("inf"))
+
+ output_memory = memory
+ output_memory = output_memory.masked_fill(memory_padding_mask.unsqueeze(-1), float(0))
+ output_memory = output_memory.masked_fill(~output_proposals_valid, float(0))
+ output_memory = self.enc_output_norm(self.enc_output(output_memory))
+ return output_memory, output_proposals
+
+ @staticmethod
+ def get_reference_points(spatial_shapes, valid_ratios, device):
+ """Get the reference points used in decoder.
+
+ Args:
+ spatial_shapes (Tensor): The shape of all
+ feature maps, has shape (num_level, 2).
+ valid_ratios (Tensor): The radios of valid
+ points on the feature map, has shape
+ (bs, num_levels, 2)
+ device (obj:`device`): The device where
+ reference_points should be.
+
+ Returns:
+ Tensor: reference points used in decoder, has \
+ shape (bs, num_keys, num_levels, 2).
+ """
+ reference_points_list = []
+ for lvl, (H, W) in enumerate(spatial_shapes):
+ ref_y, ref_x = torch.meshgrid(
+ torch.linspace(0.5, H - 0.5, H, dtype=torch.float32, device=device),
+ torch.linspace(0.5, W - 0.5, W, dtype=torch.float32, device=device),
+ )
+ ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H)
+ ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W)
+ ref = torch.stack((ref_x, ref_y), -1)
+ reference_points_list.append(ref)
+ reference_points = torch.cat(reference_points_list, 1)
+ reference_points = reference_points[:, :, None] * valid_ratios[:, None]
+ return reference_points
+
+ def get_valid_ratio(self, mask):
+ """Get the valid radios of feature maps of all level."""
+ _, H, W = mask.shape
+ valid_H = torch.sum(~mask[:, :, 0], 1)
+ valid_W = torch.sum(~mask[:, 0, :], 1)
+ valid_ratio_h = valid_H.float() / H
+ valid_ratio_w = valid_W.float() / W
+ valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1)
+ return valid_ratio
+
+ def get_proposal_pos_embed(self, proposals, num_pos_feats=128, temperature=10000):
+ """Get the position embedding of proposal."""
+ scale = 2 * math.pi
+ dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=proposals.device)
+ dim_t = temperature ** (2 * (dim_t // 2) / num_pos_feats)
+ # N, L, 4
+ proposals = proposals.sigmoid() * scale
+ # N, L, 4, 128
+ pos = proposals[:, :, :, None] / dim_t
+ # N, L, 4, 64, 2
+ pos = torch.stack((pos[:, :, :, 0::2].sin(), pos[:, :, :, 1::2].cos()), dim=4).flatten(2)
+ return pos
+
+ def forward(
+ self, mlvl_feats, mlvl_masks, query_embed, mlvl_pos_embeds, reg_branches=None, cls_branches=None, **kwargs
+ ):
+ """Forward function for `Transformer`.
+
+ Args:
+ mlvl_feats (list(Tensor)): Input queries from
+ different level. Each element has shape
+ [bs, embed_dims, h, w].
+ mlvl_masks (list(Tensor)): The key_padding_mask from
+ different level used for encoder and decoder,
+ each element has shape [bs, h, w].
+ query_embed (Tensor): The query embedding for decoder,
+ with shape [num_query, c].
+ mlvl_pos_embeds (list(Tensor)): The positional encoding
+ of feats from different level, has the shape
+ [bs, embed_dims, h, w].
+ reg_branches (obj:`nn.ModuleList`): Regression heads for
+ feature maps from each decoder layer. Only would
+ be passed when
+ `with_box_refine` is True. Default to None.
+ cls_branches (obj:`nn.ModuleList`): Classification heads
+ for feature maps from each decoder layer. Only would
+ be passed when `as_two_stage`
+ is True. Default to None.
+
+
+ Returns:
+ tuple[Tensor]: results of decoder containing the following tensor.
+
+ - inter_states: Outputs from decoder. If
+ return_intermediate_dec is True output has shape \
+ (num_dec_layers, bs, num_query, embed_dims), else has \
+ shape (1, bs, num_query, embed_dims).
+ - init_reference_out: The initial value of reference \
+ points, has shape (bs, num_queries, 4).
+ - inter_references_out: The internal value of reference \
+ points in decoder, has shape \
+ (num_dec_layers, bs,num_query, embed_dims)
+ - enc_outputs_class: The classification score of \
+ proposals generated from \
+ encoder's feature maps, has shape \
+ (batch, h*w, num_classes). \
+ Only would be returned when `as_two_stage` is True, \
+ otherwise None.
+ - enc_outputs_coord_unact: The regression results \
+ generated from encoder's feature maps., has shape \
+ (batch, h*w, 4). Only would \
+ be returned when `as_two_stage` is True, \
+ otherwise None.
+ """
+ assert self.as_two_stage or query_embed is not None
+
+ feat_flatten = []
+ mask_flatten = []
+ lvl_pos_embed_flatten = []
+ spatial_shapes = []
+ for lvl, (feat, mask, pos_embed) in enumerate(zip(mlvl_feats, mlvl_masks, mlvl_pos_embeds)):
+ bs, c, h, w = feat.shape
+ spatial_shape = (h, w)
+ spatial_shapes.append(spatial_shape)
+ feat = feat.flatten(2).transpose(1, 2)
+ mask = mask.flatten(1)
+ pos_embed = pos_embed.flatten(2).transpose(1, 2)
+ lvl_pos_embed = pos_embed + self.level_embeds[lvl].view(1, 1, -1)
+ lvl_pos_embed_flatten.append(lvl_pos_embed)
+ feat_flatten.append(feat)
+ mask_flatten.append(mask)
+ feat_flatten = torch.cat(feat_flatten, 1)
+ mask_flatten = torch.cat(mask_flatten, 1)
+ lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1)
+ spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=feat_flatten.device)
+ level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1]))
+ valid_ratios = torch.stack([self.get_valid_ratio(m) for m in mlvl_masks], 1)
+
+ reference_points = self.get_reference_points(spatial_shapes, valid_ratios, device=feat.device)
+
+ feat_flatten = feat_flatten.permute(1, 0, 2) # (H*W, bs, embed_dims)
+ lvl_pos_embed_flatten = lvl_pos_embed_flatten.permute(1, 0, 2) # (H*W, bs, embed_dims)
+ memory = self.encoder(
+ query=feat_flatten,
+ key=None,
+ value=None,
+ query_pos=lvl_pos_embed_flatten,
+ query_key_padding_mask=mask_flatten,
+ spatial_shapes=spatial_shapes,
+ reference_points=reference_points,
+ level_start_index=level_start_index,
+ valid_ratios=valid_ratios,
+ **kwargs,
+ )
+
+ memory = memory.permute(1, 0, 2)
+ bs, _, c = memory.shape
+ if self.as_two_stage:
+ output_memory, output_proposals = self.gen_encoder_output_proposals(memory, mask_flatten, spatial_shapes)
+ enc_outputs_class = cls_branches[self.decoder.num_layers](output_memory)
+ enc_outputs_coord_unact = reg_branches[self.decoder.num_layers](output_memory) + output_proposals
+
+ topk = self.two_stage_num_proposals
+ topk_proposals = torch.topk(enc_outputs_class[..., 0], topk, dim=1)[1]
+ topk_coords_unact = torch.gather(enc_outputs_coord_unact, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4))
+ topk_coords_unact = topk_coords_unact.detach()
+ reference_points = topk_coords_unact.sigmoid()
+ init_reference_out = reference_points
+ pos_trans_out = self.pos_trans_norm(self.pos_trans(self.get_proposal_pos_embed(topk_coords_unact)))
+ query_pos, query = torch.split(pos_trans_out, c, dim=2)
+ else:
+ query_pos, query = torch.split(query_embed, c, dim=1)
+ query_pos = query_pos.unsqueeze(0).expand(bs, -1, -1)
+ query = query.unsqueeze(0).expand(bs, -1, -1)
+ reference_points = self.reference_points(query_pos).sigmoid()
+ init_reference_out = reference_points
+
+ # decoder
+ query = query.permute(1, 0, 2)
+ memory = memory.permute(1, 0, 2)
+ query_pos = query_pos.permute(1, 0, 2)
+ inter_states, inter_references = self.decoder(
+ query=query,
+ key=None,
+ value=memory,
+ query_pos=query_pos,
+ key_padding_mask=mask_flatten,
+ reference_points=reference_points,
+ spatial_shapes=spatial_shapes,
+ level_start_index=level_start_index,
+ valid_ratios=valid_ratios,
+ reg_branches=reg_branches,
+ **kwargs,
+ )
+
+ inter_references_out = inter_references
+ if self.as_two_stage:
+ return inter_states, init_reference_out, inter_references_out, enc_outputs_class, enc_outputs_coord_unact
+ return inter_states, init_reference_out, inter_references_out, None, None
+
+
+@TRANSFORMER.register_module()
+class DynamicConv(BaseModule):
+ """Implements Dynamic Convolution.
+
+ This module generate parameters for each sample and
+ use bmm to implement 1*1 convolution. Code is modified
+ from the `official github repo `_ .
+
+ Args:
+ in_channels (int): The input feature channel.
+ Defaults to 256.
+ feat_channels (int): The inner feature channel.
+ Defaults to 64.
+ out_channels (int, optional): The output feature channel.
+ When not specified, it will be set to `in_channels`
+ by default
+ input_feat_shape (int): The shape of input feature.
+ Defaults to 7.
+ with_proj (bool): Project two-dimentional feature to
+ one-dimentional feature. Default to True.
+ act_cfg (dict): The activation config for DynamicConv.
+ norm_cfg (dict): Config dict for normalization layer. Default
+ layer normalization.
+ init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
+ Default: None.
+ """
+
+ def __init__(
+ self,
+ in_channels=256,
+ feat_channels=64,
+ out_channels=None,
+ input_feat_shape=7,
+ with_proj=True,
+ act_cfg=dict(type="ReLU", inplace=True),
+ norm_cfg=dict(type="LN"),
+ init_cfg=None,
+ ):
+ super(DynamicConv, self).__init__(init_cfg)
+ self.in_channels = in_channels
+ self.feat_channels = feat_channels
+ self.out_channels_raw = out_channels
+ self.input_feat_shape = input_feat_shape
+ self.with_proj = with_proj
+ self.act_cfg = act_cfg
+ self.norm_cfg = norm_cfg
+ self.out_channels = out_channels if out_channels else in_channels
+
+ self.num_params_in = self.in_channels * self.feat_channels
+ self.num_params_out = self.out_channels * self.feat_channels
+ self.dynamic_layer = nn.Linear(self.in_channels, self.num_params_in + self.num_params_out)
+
+ self.norm_in = build_norm_layer(norm_cfg, self.feat_channels)[1]
+ self.norm_out = build_norm_layer(norm_cfg, self.out_channels)[1]
+
+ self.activation = build_activation_layer(act_cfg)
+
+ num_output = self.out_channels * input_feat_shape**2
+ if self.with_proj:
+ self.fc_layer = nn.Linear(num_output, self.out_channels)
+ self.fc_norm = build_norm_layer(norm_cfg, self.out_channels)[1]
+
+ def forward(self, param_feature, input_feature):
+ """Forward function for `DynamicConv`.
+
+ Args:
+ param_feature (Tensor): The feature can be used
+ to generate the parameter, has shape
+ (num_all_proposals, in_channels).
+ input_feature (Tensor): Feature that
+ interact with parameters, has shape
+ (num_all_proposals, in_channels, H, W).
+
+ Returns:
+ Tensor: The output feature has shape
+ (num_all_proposals, out_channels).
+ """
+ input_feature = input_feature.flatten(2).permute(2, 0, 1)
+
+ input_feature = input_feature.permute(1, 0, 2)
+ parameters = self.dynamic_layer(param_feature)
+
+ param_in = parameters[:, : self.num_params_in].view(-1, self.in_channels, self.feat_channels)
+ param_out = parameters[:, -self.num_params_out :].view(-1, self.feat_channels, self.out_channels)
+
+ # input_feature has shape (num_all_proposals, H*W, in_channels)
+ # param_in has shape (num_all_proposals, in_channels, feat_channels)
+ # feature has shape (num_all_proposals, H*W, feat_channels)
+ features = torch.bmm(input_feature, param_in)
+ features = self.norm_in(features)
+ features = self.activation(features)
+
+ # param_out has shape (batch_size, feat_channels, out_channels)
+ features = torch.bmm(features, param_out)
+ features = self.norm_out(features)
+ features = self.activation(features)
+
+ if self.with_proj:
+ features = features.flatten(1)
+ features = self.fc_layer(features)
+ features = self.fc_norm(features)
+ features = self.activation(features)
+
+ return features
diff --git a/mapper/models/dinov2/eval/segmentation_m2f/ops/modules/__init__.py b/mapper/models/dinov2/eval/segmentation_m2f/ops/modules/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..49aa8fe612fd4c088e294707c5ee16bd1cb5b5e7
--- /dev/null
+++ b/mapper/models/dinov2/eval/segmentation_m2f/ops/modules/__init__.py
@@ -0,0 +1,10 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/fundamentalvision/Deformable-DETR/tree/main/models/ops/modules
+# https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
+
+from .ms_deform_attn import MSDeformAttn
diff --git a/mapper/models/dinov2/eval/segmentation_m2f/ops/modules/ms_deform_attn.py b/mapper/models/dinov2/eval/segmentation_m2f/ops/modules/ms_deform_attn.py
new file mode 100644
index 0000000000000000000000000000000000000000..d8b4fa23712e87d1a2682b57e71ee37fe8524cff
--- /dev/null
+++ b/mapper/models/dinov2/eval/segmentation_m2f/ops/modules/ms_deform_attn.py
@@ -0,0 +1,185 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+import math
+import warnings
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+from torch.autograd import Function
+from torch.cuda.amp import custom_fwd
+from torch.nn.init import constant_, xavier_uniform_
+
+
+class MSDeformAttnFunction(Function):
+ @staticmethod
+ @custom_fwd(cast_inputs=torch.float32)
+ def forward(
+ ctx, value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, im2col_step
+ ):
+ output = ms_deform_attn_core_pytorch(
+ value,
+ value_spatial_shapes,
+ # value_level_start_index,
+ sampling_locations,
+ attention_weights,
+ )
+ return output
+
+
+def ms_deform_attn_core_pytorch(value, value_spatial_shapes, sampling_locations, attention_weights):
+ # for debug and test only,
+ # need to use cuda version instead
+ N_, S_, M_, D_ = value.shape
+ _, Lq_, M_, L_, P_, _ = sampling_locations.shape
+ value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1)
+ sampling_grids = 2 * sampling_locations - 1
+ sampling_value_list = []
+ for lid_, (H_, W_) in enumerate(value_spatial_shapes):
+ # N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_
+ value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape(N_ * M_, D_, H_, W_)
+ # N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2
+ sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1)
+ # N_*M_, D_, Lq_, P_
+ sampling_value_l_ = F.grid_sample(
+ value_l_, sampling_grid_l_, mode="bilinear", padding_mode="zeros", align_corners=False
+ )
+ sampling_value_list.append(sampling_value_l_)
+ # (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_)
+ attention_weights = attention_weights.transpose(1, 2).reshape(N_ * M_, 1, Lq_, L_ * P_)
+ output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view(N_, M_ * D_, Lq_)
+ return output.transpose(1, 2).contiguous()
+
+
+def _is_power_of_2(n):
+ if (not isinstance(n, int)) or (n < 0):
+ raise ValueError("invalid input for _is_power_of_2: {} (type: {})".format(n, type(n)))
+ return (n & (n - 1) == 0) and n != 0
+
+
+class MSDeformAttn(nn.Module):
+ def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4, ratio=1.0):
+ """Multi-Scale Deformable Attention Module.
+
+ :param d_model hidden dimension
+ :param n_levels number of feature levels
+ :param n_heads number of attention heads
+ :param n_points number of sampling points per attention head per feature level
+ """
+ super().__init__()
+ if d_model % n_heads != 0:
+ raise ValueError("d_model must be divisible by n_heads, " "but got {} and {}".format(d_model, n_heads))
+ _d_per_head = d_model // n_heads
+ # you'd better set _d_per_head to a power of 2
+ # which is more efficient in our CUDA implementation
+ if not _is_power_of_2(_d_per_head):
+ warnings.warn(
+ "You'd better set d_model in MSDeformAttn to make "
+ "the dimension of each attention head a power of 2 "
+ "which is more efficient in our CUDA implementation."
+ )
+
+ self.im2col_step = 64
+
+ self.d_model = d_model
+ self.n_levels = n_levels
+ self.n_heads = n_heads
+ self.n_points = n_points
+ self.ratio = ratio
+ self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 2)
+ self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points)
+ self.value_proj = nn.Linear(d_model, int(d_model * ratio))
+ self.output_proj = nn.Linear(int(d_model * ratio), d_model)
+
+ self._reset_parameters()
+
+ def _reset_parameters(self):
+ constant_(self.sampling_offsets.weight.data, 0.0)
+ thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads)
+ grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
+ grid_init = (
+ (grid_init / grid_init.abs().max(-1, keepdim=True)[0])
+ .view(self.n_heads, 1, 1, 2)
+ .repeat(1, self.n_levels, self.n_points, 1)
+ )
+ for i in range(self.n_points):
+ grid_init[:, :, i, :] *= i + 1
+
+ with torch.no_grad():
+ self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))
+ constant_(self.attention_weights.weight.data, 0.0)
+ constant_(self.attention_weights.bias.data, 0.0)
+ xavier_uniform_(self.value_proj.weight.data)
+ constant_(self.value_proj.bias.data, 0.0)
+ xavier_uniform_(self.output_proj.weight.data)
+ constant_(self.output_proj.bias.data, 0.0)
+
+ def forward(
+ self,
+ query,
+ reference_points,
+ input_flatten,
+ input_spatial_shapes,
+ input_level_start_index,
+ input_padding_mask=None,
+ ):
+ """
+ :param query (N, Length_{query}, C)
+ :param reference_points (N, Length_{query}, n_levels, 2), range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area
+ or (N, Length_{query}, n_levels, 4), add additional (w, h) to form reference boxes
+ :param input_flatten (N, \\sum_{l=0}^{L-1} H_l \\cdot W_l, C)
+ :param input_spatial_shapes (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})]
+ :param input_level_start_index (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}]
+ :param input_padding_mask (N, \\sum_{l=0}^{L-1} H_l \\cdot W_l), True for padding elements, False for non-padding elements
+
+ :return output (N, Length_{query}, C)
+ """
+ # print(query.shape)
+ # print(reference_points.shape)
+ # print(input_flatten.shape)
+ # print(input_spatial_shapes.shape)
+ # print(input_level_start_index.shape)
+ # print(input_spatial_shapes)
+ # print(input_level_start_index)
+
+ N, Len_q, _ = query.shape
+ N, Len_in, _ = input_flatten.shape
+ assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]).sum() == Len_in
+
+ value = self.value_proj(input_flatten)
+ if input_padding_mask is not None:
+ value = value.masked_fill(input_padding_mask[..., None], float(0))
+
+ value = value.view(N, Len_in, self.n_heads, int(self.ratio * self.d_model) // self.n_heads)
+ sampling_offsets = self.sampling_offsets(query).view(N, Len_q, self.n_heads, self.n_levels, self.n_points, 2)
+ attention_weights = self.attention_weights(query).view(N, Len_q, self.n_heads, self.n_levels * self.n_points)
+ attention_weights = F.softmax(attention_weights, -1).view(N, Len_q, self.n_heads, self.n_levels, self.n_points)
+
+ if reference_points.shape[-1] == 2:
+ offset_normalizer = torch.stack([input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1)
+ sampling_locations = (
+ reference_points[:, :, None, :, None, :]
+ + sampling_offsets / offset_normalizer[None, None, None, :, None, :]
+ )
+ elif reference_points.shape[-1] == 4:
+ sampling_locations = (
+ reference_points[:, :, None, :, None, :2]
+ + sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5
+ )
+ else:
+ raise ValueError(
+ "Last dim of reference_points must be 2 or 4, but get {} instead.".format(reference_points.shape[-1])
+ )
+ output = MSDeformAttnFunction.apply(
+ value,
+ input_spatial_shapes,
+ input_level_start_index,
+ sampling_locations,
+ attention_weights,
+ self.im2col_step,
+ )
+ output = self.output_proj(output)
+ return output
diff --git a/mapper/models/dinov2/eval/setup.py b/mapper/models/dinov2/eval/setup.py
new file mode 100644
index 0000000000000000000000000000000000000000..959128c0673cc51036dbf17dcc4ee68a037988fb
--- /dev/null
+++ b/mapper/models/dinov2/eval/setup.py
@@ -0,0 +1,75 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+import argparse
+from typing import Any, List, Optional, Tuple
+
+import torch
+import torch.backends.cudnn as cudnn
+
+from dinov2.models import build_model_from_cfg
+from dinov2.utils.config import setup
+import dinov2.utils.utils as dinov2_utils
+
+
+def get_args_parser(
+ description: Optional[str] = None,
+ parents: Optional[List[argparse.ArgumentParser]] = None,
+ add_help: bool = True,
+):
+ parser = argparse.ArgumentParser(
+ description=description,
+ parents=parents or [],
+ add_help=add_help,
+ )
+ parser.add_argument(
+ "--config-file",
+ type=str,
+ help="Model configuration file",
+ )
+ parser.add_argument(
+ "--pretrained-weights",
+ type=str,
+ help="Pretrained model weights",
+ )
+ parser.add_argument(
+ "--output-dir",
+ default="",
+ type=str,
+ help="Output directory to write results and logs",
+ )
+ parser.add_argument(
+ "--opts",
+ help="Extra configuration options",
+ default=[],
+ nargs="+",
+ )
+ return parser
+
+
+def get_autocast_dtype(config):
+ teacher_dtype_str = config.compute_precision.teacher.backbone.mixed_precision.param_dtype
+ if teacher_dtype_str == "fp16":
+ return torch.half
+ elif teacher_dtype_str == "bf16":
+ return torch.bfloat16
+ else:
+ return torch.float
+
+
+def build_model_for_eval(config, pretrained_weights):
+ model, _ = build_model_from_cfg(config, only_teacher=True)
+ dinov2_utils.load_pretrained_weights(model, pretrained_weights, "teacher")
+ model.eval()
+ model.cuda()
+ return model
+
+
+def setup_and_build_model(args) -> Tuple[Any, torch.dtype]:
+ cudnn.benchmark = True
+ config = setup(args)
+ model = build_model_for_eval(config, args.pretrained_weights)
+ autocast_dtype = get_autocast_dtype(config)
+ return model, autocast_dtype
diff --git a/mapper/models/dinov2/eval/utils.py b/mapper/models/dinov2/eval/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..c50576b1940587ee64b7a422e2e96b475d60fd39
--- /dev/null
+++ b/mapper/models/dinov2/eval/utils.py
@@ -0,0 +1,146 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+import logging
+from typing import Dict, Optional
+
+import torch
+from torch import nn
+from torchmetrics import MetricCollection
+
+from dinov2.data import DatasetWithEnumeratedTargets, SamplerType, make_data_loader
+import dinov2.distributed as distributed
+from dinov2.logging import MetricLogger
+
+
+logger = logging.getLogger("dinov2")
+
+
+class ModelWithNormalize(torch.nn.Module):
+ def __init__(self, model):
+ super().__init__()
+ self.model = model
+
+ def forward(self, samples):
+ return nn.functional.normalize(self.model(samples), dim=1, p=2)
+
+
+class ModelWithIntermediateLayers(nn.Module):
+ def __init__(self, feature_model, n_last_blocks, autocast_ctx):
+ super().__init__()
+ self.feature_model = feature_model
+ self.feature_model.eval()
+ self.n_last_blocks = n_last_blocks
+ self.autocast_ctx = autocast_ctx
+
+ def forward(self, images):
+ with torch.inference_mode():
+ with self.autocast_ctx():
+ features = self.feature_model.get_intermediate_layers(
+ images, self.n_last_blocks, return_class_token=True
+ )
+ return features
+
+
+@torch.inference_mode()
+def evaluate(
+ model: nn.Module,
+ data_loader,
+ postprocessors: Dict[str, nn.Module],
+ metrics: Dict[str, MetricCollection],
+ device: torch.device,
+ criterion: Optional[nn.Module] = None,
+):
+ model.eval()
+ if criterion is not None:
+ criterion.eval()
+
+ for metric in metrics.values():
+ metric = metric.to(device)
+
+ metric_logger = MetricLogger(delimiter=" ")
+ header = "Test:"
+
+ for samples, targets, *_ in metric_logger.log_every(data_loader, 10, header):
+ outputs = model(samples.to(device))
+ targets = targets.to(device)
+
+ if criterion is not None:
+ loss = criterion(outputs, targets)
+ metric_logger.update(loss=loss.item())
+
+ for k, metric in metrics.items():
+ metric_inputs = postprocessors[k](outputs, targets)
+ metric.update(**metric_inputs)
+
+ metric_logger.synchronize_between_processes()
+ logger.info(f"Averaged stats: {metric_logger}")
+
+ stats = {k: metric.compute() for k, metric in metrics.items()}
+ metric_logger_stats = {k: meter.global_avg for k, meter in metric_logger.meters.items()}
+ return metric_logger_stats, stats
+
+
+def all_gather_and_flatten(tensor_rank):
+ tensor_all_ranks = torch.empty(
+ distributed.get_global_size(),
+ *tensor_rank.shape,
+ dtype=tensor_rank.dtype,
+ device=tensor_rank.device,
+ )
+ tensor_list = list(tensor_all_ranks.unbind(0))
+ torch.distributed.all_gather(tensor_list, tensor_rank.contiguous())
+ return tensor_all_ranks.flatten(end_dim=1)
+
+
+def extract_features(model, dataset, batch_size, num_workers, gather_on_cpu=False):
+ dataset_with_enumerated_targets = DatasetWithEnumeratedTargets(dataset)
+ sample_count = len(dataset_with_enumerated_targets)
+ data_loader = make_data_loader(
+ dataset=dataset_with_enumerated_targets,
+ batch_size=batch_size,
+ num_workers=num_workers,
+ sampler_type=SamplerType.DISTRIBUTED,
+ drop_last=False,
+ shuffle=False,
+ )
+ return extract_features_with_dataloader(model, data_loader, sample_count, gather_on_cpu)
+
+
+@torch.inference_mode()
+def extract_features_with_dataloader(model, data_loader, sample_count, gather_on_cpu=False):
+ gather_device = torch.device("cpu") if gather_on_cpu else torch.device("cuda")
+ metric_logger = MetricLogger(delimiter=" ")
+ features, all_labels = None, None
+ for samples, (index, labels_rank) in metric_logger.log_every(data_loader, 10):
+ samples = samples.cuda(non_blocking=True)
+ labels_rank = labels_rank.cuda(non_blocking=True)
+ index = index.cuda(non_blocking=True)
+ features_rank = model(samples).float()
+
+ # init storage feature matrix
+ if features is None:
+ features = torch.zeros(sample_count, features_rank.shape[-1], device=gather_device)
+ labels_shape = list(labels_rank.shape)
+ labels_shape[0] = sample_count
+ all_labels = torch.full(labels_shape, fill_value=-1, device=gather_device)
+ logger.info(f"Storing features into tensor of shape {features.shape}")
+
+ # share indexes, features and labels between processes
+ index_all = all_gather_and_flatten(index).to(gather_device)
+ features_all_ranks = all_gather_and_flatten(features_rank).to(gather_device)
+ labels_all_ranks = all_gather_and_flatten(labels_rank).to(gather_device)
+
+ # update storage feature matrix
+ if len(index_all) > 0:
+ features.index_copy_(0, index_all, features_all_ranks)
+ all_labels.index_copy_(0, index_all, labels_all_ranks)
+
+ logger.info(f"Features shape: {tuple(features.shape)}")
+ logger.info(f"Labels shape: {tuple(all_labels.shape)}")
+
+ assert torch.all(all_labels > -1)
+
+ return features, all_labels
diff --git a/mapper/models/dinov2/fsdp/__init__.py b/mapper/models/dinov2/fsdp/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ed454480e0b76e761d657cc40fd097bd339d15a2
--- /dev/null
+++ b/mapper/models/dinov2/fsdp/__init__.py
@@ -0,0 +1,157 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+import os
+from typing import Any
+
+import torch
+import dinov2.distributed as distributed
+from functools import partial
+from fvcore.common.checkpoint import Checkpointer
+from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
+from torch.distributed.fsdp import ShardingStrategy
+from torch.distributed.fsdp import MixedPrecision
+from torch.distributed.fsdp import StateDictType
+from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
+from torch.distributed.fsdp.wrap import ModuleWrapPolicy
+from torch.distributed.fsdp._runtime_utils import _reshard
+
+
+def get_fsdp_wrapper(model_cfg, modules_to_wrap=set()):
+ sharding_strategy_dict = {
+ "NO_SHARD": ShardingStrategy.NO_SHARD,
+ "SHARD_GRAD_OP": ShardingStrategy.SHARD_GRAD_OP,
+ "FULL_SHARD": ShardingStrategy.FULL_SHARD,
+ }
+
+ dtype_dict = {
+ "fp32": torch.float32,
+ "fp16": torch.float16,
+ "bf16": torch.bfloat16,
+ }
+
+ mixed_precision_config = MixedPrecision(
+ param_dtype=dtype_dict[model_cfg.mixed_precision.param_dtype],
+ reduce_dtype=dtype_dict[model_cfg.mixed_precision.reduce_dtype],
+ buffer_dtype=dtype_dict[model_cfg.mixed_precision.buffer_dtype],
+ )
+
+ sharding_strategy_config = sharding_strategy_dict[model_cfg.sharding_strategy]
+
+ local_rank = distributed.get_local_rank()
+
+ fsdp_wrapper = partial(
+ FSDP,
+ sharding_strategy=sharding_strategy_config,
+ mixed_precision=mixed_precision_config,
+ device_id=local_rank,
+ sync_module_states=True,
+ use_orig_params=True,
+ auto_wrap_policy=ModuleWrapPolicy(modules_to_wrap),
+ )
+ return fsdp_wrapper
+
+
+def is_fsdp(x):
+ return isinstance(x, FSDP)
+
+
+def is_sharded_fsdp(x):
+ return is_fsdp(x) and x.sharding_strategy is not ShardingStrategy.NO_SHARD
+
+
+def free_if_fsdp(x):
+ if is_sharded_fsdp(x):
+ handles = x._handles
+ true_list = [True for h in handles]
+ _reshard(x, handles, true_list)
+
+
+def get_fsdp_modules(x):
+ return FSDP.fsdp_modules(x)
+
+
+def reshard_fsdp_model(x):
+ for m in get_fsdp_modules(x):
+ free_if_fsdp(m)
+
+
+def rankstr():
+ return f"rank_{distributed.get_global_rank()}"
+
+
+class FSDPCheckpointer(Checkpointer):
+ def save(self, name: str, **kwargs: Any) -> None:
+ """
+ Dump model and checkpointables to a file.
+
+ Args:
+ name (str): name of the file.
+ kwargs (dict): extra arbitrary data to save.
+ """
+ if not self.save_dir or not self.save_to_disk:
+ return
+
+ data = {}
+ with FSDP.state_dict_type(self.model, StateDictType.LOCAL_STATE_DICT):
+ data["model"] = self.model.state_dict()
+
+ # data["model"] = self.model.state_dict()
+ for key, obj in self.checkpointables.items():
+ data[key] = obj.state_dict()
+ data.update(kwargs)
+
+ basename = f"{name}.{rankstr()}.pth"
+ save_file = os.path.join(self.save_dir, basename)
+ assert os.path.basename(save_file) == basename, basename
+ self.logger.info("Saving checkpoint to {}".format(save_file))
+ with self.path_manager.open(save_file, "wb") as f:
+ torch.save(data, f)
+ self.tag_last_checkpoint(basename)
+
+ def load(self, *args, **kwargs):
+ with FSDP.state_dict_type(self.model, StateDictType.LOCAL_STATE_DICT):
+ return super().load(*args, **kwargs)
+
+ def has_checkpoint(self) -> bool:
+ """
+ Returns:
+ bool: whether a checkpoint exists in the target directory.
+ """
+ save_file = os.path.join(self.save_dir, f"last_checkpoint.{rankstr()}")
+ return self.path_manager.exists(save_file)
+
+ def get_checkpoint_file(self) -> str:
+ """
+ Returns:
+ str: The latest checkpoint file in target directory.
+ """
+ save_file = os.path.join(self.save_dir, f"last_checkpoint.{rankstr()}")
+ try:
+ with self.path_manager.open(save_file, "r") as f:
+ last_saved = f.read().strip()
+ except IOError:
+ # if file doesn't exist, maybe because it has just been
+ # deleted by a separate process
+ return ""
+ # pyre-fixme[6]: For 2nd param expected `Union[PathLike[str], str]` but got
+ # `Union[bytes, str]`.
+ return os.path.join(self.save_dir, last_saved)
+
+ def tag_last_checkpoint(self, last_filename_basename: str) -> None:
+ """
+ Tag the last checkpoint.
+
+ Args:
+ last_filename_basename (str): the basename of the last filename.
+ """
+ if distributed.is_enabled():
+ torch.distributed.barrier()
+ save_file = os.path.join(self.save_dir, f"last_checkpoint.{rankstr()}")
+ with self.path_manager.open(save_file, "w") as f:
+ f.write(last_filename_basename) # pyre-ignore
+
+
+ShardedGradScaler = ShardedGradScaler
diff --git a/mapper/models/dinov2/hub/__init__.py b/mapper/models/dinov2/hub/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..51eee37fc44057f0438adb43b488f4bef8bd86fc
--- /dev/null
+++ b/mapper/models/dinov2/hub/__init__.py
@@ -0,0 +1,4 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
\ No newline at end of file
diff --git a/mapper/models/dinov2/hub/backbones.py b/mapper/models/dinov2/hub/backbones.py
new file mode 100644
index 0000000000000000000000000000000000000000..1f666fd3de8dbe4717b556af80de558f824b5517
--- /dev/null
+++ b/mapper/models/dinov2/hub/backbones.py
@@ -0,0 +1,156 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+from enum import Enum
+from typing import Union
+
+import torch
+
+from .utils import _DINOV2_BASE_URL, _make_dinov2_model_name
+
+
+class Weights(Enum):
+ LVD142M = "LVD142M"
+
+
+def _make_dinov2_model(
+ *,
+ arch_name: str = "vit_large",
+ img_size: int = 518,
+ patch_size: int = 14,
+ init_values: float = 1.0,
+ ffn_layer: str = "mlp",
+ block_chunks: int = 0,
+ num_register_tokens: int = 0,
+ interpolate_antialias: bool = False,
+ interpolate_offset: float = 0.1,
+ pretrained: bool = True,
+ weights: Union[Weights, str] = Weights.LVD142M,
+ **kwargs,
+):
+ from ..models import vision_transformer as vits
+
+ if isinstance(weights, str):
+ try:
+ weights = Weights[weights]
+ except KeyError:
+ raise AssertionError(f"Unsupported weights: {weights}")
+
+ model_base_name = _make_dinov2_model_name(arch_name, patch_size)
+ vit_kwargs = dict(
+ img_size=img_size,
+ patch_size=patch_size,
+ init_values=init_values,
+ ffn_layer=ffn_layer,
+ block_chunks=block_chunks,
+ num_register_tokens=num_register_tokens,
+ interpolate_antialias=interpolate_antialias,
+ interpolate_offset=interpolate_offset,
+ )
+ vit_kwargs.update(**kwargs)
+ model = vits.__dict__[arch_name](**vit_kwargs)
+
+ if pretrained:
+ model_full_name = _make_dinov2_model_name(arch_name, patch_size, num_register_tokens)
+ url = _DINOV2_BASE_URL + f"/{model_base_name}/{model_full_name}_pretrain.pth"
+ state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu")
+ model.load_state_dict(state_dict, strict=True)
+
+ return model
+
+
+def dinov2_vits14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
+ """
+ DINOv2 ViT-S/14 model (optionally) pretrained on the LVD-142M dataset.
+ """
+ return _make_dinov2_model(arch_name="vit_small", pretrained=pretrained, weights=weights, **kwargs)
+
+
+def dinov2_vitb14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
+ """
+ DINOv2 ViT-B/14 model (optionally) pretrained on the LVD-142M dataset.
+ """
+ return _make_dinov2_model(arch_name="vit_base", pretrained=pretrained, weights=weights, **kwargs)
+
+
+def dinov2_vitl14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
+ """
+ DINOv2 ViT-L/14 model (optionally) pretrained on the LVD-142M dataset.
+ """
+ return _make_dinov2_model(arch_name="vit_large", pretrained=pretrained, weights=weights, **kwargs)
+
+
+def dinov2_vitg14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
+ """
+ DINOv2 ViT-g/14 model (optionally) pretrained on the LVD-142M dataset.
+ """
+ return _make_dinov2_model(
+ arch_name="vit_giant2",
+ ffn_layer="swiglufused",
+ weights=weights,
+ pretrained=pretrained,
+ **kwargs,
+ )
+
+
+def dinov2_vits14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
+ """
+ DINOv2 ViT-S/14 model with registers (optionally) pretrained on the LVD-142M dataset.
+ """
+ return _make_dinov2_model(
+ arch_name="vit_small",
+ pretrained=pretrained,
+ weights=weights,
+ num_register_tokens=4,
+ interpolate_antialias=True,
+ interpolate_offset=0.0,
+ **kwargs,
+ )
+
+
+def dinov2_vitb14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
+ """
+ DINOv2 ViT-B/14 model with registers (optionally) pretrained on the LVD-142M dataset.
+ """
+ return _make_dinov2_model(
+ arch_name="vit_base",
+ pretrained=pretrained,
+ weights=weights,
+ num_register_tokens=4,
+ interpolate_antialias=True,
+ interpolate_offset=0.0,
+ **kwargs,
+ )
+
+
+def dinov2_vitl14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
+ """
+ DINOv2 ViT-L/14 model with registers (optionally) pretrained on the LVD-142M dataset.
+ """
+ return _make_dinov2_model(
+ arch_name="vit_large",
+ pretrained=pretrained,
+ weights=weights,
+ num_register_tokens=4,
+ interpolate_antialias=True,
+ interpolate_offset=0.0,
+ **kwargs,
+ )
+
+
+def dinov2_vitg14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
+ """
+ DINOv2 ViT-g/14 model with registers (optionally) pretrained on the LVD-142M dataset.
+ """
+ return _make_dinov2_model(
+ arch_name="vit_giant2",
+ ffn_layer="swiglufused",
+ weights=weights,
+ pretrained=pretrained,
+ num_register_tokens=4,
+ interpolate_antialias=True,
+ interpolate_offset=0.0,
+ **kwargs,
+ )
\ No newline at end of file
diff --git a/mapper/models/dinov2/hub/classifiers.py b/mapper/models/dinov2/hub/classifiers.py
new file mode 100644
index 0000000000000000000000000000000000000000..ebbe2337ec0119bd2cb091656544e45b3be6d9bf
--- /dev/null
+++ b/mapper/models/dinov2/hub/classifiers.py
@@ -0,0 +1,268 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+from enum import Enum
+from typing import Union
+
+import torch
+import torch.nn as nn
+
+from .backbones import _make_dinov2_model
+from .utils import _DINOV2_BASE_URL, _make_dinov2_model_name
+
+
+class Weights(Enum):
+ IMAGENET1K = "IMAGENET1K"
+
+
+def _make_dinov2_linear_classification_head(
+ *,
+ arch_name: str = "vit_large",
+ patch_size: int = 14,
+ embed_dim: int = 1024,
+ layers: int = 4,
+ pretrained: bool = True,
+ weights: Union[Weights, str] = Weights.IMAGENET1K,
+ num_register_tokens: int = 0,
+ **kwargs,
+):
+ if layers not in (1, 4):
+ raise AssertionError(f"Unsupported number of layers: {layers}")
+ if isinstance(weights, str):
+ try:
+ weights = Weights[weights]
+ except KeyError:
+ raise AssertionError(f"Unsupported weights: {weights}")
+
+ linear_head = nn.Linear((1 + layers) * embed_dim, 1_000)
+
+ if pretrained:
+ model_base_name = _make_dinov2_model_name(arch_name, patch_size)
+ model_full_name = _make_dinov2_model_name(arch_name, patch_size, num_register_tokens)
+ layers_str = str(layers) if layers == 4 else ""
+ url = _DINOV2_BASE_URL + f"/{model_base_name}/{model_full_name}_linear{layers_str}_head.pth"
+ state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu")
+ linear_head.load_state_dict(state_dict, strict=True)
+
+ return linear_head
+
+
+class _LinearClassifierWrapper(nn.Module):
+ def __init__(self, *, backbone: nn.Module, linear_head: nn.Module, layers: int = 4):
+ super().__init__()
+ self.backbone = backbone
+ self.linear_head = linear_head
+ self.layers = layers
+
+ def forward(self, x):
+ if self.layers == 1:
+ x = self.backbone.forward_features(x)
+ cls_token = x["x_norm_clstoken"]
+ patch_tokens = x["x_norm_patchtokens"]
+ # fmt: off
+ linear_input = torch.cat([
+ cls_token,
+ patch_tokens.mean(dim=1),
+ ], dim=1)
+ # fmt: on
+ elif self.layers == 4:
+ x = self.backbone.get_intermediate_layers(x, n=4, return_class_token=True)
+ # fmt: off
+ linear_input = torch.cat([
+ x[0][1],
+ x[1][1],
+ x[2][1],
+ x[3][1],
+ x[3][0].mean(dim=1),
+ ], dim=1)
+ # fmt: on
+ else:
+ assert False, f"Unsupported number of layers: {self.layers}"
+ return self.linear_head(linear_input)
+
+
+def _make_dinov2_linear_classifier(
+ *,
+ arch_name: str = "vit_large",
+ layers: int = 4,
+ pretrained: bool = True,
+ weights: Union[Weights, str] = Weights.IMAGENET1K,
+ num_register_tokens: int = 0,
+ interpolate_antialias: bool = False,
+ interpolate_offset: float = 0.1,
+ **kwargs,
+):
+ backbone = _make_dinov2_model(
+ arch_name=arch_name,
+ pretrained=pretrained,
+ num_register_tokens=num_register_tokens,
+ interpolate_antialias=interpolate_antialias,
+ interpolate_offset=interpolate_offset,
+ **kwargs,
+ )
+
+ embed_dim = backbone.embed_dim
+ patch_size = backbone.patch_size
+ linear_head = _make_dinov2_linear_classification_head(
+ arch_name=arch_name,
+ patch_size=patch_size,
+ embed_dim=embed_dim,
+ layers=layers,
+ pretrained=pretrained,
+ weights=weights,
+ num_register_tokens=num_register_tokens,
+ )
+
+ return _LinearClassifierWrapper(backbone=backbone, linear_head=linear_head, layers=layers)
+
+
+def dinov2_vits14_lc(
+ *,
+ layers: int = 4,
+ pretrained: bool = True,
+ weights: Union[Weights, str] = Weights.IMAGENET1K,
+ **kwargs,
+):
+ """
+ Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-S/14 backbone (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k.
+ """
+ return _make_dinov2_linear_classifier(
+ arch_name="vit_small",
+ layers=layers,
+ pretrained=pretrained,
+ weights=weights,
+ **kwargs,
+ )
+
+
+def dinov2_vitb14_lc(
+ *,
+ layers: int = 4,
+ pretrained: bool = True,
+ weights: Union[Weights, str] = Weights.IMAGENET1K,
+ **kwargs,
+):
+ """
+ Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-B/14 backbone (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k.
+ """
+ return _make_dinov2_linear_classifier(
+ arch_name="vit_base",
+ layers=layers,
+ pretrained=pretrained,
+ weights=weights,
+ **kwargs,
+ )
+
+
+def dinov2_vitl14_lc(
+ *,
+ layers: int = 4,
+ pretrained: bool = True,
+ weights: Union[Weights, str] = Weights.IMAGENET1K,
+ **kwargs,
+):
+ """
+ Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-L/14 backbone (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k.
+ """
+ return _make_dinov2_linear_classifier(
+ arch_name="vit_large",
+ layers=layers,
+ pretrained=pretrained,
+ weights=weights,
+ **kwargs,
+ )
+
+
+def dinov2_vitg14_lc(
+ *,
+ layers: int = 4,
+ pretrained: bool = True,
+ weights: Union[Weights, str] = Weights.IMAGENET1K,
+ **kwargs,
+):
+ """
+ Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-g/14 backbone (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k.
+ """
+ return _make_dinov2_linear_classifier(
+ arch_name="vit_giant2",
+ layers=layers,
+ ffn_layer="swiglufused",
+ pretrained=pretrained,
+ weights=weights,
+ **kwargs,
+ )
+
+
+def dinov2_vits14_reg_lc(
+ *, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.IMAGENET1K, **kwargs
+):
+ """
+ Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-S/14 backbone with registers (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k.
+ """
+ return _make_dinov2_linear_classifier(
+ arch_name="vit_small",
+ layers=layers,
+ pretrained=pretrained,
+ weights=weights,
+ num_register_tokens=4,
+ interpolate_antialias=True,
+ interpolate_offset=0.0,
+ **kwargs,
+ )
+
+
+def dinov2_vitb14_reg_lc(
+ *, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.IMAGENET1K, **kwargs
+):
+ """
+ Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-B/14 backbone with registers (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k.
+ """
+ return _make_dinov2_linear_classifier(
+ arch_name="vit_base",
+ layers=layers,
+ pretrained=pretrained,
+ weights=weights,
+ num_register_tokens=4,
+ interpolate_antialias=True,
+ interpolate_offset=0.0,
+ **kwargs,
+ )
+
+
+def dinov2_vitl14_reg_lc(
+ *, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.IMAGENET1K, **kwargs
+):
+ """
+ Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-L/14 backbone with registers (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k.
+ """
+ return _make_dinov2_linear_classifier(
+ arch_name="vit_large",
+ layers=layers,
+ pretrained=pretrained,
+ weights=weights,
+ num_register_tokens=4,
+ interpolate_antialias=True,
+ interpolate_offset=0.0,
+ **kwargs,
+ )
+
+
+def dinov2_vitg14_reg_lc(
+ *, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.IMAGENET1K, **kwargs
+):
+ """
+ Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-g/14 backbone with registers (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k.
+ """
+ return _make_dinov2_linear_classifier(
+ arch_name="vit_giant2",
+ layers=layers,
+ ffn_layer="swiglufused",
+ pretrained=pretrained,
+ weights=weights,
+ num_register_tokens=4,
+ interpolate_antialias=True,
+ interpolate_offset=0.0,
+ **kwargs,
+ )
\ No newline at end of file
diff --git a/mapper/models/dinov2/hub/utils.py b/mapper/models/dinov2/hub/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..d4afb84c777d8e7ce51f8ff262dbea2c7fa06726
--- /dev/null
+++ b/mapper/models/dinov2/hub/utils.py
@@ -0,0 +1,39 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+import itertools
+import math
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+_DINOV2_BASE_URL = "https://dl.fbaipublicfiles.com/dinov2"
+
+
+def _make_dinov2_model_name(arch_name: str, patch_size: int, num_register_tokens: int = 0) -> str:
+ compact_arch_name = arch_name.replace("_", "")[:4]
+ registers_suffix = f"_reg{num_register_tokens}" if num_register_tokens else ""
+ return f"dinov2_{compact_arch_name}{patch_size}{registers_suffix}"
+
+
+class CenterPadding(nn.Module):
+ def __init__(self, multiple):
+ super().__init__()
+ self.multiple = multiple
+
+ def _get_pad(self, size):
+ new_size = math.ceil(size / self.multiple) * self.multiple
+ pad_size = new_size - size
+ pad_size_left = pad_size // 2
+ pad_size_right = pad_size - pad_size_left
+ return pad_size_left, pad_size_right
+
+ @torch.inference_mode()
+ def forward(self, x):
+ pads = list(itertools.chain.from_iterable(self._get_pad(m) for m in x.shape[:1:-1]))
+ output = F.pad(x, pads)
+ return output
\ No newline at end of file
diff --git a/mapper/models/dinov2/layers/__init__.py b/mapper/models/dinov2/layers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..05a0b61868e43abb821ca05a813bab2b8b43629e
--- /dev/null
+++ b/mapper/models/dinov2/layers/__init__.py
@@ -0,0 +1,11 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+from .dino_head import DINOHead
+from .mlp import Mlp
+from .patch_embed import PatchEmbed
+from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
+from .block import NestedTensorBlock
+from .attention import MemEffAttention
diff --git a/mapper/models/dinov2/layers/attention.py b/mapper/models/dinov2/layers/attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..0fb76ef2816164729a58cceb18d0f000cfb18777
--- /dev/null
+++ b/mapper/models/dinov2/layers/attention.py
@@ -0,0 +1,89 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
+
+import logging
+import os
+import warnings
+
+from torch import Tensor
+from torch import nn
+
+
+logger = logging.getLogger("dinov2")
+
+
+XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
+try:
+ if XFORMERS_ENABLED:
+ from xformers.ops import memory_efficient_attention, unbind
+
+ XFORMERS_AVAILABLE = True
+ warnings.warn("xFormers is available (Attention)")
+ else:
+ warnings.warn("xFormers is disabled (Attention)")
+ raise ImportError
+except ImportError:
+ XFORMERS_AVAILABLE = False
+ warnings.warn("xFormers is not available (Attention)")
+
+
+class Attention(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int = 8,
+ qkv_bias: bool = False,
+ proj_bias: bool = True,
+ attn_drop: float = 0.0,
+ proj_drop: float = 0.0,
+ ) -> None:
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = head_dim**-0.5
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim, bias=proj_bias)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x: Tensor) -> Tensor:
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+
+ q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
+ attn = q @ k.transpose(-2, -1)
+
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class MemEffAttention(Attention):
+ def forward(self, x: Tensor, attn_bias=None) -> Tensor:
+ if not XFORMERS_AVAILABLE:
+ if attn_bias is not None:
+ raise AssertionError("xFormers is required for using nested tensors")
+ return super().forward(x)
+
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
+
+ q, k, v = unbind(qkv, 2)
+
+ x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
+ x = x.reshape([B, N, C])
+
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
diff --git a/mapper/models/dinov2/layers/block.py b/mapper/models/dinov2/layers/block.py
new file mode 100644
index 0000000000000000000000000000000000000000..930787b262faac4f2264797496faff75ac56b7cc
--- /dev/null
+++ b/mapper/models/dinov2/layers/block.py
@@ -0,0 +1,260 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
+
+import logging
+import os
+from typing import Callable, List, Any, Tuple, Dict
+import warnings
+
+import torch
+from torch import nn, Tensor
+
+from .attention import Attention, MemEffAttention
+from .drop_path import DropPath
+from .layer_scale import LayerScale
+from .mlp import Mlp
+
+
+logger = logging.getLogger("dinov2")
+
+
+XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
+try:
+ if XFORMERS_ENABLED:
+ from xformers.ops import fmha, scaled_index_add, index_select_cat
+
+ XFORMERS_AVAILABLE = True
+ warnings.warn("xFormers is available (Block)")
+ else:
+ warnings.warn("xFormers is disabled (Block)")
+ raise ImportError
+except ImportError:
+ XFORMERS_AVAILABLE = False
+
+ warnings.warn("xFormers is not available (Block)")
+
+
+class Block(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int,
+ mlp_ratio: float = 4.0,
+ qkv_bias: bool = False,
+ proj_bias: bool = True,
+ ffn_bias: bool = True,
+ drop: float = 0.0,
+ attn_drop: float = 0.0,
+ init_values=None,
+ drop_path: float = 0.0,
+ act_layer: Callable[..., nn.Module] = nn.GELU,
+ norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
+ attn_class: Callable[..., nn.Module] = Attention,
+ ffn_layer: Callable[..., nn.Module] = Mlp,
+ ) -> None:
+ super().__init__()
+ # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
+ self.norm1 = norm_layer(dim)
+ self.attn = attn_class(
+ dim,
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ proj_bias=proj_bias,
+ attn_drop=attn_drop,
+ proj_drop=drop,
+ )
+ self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = ffn_layer(
+ in_features=dim,
+ hidden_features=mlp_hidden_dim,
+ act_layer=act_layer,
+ drop=drop,
+ bias=ffn_bias,
+ )
+ self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+
+ self.sample_drop_ratio = drop_path
+
+ def forward(self, x: Tensor) -> Tensor:
+ def attn_residual_func(x: Tensor) -> Tensor:
+ return self.ls1(self.attn(self.norm1(x)))
+
+ def ffn_residual_func(x: Tensor) -> Tensor:
+ return self.ls2(self.mlp(self.norm2(x)))
+
+ if self.training and self.sample_drop_ratio > 0.1:
+ # the overhead is compensated only for a drop path rate larger than 0.1
+ x = drop_add_residual_stochastic_depth(
+ x,
+ residual_func=attn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ )
+ x = drop_add_residual_stochastic_depth(
+ x,
+ residual_func=ffn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ )
+ elif self.training and self.sample_drop_ratio > 0.0:
+ x = x + self.drop_path1(attn_residual_func(x))
+ x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
+ else:
+ x = x + attn_residual_func(x)
+ x = x + ffn_residual_func(x)
+ return x
+
+
+def drop_add_residual_stochastic_depth(
+ x: Tensor,
+ residual_func: Callable[[Tensor], Tensor],
+ sample_drop_ratio: float = 0.0,
+) -> Tensor:
+ # 1) extract subset using permutation
+ b, n, d = x.shape
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
+ x_subset = x[brange]
+
+ # 2) apply residual_func to get residual
+ residual = residual_func(x_subset)
+
+ x_flat = x.flatten(1)
+ residual = residual.flatten(1)
+
+ residual_scale_factor = b / sample_subset_size
+
+ # 3) add the residual
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
+ return x_plus_residual.view_as(x)
+
+
+def get_branges_scales(x, sample_drop_ratio=0.0):
+ b, n, d = x.shape
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
+ residual_scale_factor = b / sample_subset_size
+ return brange, residual_scale_factor
+
+
+def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
+ if scaling_vector is None:
+ x_flat = x.flatten(1)
+ residual = residual.flatten(1)
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
+ else:
+ x_plus_residual = scaled_index_add(
+ x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
+ )
+ return x_plus_residual
+
+
+attn_bias_cache: Dict[Tuple, Any] = {}
+
+
+def get_attn_bias_and_cat(x_list, branges=None):
+ """
+ this will perform the index select, cat the tensors, and provide the attn_bias from cache
+ """
+ batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
+ all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
+ if all_shapes not in attn_bias_cache.keys():
+ seqlens = []
+ for b, x in zip(batch_sizes, x_list):
+ for _ in range(b):
+ seqlens.append(x.shape[1])
+ attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
+ attn_bias._batch_sizes = batch_sizes
+ attn_bias_cache[all_shapes] = attn_bias
+
+ if branges is not None:
+ cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
+ else:
+ tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
+ cat_tensors = torch.cat(tensors_bs1, dim=1)
+
+ return attn_bias_cache[all_shapes], cat_tensors
+
+
+def drop_add_residual_stochastic_depth_list(
+ x_list: List[Tensor],
+ residual_func: Callable[[Tensor, Any], Tensor],
+ sample_drop_ratio: float = 0.0,
+ scaling_vector=None,
+) -> Tensor:
+ # 1) generate random set of indices for dropping samples in the batch
+ branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
+ branges = [s[0] for s in branges_scales]
+ residual_scale_factors = [s[1] for s in branges_scales]
+
+ # 2) get attention bias and index+concat the tensors
+ attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
+
+ # 3) apply residual_func to get residual, and split the result
+ residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
+
+ outputs = []
+ for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
+ outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
+ return outputs
+
+
+class NestedTensorBlock(Block):
+ def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
+ """
+ x_list contains a list of tensors to nest together and run
+ """
+ assert isinstance(self.attn, MemEffAttention)
+
+ if self.training and self.sample_drop_ratio > 0.0:
+
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.attn(self.norm1(x), attn_bias=attn_bias)
+
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.mlp(self.norm2(x))
+
+ x_list = drop_add_residual_stochastic_depth_list(
+ x_list,
+ residual_func=attn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
+ )
+ x_list = drop_add_residual_stochastic_depth_list(
+ x_list,
+ residual_func=ffn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
+ )
+ return x_list
+ else:
+
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
+
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.ls2(self.mlp(self.norm2(x)))
+
+ attn_bias, x = get_attn_bias_and_cat(x_list)
+ x = x + attn_residual_func(x, attn_bias=attn_bias)
+ x = x + ffn_residual_func(x)
+ return attn_bias.split(x)
+
+ def forward(self, x_or_x_list):
+ if isinstance(x_or_x_list, Tensor):
+ return super().forward(x_or_x_list)
+ elif isinstance(x_or_x_list, list):
+ if not XFORMERS_AVAILABLE:
+ raise AssertionError("xFormers is required for using nested tensors")
+ return self.forward_nested(x_or_x_list)
+ else:
+ raise AssertionError
diff --git a/mapper/models/dinov2/layers/dino_head.py b/mapper/models/dinov2/layers/dino_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..0ace8ffd6297a1dd480b19db407b662a6ea0f565
--- /dev/null
+++ b/mapper/models/dinov2/layers/dino_head.py
@@ -0,0 +1,58 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn as nn
+from torch.nn.init import trunc_normal_
+from torch.nn.utils import weight_norm
+
+
+class DINOHead(nn.Module):
+ def __init__(
+ self,
+ in_dim,
+ out_dim,
+ use_bn=False,
+ nlayers=3,
+ hidden_dim=2048,
+ bottleneck_dim=256,
+ mlp_bias=True,
+ ):
+ super().__init__()
+ nlayers = max(nlayers, 1)
+ self.mlp = _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=hidden_dim, use_bn=use_bn, bias=mlp_bias)
+ self.apply(self._init_weights)
+ self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
+ self.last_layer.weight_g.data.fill_(1)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=0.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+
+ def forward(self, x):
+ x = self.mlp(x)
+ eps = 1e-6 if x.dtype == torch.float16 else 1e-12
+ x = nn.functional.normalize(x, dim=-1, p=2, eps=eps)
+ x = self.last_layer(x)
+ return x
+
+
+def _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True):
+ if nlayers == 1:
+ return nn.Linear(in_dim, bottleneck_dim, bias=bias)
+ else:
+ layers = [nn.Linear(in_dim, hidden_dim, bias=bias)]
+ if use_bn:
+ layers.append(nn.BatchNorm1d(hidden_dim))
+ layers.append(nn.GELU())
+ for _ in range(nlayers - 2):
+ layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias))
+ if use_bn:
+ layers.append(nn.BatchNorm1d(hidden_dim))
+ layers.append(nn.GELU())
+ layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias))
+ return nn.Sequential(*layers)
diff --git a/mapper/models/dinov2/layers/drop_path.py b/mapper/models/dinov2/layers/drop_path.py
new file mode 100644
index 0000000000000000000000000000000000000000..1d640e0b969b8dcba96260243473700b4e5b24b5
--- /dev/null
+++ b/mapper/models/dinov2/layers/drop_path.py
@@ -0,0 +1,34 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
+
+
+from torch import nn
+
+
+def drop_path(x, drop_prob: float = 0.0, training: bool = False):
+ if drop_prob == 0.0 or not training:
+ return x
+ keep_prob = 1 - drop_prob
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
+ if keep_prob > 0.0:
+ random_tensor.div_(keep_prob)
+ output = x * random_tensor
+ return output
+
+
+class DropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
+
+ def __init__(self, drop_prob=None):
+ super(DropPath, self).__init__()
+ self.drop_prob = drop_prob
+
+ def forward(self, x):
+ return drop_path(x, self.drop_prob, self.training)
diff --git a/mapper/models/dinov2/layers/layer_scale.py b/mapper/models/dinov2/layers/layer_scale.py
new file mode 100644
index 0000000000000000000000000000000000000000..51df0d7ce61f2b41fa9e6369f52391dd7fe7d386
--- /dev/null
+++ b/mapper/models/dinov2/layers/layer_scale.py
@@ -0,0 +1,27 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
+
+from typing import Union
+
+import torch
+from torch import Tensor
+from torch import nn
+
+
+class LayerScale(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ init_values: Union[float, Tensor] = 1e-5,
+ inplace: bool = False,
+ ) -> None:
+ super().__init__()
+ self.inplace = inplace
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
+
+ def forward(self, x: Tensor) -> Tensor:
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
diff --git a/mapper/models/dinov2/layers/mlp.py b/mapper/models/dinov2/layers/mlp.py
new file mode 100644
index 0000000000000000000000000000000000000000..bbf9432aae9258612caeae910a7bde17999e328e
--- /dev/null
+++ b/mapper/models/dinov2/layers/mlp.py
@@ -0,0 +1,40 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
+
+
+from typing import Callable, Optional
+
+from torch import Tensor, nn
+
+
+class Mlp(nn.Module):
+ def __init__(
+ self,
+ in_features: int,
+ hidden_features: Optional[int] = None,
+ out_features: Optional[int] = None,
+ act_layer: Callable[..., nn.Module] = nn.GELU,
+ drop: float = 0.0,
+ bias: bool = True,
+ ) -> None:
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x: Tensor) -> Tensor:
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
diff --git a/mapper/models/dinov2/layers/patch_embed.py b/mapper/models/dinov2/layers/patch_embed.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b7c0804784a42cf80c0297d110dcc68cc85b339
--- /dev/null
+++ b/mapper/models/dinov2/layers/patch_embed.py
@@ -0,0 +1,88 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
+
+from typing import Callable, Optional, Tuple, Union
+
+from torch import Tensor
+import torch.nn as nn
+
+
+def make_2tuple(x):
+ if isinstance(x, tuple):
+ assert len(x) == 2
+ return x
+
+ assert isinstance(x, int)
+ return (x, x)
+
+
+class PatchEmbed(nn.Module):
+ """
+ 2D image to patch embedding: (B,C,H,W) -> (B,N,D)
+
+ Args:
+ img_size: Image size.
+ patch_size: Patch token size.
+ in_chans: Number of input image channels.
+ embed_dim: Number of linear projection output channels.
+ norm_layer: Normalization layer.
+ """
+
+ def __init__(
+ self,
+ img_size: Union[int, Tuple[int, int]] = 224,
+ patch_size: Union[int, Tuple[int, int]] = 16,
+ in_chans: int = 3,
+ embed_dim: int = 768,
+ norm_layer: Optional[Callable] = None,
+ flatten_embedding: bool = True,
+ ) -> None:
+ super().__init__()
+
+ image_HW = make_2tuple(img_size)
+ patch_HW = make_2tuple(patch_size)
+ patch_grid_size = (
+ image_HW[0] // patch_HW[0],
+ image_HW[1] // patch_HW[1],
+ )
+
+ self.img_size = image_HW
+ self.patch_size = patch_HW
+ self.patches_resolution = patch_grid_size
+ self.num_patches = patch_grid_size[0] * patch_grid_size[1]
+
+ self.in_chans = in_chans
+ self.embed_dim = embed_dim
+
+ self.flatten_embedding = flatten_embedding
+
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
+
+ def forward(self, x: Tensor) -> Tensor:
+ _, _, H, W = x.shape
+ patch_H, patch_W = self.patch_size
+
+ assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
+ assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
+
+ x = self.proj(x) # B C H W
+ H, W = x.size(2), x.size(3)
+ x = x.flatten(2).transpose(1, 2) # B HW C
+ x = self.norm(x)
+ if not self.flatten_embedding:
+ x = x.reshape(-1, H, W, self.embed_dim) # B H W C
+ return x
+
+ def flops(self) -> float:
+ Ho, Wo = self.patches_resolution
+ flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
+ if self.norm is not None:
+ flops += Ho * Wo * self.embed_dim
+ return flops
diff --git a/mapper/models/dinov2/layers/swiglu_ffn.py b/mapper/models/dinov2/layers/swiglu_ffn.py
new file mode 100644
index 0000000000000000000000000000000000000000..5e9dafa4592a408f6874d54853e8f60db5c41f74
--- /dev/null
+++ b/mapper/models/dinov2/layers/swiglu_ffn.py
@@ -0,0 +1,72 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+import os
+from typing import Callable, Optional
+import warnings
+
+from torch import Tensor, nn
+import torch.nn.functional as F
+
+
+class SwiGLUFFN(nn.Module):
+ def __init__(
+ self,
+ in_features: int,
+ hidden_features: Optional[int] = None,
+ out_features: Optional[int] = None,
+ act_layer: Callable[..., nn.Module] = None,
+ drop: float = 0.0,
+ bias: bool = True,
+ ) -> None:
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
+ self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
+
+ def forward(self, x: Tensor) -> Tensor:
+ x12 = self.w12(x)
+ x1, x2 = x12.chunk(2, dim=-1)
+ hidden = F.silu(x1) * x2
+ return self.w3(hidden)
+
+
+XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
+try:
+ if XFORMERS_ENABLED:
+ from xformers.ops import SwiGLU
+
+ XFORMERS_AVAILABLE = True
+ warnings.warn("xFormers is available (SwiGLU)")
+ else:
+ warnings.warn("xFormers is disabled (SwiGLU)")
+ raise ImportError
+except ImportError:
+ SwiGLU = SwiGLUFFN
+ XFORMERS_AVAILABLE = False
+
+ warnings.warn("xFormers is not available (SwiGLU)")
+
+
+class SwiGLUFFNFused(SwiGLU):
+ def __init__(
+ self,
+ in_features: int,
+ hidden_features: Optional[int] = None,
+ out_features: Optional[int] = None,
+ act_layer: Callable[..., nn.Module] = None,
+ drop: float = 0.0,
+ bias: bool = True,
+ ) -> None:
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
+ super().__init__(
+ in_features=in_features,
+ hidden_features=hidden_features,
+ out_features=out_features,
+ bias=bias,
+ )
diff --git a/mapper/models/dinov2/logging/__init__.py b/mapper/models/dinov2/logging/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..04a7f02204316d4d1ef38bf6080dae3d66241c25
--- /dev/null
+++ b/mapper/models/dinov2/logging/__init__.py
@@ -0,0 +1,102 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+import functools
+import logging
+import os
+import sys
+from typing import Optional
+
+import dinov2.distributed as distributed
+from .helpers import MetricLogger, SmoothedValue
+
+
+# So that calling _configure_logger multiple times won't add many handlers
+@functools.lru_cache()
+def _configure_logger(
+ name: Optional[str] = None,
+ *,
+ level: int = logging.DEBUG,
+ output: Optional[str] = None,
+):
+ """
+ Configure a logger.
+
+ Adapted from Detectron2.
+
+ Args:
+ name: The name of the logger to configure.
+ level: The logging level to use.
+ output: A file name or a directory to save log. If None, will not save log file.
+ If ends with ".txt" or ".log", assumed to be a file name.
+ Otherwise, logs will be saved to `output/log.txt`.
+
+ Returns:
+ The configured logger.
+ """
+
+ logger = logging.getLogger(name)
+ logger.setLevel(level)
+ logger.propagate = False
+
+ # Loosely match Google glog format:
+ # [IWEF]yyyymmdd hh:mm:ss.uuuuuu threadid file:line] msg
+ # but use a shorter timestamp and include the logger name:
+ # [IWEF]yyyymmdd hh:mm:ss logger threadid file:line] msg
+ fmt_prefix = "%(levelname).1s%(asctime)s %(process)s %(name)s %(filename)s:%(lineno)s] "
+ fmt_message = "%(message)s"
+ fmt = fmt_prefix + fmt_message
+ datefmt = "%Y%m%d %H:%M:%S"
+ formatter = logging.Formatter(fmt=fmt, datefmt=datefmt)
+
+ # stdout logging for main worker only
+ if distributed.is_main_process():
+ handler = logging.StreamHandler(stream=sys.stdout)
+ handler.setLevel(logging.DEBUG)
+ handler.setFormatter(formatter)
+ logger.addHandler(handler)
+
+ # file logging for all workers
+ if output:
+ if os.path.splitext(output)[-1] in (".txt", ".log"):
+ filename = output
+ else:
+ filename = os.path.join(output, "logs", "log.txt")
+
+ if not distributed.is_main_process():
+ global_rank = distributed.get_global_rank()
+ filename = filename + ".rank{}".format(global_rank)
+
+ os.makedirs(os.path.dirname(filename), exist_ok=True)
+
+ handler = logging.StreamHandler(open(filename, "a"))
+ handler.setLevel(logging.DEBUG)
+ handler.setFormatter(formatter)
+ logger.addHandler(handler)
+
+ return logger
+
+
+def setup_logging(
+ output: Optional[str] = None,
+ *,
+ name: Optional[str] = None,
+ level: int = logging.DEBUG,
+ capture_warnings: bool = True,
+) -> None:
+ """
+ Setup logging.
+
+ Args:
+ output: A file name or a directory to save log files. If None, log
+ files will not be saved. If output ends with ".txt" or ".log", it
+ is assumed to be a file name.
+ Otherwise, logs will be saved to `output/log.txt`.
+ name: The name of the logger to configure, by default the root logger.
+ level: The logging level to use.
+ capture_warnings: Whether warnings should be captured as logs.
+ """
+ logging.captureWarnings(capture_warnings)
+ _configure_logger(name, level=level, output=output)
diff --git a/mapper/models/dinov2/logging/helpers.py b/mapper/models/dinov2/logging/helpers.py
new file mode 100644
index 0000000000000000000000000000000000000000..c6e70bb15505cbbc4c4732b069ee919bf921a74f
--- /dev/null
+++ b/mapper/models/dinov2/logging/helpers.py
@@ -0,0 +1,194 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+from collections import defaultdict, deque
+import datetime
+import json
+import logging
+import time
+
+import torch
+
+import dinov2.distributed as distributed
+
+
+logger = logging.getLogger("dinov2")
+
+
+class MetricLogger(object):
+ def __init__(self, delimiter="\t", output_file=None):
+ self.meters = defaultdict(SmoothedValue)
+ self.delimiter = delimiter
+ self.output_file = output_file
+
+ def update(self, **kwargs):
+ for k, v in kwargs.items():
+ if isinstance(v, torch.Tensor):
+ v = v.item()
+ assert isinstance(v, (float, int))
+ self.meters[k].update(v)
+
+ def __getattr__(self, attr):
+ if attr in self.meters:
+ return self.meters[attr]
+ if attr in self.__dict__:
+ return self.__dict__[attr]
+ raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, attr))
+
+ def __str__(self):
+ loss_str = []
+ for name, meter in self.meters.items():
+ loss_str.append("{}: {}".format(name, str(meter)))
+ return self.delimiter.join(loss_str)
+
+ def synchronize_between_processes(self):
+ for meter in self.meters.values():
+ meter.synchronize_between_processes()
+
+ def add_meter(self, name, meter):
+ self.meters[name] = meter
+
+ def dump_in_output_file(self, iteration, iter_time, data_time):
+ if self.output_file is None or not distributed.is_main_process():
+ return
+ dict_to_dump = dict(
+ iteration=iteration,
+ iter_time=iter_time,
+ data_time=data_time,
+ )
+ dict_to_dump.update({k: v.median for k, v in self.meters.items()})
+ with open(self.output_file, "a") as f:
+ f.write(json.dumps(dict_to_dump) + "\n")
+ pass
+
+ def log_every(self, iterable, print_freq, header=None, n_iterations=None, start_iteration=0):
+ i = start_iteration
+ if not header:
+ header = ""
+ start_time = time.time()
+ end = time.time()
+ iter_time = SmoothedValue(fmt="{avg:.6f}")
+ data_time = SmoothedValue(fmt="{avg:.6f}")
+
+ if n_iterations is None:
+ n_iterations = len(iterable)
+
+ space_fmt = ":" + str(len(str(n_iterations))) + "d"
+
+ log_list = [
+ header,
+ "[{0" + space_fmt + "}/{1}]",
+ "eta: {eta}",
+ "{meters}",
+ "time: {time}",
+ "data: {data}",
+ ]
+ if torch.cuda.is_available():
+ log_list += ["max mem: {memory:.0f}"]
+
+ log_msg = self.delimiter.join(log_list)
+ MB = 1024.0 * 1024.0
+ for obj in iterable:
+ data_time.update(time.time() - end)
+ yield obj
+ iter_time.update(time.time() - end)
+ if i % print_freq == 0 or i == n_iterations - 1:
+ self.dump_in_output_file(iteration=i, iter_time=iter_time.avg, data_time=data_time.avg)
+ eta_seconds = iter_time.global_avg * (n_iterations - i)
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
+ if torch.cuda.is_available():
+ logger.info(
+ log_msg.format(
+ i,
+ n_iterations,
+ eta=eta_string,
+ meters=str(self),
+ time=str(iter_time),
+ data=str(data_time),
+ memory=torch.cuda.max_memory_allocated() / MB,
+ )
+ )
+ else:
+ logger.info(
+ log_msg.format(
+ i,
+ n_iterations,
+ eta=eta_string,
+ meters=str(self),
+ time=str(iter_time),
+ data=str(data_time),
+ )
+ )
+ i += 1
+ end = time.time()
+ if i >= n_iterations:
+ break
+ total_time = time.time() - start_time
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
+ logger.info("{} Total time: {} ({:.6f} s / it)".format(header, total_time_str, total_time / n_iterations))
+
+
+class SmoothedValue:
+ """Track a series of values and provide access to smoothed values over a
+ window or the global series average.
+ """
+
+ def __init__(self, window_size=20, fmt=None):
+ if fmt is None:
+ fmt = "{median:.4f} ({global_avg:.4f})"
+ self.deque = deque(maxlen=window_size)
+ self.total = 0.0
+ self.count = 0
+ self.fmt = fmt
+
+ def update(self, value, num=1):
+ self.deque.append(value)
+ self.count += num
+ self.total += value * num
+
+ def synchronize_between_processes(self):
+ """
+ Distributed synchronization of the metric
+ Warning: does not synchronize the deque!
+ """
+ if not distributed.is_enabled():
+ return
+ t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda")
+ torch.distributed.barrier()
+ torch.distributed.all_reduce(t)
+ t = t.tolist()
+ self.count = int(t[0])
+ self.total = t[1]
+
+ @property
+ def median(self):
+ d = torch.tensor(list(self.deque))
+ return d.median().item()
+
+ @property
+ def avg(self):
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
+ return d.mean().item()
+
+ @property
+ def global_avg(self):
+ return self.total / self.count
+
+ @property
+ def max(self):
+ return max(self.deque)
+
+ @property
+ def value(self):
+ return self.deque[-1]
+
+ def __str__(self):
+ return self.fmt.format(
+ median=self.median,
+ avg=self.avg,
+ global_avg=self.global_avg,
+ max=self.max,
+ value=self.value,
+ )
diff --git a/mapper/models/dinov2/loss/__init__.py b/mapper/models/dinov2/loss/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d6b0115b74edbd74b324c9056a57fade363c58fd
--- /dev/null
+++ b/mapper/models/dinov2/loss/__init__.py
@@ -0,0 +1,8 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+from .dino_clstoken_loss import DINOLoss
+from .ibot_patch_loss import iBOTPatchLoss
+from .koleo_loss import KoLeoLoss
diff --git a/mapper/models/dinov2/loss/dino_clstoken_loss.py b/mapper/models/dinov2/loss/dino_clstoken_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..c31808e36e6c38ee6dae13ba0443bf1946242117
--- /dev/null
+++ b/mapper/models/dinov2/loss/dino_clstoken_loss.py
@@ -0,0 +1,99 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.distributed as dist
+import torch.nn.functional as F
+from torch import nn
+
+
+class DINOLoss(nn.Module):
+ def __init__(
+ self,
+ out_dim,
+ student_temp=0.1,
+ center_momentum=0.9,
+ ):
+ super().__init__()
+ self.student_temp = student_temp
+ self.center_momentum = center_momentum
+ self.register_buffer("center", torch.zeros(1, out_dim))
+ self.updated = True
+ self.reduce_handle = None
+ self.len_teacher_output = None
+ self.async_batch_center = None
+
+ @torch.no_grad()
+ def softmax_center_teacher(self, teacher_output, teacher_temp):
+ self.apply_center_update()
+ # teacher centering and sharpening
+ return F.softmax((teacher_output - self.center) / teacher_temp, dim=-1)
+
+ @torch.no_grad()
+ def sinkhorn_knopp_teacher(self, teacher_output, teacher_temp, n_iterations=3):
+ teacher_output = teacher_output.float()
+ world_size = dist.get_world_size() if dist.is_initialized() else 1
+ Q = torch.exp(teacher_output / teacher_temp).t() # Q is K-by-B for consistency with notations from our paper
+ B = Q.shape[1] * world_size # number of samples to assign
+ K = Q.shape[0] # how many prototypes
+
+ # make the matrix sums to 1
+ sum_Q = torch.sum(Q)
+ if dist.is_initialized():
+ dist.all_reduce(sum_Q)
+ Q /= sum_Q
+
+ for it in range(n_iterations):
+ # normalize each row: total weight per prototype must be 1/K
+ sum_of_rows = torch.sum(Q, dim=1, keepdim=True)
+ if dist.is_initialized():
+ dist.all_reduce(sum_of_rows)
+ Q /= sum_of_rows
+ Q /= K
+
+ # normalize each column: total weight per sample must be 1/B
+ Q /= torch.sum(Q, dim=0, keepdim=True)
+ Q /= B
+
+ Q *= B # the columns must sum to 1 so that Q is an assignment
+ return Q.t()
+
+ def forward(self, student_output_list, teacher_out_softmaxed_centered_list):
+ """
+ Cross-entropy between softmax outputs of the teacher and student networks.
+ """
+ # TODO: Use cross_entropy_distribution here
+ total_loss = 0
+ for s in student_output_list:
+ lsm = F.log_softmax(s / self.student_temp, dim=-1)
+ for t in teacher_out_softmaxed_centered_list:
+ loss = torch.sum(t * lsm, dim=-1)
+ total_loss -= loss.mean()
+ return total_loss
+
+ @torch.no_grad()
+ def update_center(self, teacher_output):
+ self.reduce_center_update(teacher_output)
+
+ @torch.no_grad()
+ def reduce_center_update(self, teacher_output):
+ self.updated = False
+ self.len_teacher_output = len(teacher_output)
+ self.async_batch_center = torch.sum(teacher_output, dim=0, keepdim=True)
+ if dist.is_initialized():
+ self.reduce_handle = dist.all_reduce(self.async_batch_center, async_op=True)
+
+ @torch.no_grad()
+ def apply_center_update(self):
+ if self.updated is False:
+ world_size = dist.get_world_size() if dist.is_initialized() else 1
+
+ if self.reduce_handle is not None:
+ self.reduce_handle.wait()
+ _t = self.async_batch_center / (self.len_teacher_output * world_size)
+
+ self.center = self.center * self.center_momentum + _t * (1 - self.center_momentum)
+
+ self.updated = True
diff --git a/mapper/models/dinov2/loss/ibot_patch_loss.py b/mapper/models/dinov2/loss/ibot_patch_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..6732cda0c311c69f193669ebc950fc8665871442
--- /dev/null
+++ b/mapper/models/dinov2/loss/ibot_patch_loss.py
@@ -0,0 +1,151 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.distributed as dist
+import torch.nn.functional as F
+from torch import nn
+
+import logging
+
+
+logger = logging.getLogger("dinov2")
+
+
+try:
+ from xformers.ops import cross_entropy
+
+ def lossfunc(t, s, temp):
+ s = s.float()
+ t = t.float()
+ if s.ndim == 2:
+ return -cross_entropy(s.unsqueeze(0), t.unsqueeze(0), temp, bw_inplace=True).squeeze(0)
+ elif s.ndim == 3:
+ return -cross_entropy(s, t, temp, bw_inplace=True)
+
+except ImportError:
+
+ def lossfunc(t, s, temp):
+ return torch.sum(t * F.log_softmax(s / temp, dim=-1), dim=-1)
+
+
+class iBOTPatchLoss(nn.Module):
+ def __init__(self, patch_out_dim, student_temp=0.1, center_momentum=0.9):
+ super().__init__()
+ self.student_temp = student_temp
+ self.center_momentum = center_momentum
+ self.register_buffer("center", torch.zeros(1, 1, patch_out_dim))
+ self.updated = True
+ self.reduce_handle = None
+ self.len_teacher_patch_tokens = None
+ self.async_batch_center = None
+
+ @torch.no_grad()
+ def softmax_center_teacher(self, teacher_patch_tokens, teacher_temp):
+ self.apply_center_update()
+ # teacher centering and sharpening
+ #
+ # WARNING:
+ # as self.center is a float32, everything gets casted to float32 afterwards
+ #
+ # teacher_patch_tokens = teacher_patch_tokens.float()
+ # return F.softmax((teacher_patch_tokens.sub_(self.center.to(teacher_patch_tokens.dtype))).mul_(1 / teacher_temp), dim=-1)
+
+ return F.softmax((teacher_patch_tokens - self.center) / teacher_temp, dim=-1)
+
+ # this is experimental, keep everything in float16 and let's see what happens:
+ # return F.softmax((teacher_patch_tokens.sub_(self.center)) / teacher_temp, dim=-1)
+
+ @torch.no_grad()
+ def sinkhorn_knopp_teacher(self, teacher_output, teacher_temp, n_masked_patches_tensor, n_iterations=3):
+ teacher_output = teacher_output.float()
+ # world_size = dist.get_world_size() if dist.is_initialized() else 1
+ Q = torch.exp(teacher_output / teacher_temp).t() # Q is K-by-B for consistency with notations from our paper
+ # B = Q.shape[1] * world_size # number of samples to assign
+ B = n_masked_patches_tensor
+ dist.all_reduce(B)
+ K = Q.shape[0] # how many prototypes
+
+ # make the matrix sums to 1
+ sum_Q = torch.sum(Q)
+ if dist.is_initialized():
+ dist.all_reduce(sum_Q)
+ Q /= sum_Q
+
+ for it in range(n_iterations):
+ # normalize each row: total weight per prototype must be 1/K
+ sum_of_rows = torch.sum(Q, dim=1, keepdim=True)
+ if dist.is_initialized():
+ dist.all_reduce(sum_of_rows)
+ Q /= sum_of_rows
+ Q /= K
+
+ # normalize each column: total weight per sample must be 1/B
+ Q /= torch.sum(Q, dim=0, keepdim=True)
+ Q /= B
+
+ Q *= B # the columns must sum to 1 so that Q is an assignment
+ return Q.t()
+
+ def forward(self, student_patch_tokens, teacher_patch_tokens, student_masks_flat):
+ """
+ Cross-entropy between softmax outputs of the teacher and student networks.
+ student_patch_tokens: (B, N, D) tensor
+ teacher_patch_tokens: (B, N, D) tensor
+ student_masks_flat: (B, N) tensor
+ """
+ t = teacher_patch_tokens
+ s = student_patch_tokens
+ loss = torch.sum(t * F.log_softmax(s / self.student_temp, dim=-1), dim=-1)
+ loss = torch.sum(loss * student_masks_flat.float(), dim=-1) / student_masks_flat.sum(dim=-1).clamp(min=1.0)
+ return -loss.mean()
+
+ def forward_masked(
+ self,
+ student_patch_tokens_masked,
+ teacher_patch_tokens_masked,
+ student_masks_flat,
+ n_masked_patches=None,
+ masks_weight=None,
+ ):
+ t = teacher_patch_tokens_masked
+ s = student_patch_tokens_masked
+ # loss = torch.sum(t * F.log_softmax(s / self.student_temp, dim=-1), dim=-1)
+ loss = lossfunc(t, s, self.student_temp)
+ if masks_weight is None:
+ masks_weight = (
+ (1 / student_masks_flat.sum(-1).clamp(min=1.0))
+ .unsqueeze(-1)
+ .expand_as(student_masks_flat)[student_masks_flat]
+ )
+ if n_masked_patches is not None:
+ loss = loss[:n_masked_patches]
+ loss = loss * masks_weight
+ return -loss.sum() / student_masks_flat.shape[0]
+
+ @torch.no_grad()
+ def update_center(self, teacher_patch_tokens):
+ self.reduce_center_update(teacher_patch_tokens)
+
+ @torch.no_grad()
+ def reduce_center_update(self, teacher_patch_tokens):
+ self.updated = False
+ self.len_teacher_patch_tokens = len(teacher_patch_tokens)
+ self.async_batch_center = torch.sum(teacher_patch_tokens.mean(1), dim=0, keepdim=True)
+ if dist.is_initialized():
+ self.reduce_handle = dist.all_reduce(self.async_batch_center, async_op=True)
+
+ @torch.no_grad()
+ def apply_center_update(self):
+ if self.updated is False:
+ world_size = dist.get_world_size() if dist.is_initialized() else 1
+
+ if self.reduce_handle is not None:
+ self.reduce_handle.wait()
+ _t = self.async_batch_center / (self.len_teacher_patch_tokens * world_size)
+
+ self.center = self.center * self.center_momentum + _t * (1 - self.center_momentum)
+
+ self.updated = True
diff --git a/mapper/models/dinov2/loss/koleo_loss.py b/mapper/models/dinov2/loss/koleo_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..b5cbcd91e0fc0b857f477b0910f957f02a6c4335
--- /dev/null
+++ b/mapper/models/dinov2/loss/koleo_loss.py
@@ -0,0 +1,48 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+import logging
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+# import torch.distributed as dist
+
+
+logger = logging.getLogger("dinov2")
+
+
+class KoLeoLoss(nn.Module):
+ """Kozachenko-Leonenko entropic loss regularizer from Sablayrolles et al. - 2018 - Spreading vectors for similarity search"""
+
+ def __init__(self):
+ super().__init__()
+ self.pdist = nn.PairwiseDistance(2, eps=1e-8)
+
+ def pairwise_NNs_inner(self, x):
+ """
+ Pairwise nearest neighbors for L2-normalized vectors.
+ Uses Torch rather than Faiss to remain on GPU.
+ """
+ # parwise dot products (= inverse distance)
+ dots = torch.mm(x, x.t())
+ n = x.shape[0]
+ dots.view(-1)[:: (n + 1)].fill_(-1) # Trick to fill diagonal with -1
+ # max inner prod -> min distance
+ _, I = torch.max(dots, dim=1) # noqa: E741
+ return I
+
+ def forward(self, student_output, eps=1e-8):
+ """
+ Args:
+ student_output (BxD): backbone output of student
+ """
+ with torch.cuda.amp.autocast(enabled=False):
+ student_output = F.normalize(student_output, eps=eps, p=2, dim=-1)
+ I = self.pairwise_NNs_inner(student_output) # noqa: E741
+ distances = self.pdist(student_output, student_output[I]) # BxD, BxD -> B
+ loss = -torch.log(distances + eps).mean()
+ return loss
diff --git a/mapper/models/dinov2/models/__init__.py b/mapper/models/dinov2/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2cc5d4a54f4fdf9b7a343bf2950605fd38ff916c
--- /dev/null
+++ b/mapper/models/dinov2/models/__init__.py
@@ -0,0 +1,43 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+import logging
+
+from . import vision_transformer as vits
+
+
+logger = logging.getLogger("dinov2")
+
+
+def build_model(args, only_teacher=False, img_size=224):
+ args.arch = args.arch.removesuffix("_memeff")
+ if "vit" in args.arch:
+ vit_kwargs = dict(
+ img_size=img_size,
+ patch_size=args.patch_size,
+ init_values=args.layerscale,
+ ffn_layer=args.ffn_layer,
+ block_chunks=args.block_chunks,
+ qkv_bias=args.qkv_bias,
+ proj_bias=args.proj_bias,
+ ffn_bias=args.ffn_bias,
+ num_register_tokens=args.num_register_tokens,
+ interpolate_offset=args.interpolate_offset,
+ interpolate_antialias=args.interpolate_antialias,
+ )
+ teacher = vits.__dict__[args.arch](**vit_kwargs)
+ if only_teacher:
+ return teacher, teacher.embed_dim
+ student = vits.__dict__[args.arch](
+ **vit_kwargs,
+ drop_path_rate=args.drop_path_rate,
+ drop_path_uniform=args.drop_path_uniform,
+ )
+ embed_dim = student.embed_dim
+ return student, teacher, embed_dim
+
+
+def build_model_from_cfg(cfg, only_teacher=False):
+ return build_model(cfg.student, only_teacher=only_teacher, img_size=cfg.crops.global_crops_size)
\ No newline at end of file
diff --git a/mapper/models/dinov2/models/vision_transformer.py b/mapper/models/dinov2/models/vision_transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..b324dc261582bc1e811322cbfab57e0666e9d41a
--- /dev/null
+++ b/mapper/models/dinov2/models/vision_transformer.py
@@ -0,0 +1,393 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
+
+from functools import partial
+import math
+import logging
+from typing import Sequence, Tuple, Union, Callable
+
+import torch
+import torch.nn as nn
+import torch.utils.checkpoint
+from torch.nn.init import trunc_normal_
+
+from ..layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block
+
+
+logger = logging.getLogger("dinov2")
+
+
+def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
+ if not depth_first and include_root:
+ fn(module=module, name=name)
+ for child_name, child_module in module.named_children():
+ child_name = ".".join((name, child_name)) if name else child_name
+ named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
+ if depth_first and include_root:
+ fn(module=module, name=name)
+ return module
+
+
+class BlockChunk(nn.ModuleList):
+ def forward(self, x):
+ for b in self:
+ x = b(x)
+ return x
+
+
+class DinoVisionTransformer(nn.Module):
+ def __init__(
+ self,
+ img_size=224,
+ patch_size=16,
+ in_chans=3,
+ embed_dim=768,
+ depth=12,
+ num_heads=12,
+ mlp_ratio=4.0,
+ qkv_bias=True,
+ ffn_bias=True,
+ proj_bias=True,
+ drop_path_rate=0.0,
+ drop_path_uniform=False,
+ init_values=None, # for layerscale: None or 0 => no layerscale
+ embed_layer=PatchEmbed,
+ act_layer=nn.GELU,
+ block_fn=Block,
+ ffn_layer="mlp",
+ block_chunks=1,
+ num_register_tokens=0,
+ interpolate_antialias=False,
+ interpolate_offset=0.1,
+ ):
+ """
+ Args:
+ img_size (int, tuple): input image size
+ patch_size (int, tuple): patch size
+ in_chans (int): number of input channels
+ embed_dim (int): embedding dimension
+ depth (int): depth of transformer
+ num_heads (int): number of attention heads
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
+ qkv_bias (bool): enable bias for qkv if True
+ proj_bias (bool): enable bias for proj in attn if True
+ ffn_bias (bool): enable bias for ffn if True
+ drop_path_rate (float): stochastic depth rate
+ drop_path_uniform (bool): apply uniform drop rate across blocks
+ weight_init (str): weight init scheme
+ init_values (float): layer-scale init values
+ embed_layer (nn.Module): patch embedding layer
+ act_layer (nn.Module): MLP activation layer
+ block_fn (nn.Module): transformer block class
+ ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
+ block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
+ num_register_tokens: (int) number of extra cls tokens (so-called "registers")
+ interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings
+ interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings
+ """
+ super().__init__()
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
+
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
+ self.num_tokens = 1
+ self.n_blocks = depth
+ self.num_heads = num_heads
+ self.patch_size = patch_size
+ self.num_register_tokens = num_register_tokens
+ self.interpolate_antialias = interpolate_antialias
+ self.interpolate_offset = interpolate_offset
+
+ self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
+ num_patches = self.patch_embed.num_patches
+
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
+ assert num_register_tokens >= 0
+ self.register_tokens = (
+ nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None
+ )
+
+ if drop_path_uniform is True:
+ dpr = [drop_path_rate] * depth
+ else:
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
+
+ if ffn_layer == "mlp":
+ logger.info("using MLP layer as FFN")
+ ffn_layer = Mlp
+ elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
+ logger.info("using SwiGLU layer as FFN")
+ ffn_layer = SwiGLUFFNFused
+ elif ffn_layer == "identity":
+ logger.info("using Identity layer as FFN")
+
+ def f(*args, **kwargs):
+ return nn.Identity()
+
+ ffn_layer = f
+ else:
+ raise NotImplementedError
+
+ blocks_list = [
+ block_fn(
+ dim=embed_dim,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ proj_bias=proj_bias,
+ ffn_bias=ffn_bias,
+ drop_path=dpr[i],
+ norm_layer=norm_layer,
+ act_layer=act_layer,
+ ffn_layer=ffn_layer,
+ init_values=init_values,
+ )
+ for i in range(depth)
+ ]
+ if block_chunks > 0:
+ self.chunked_blocks = True
+ chunked_blocks = []
+ chunksize = depth // block_chunks
+ for i in range(0, depth, chunksize):
+ # this is to keep the block index consistent if we chunk the block list
+ chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
+ self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
+ else:
+ self.chunked_blocks = False
+ self.blocks = nn.ModuleList(blocks_list)
+
+ self.norm = norm_layer(embed_dim)
+ self.head = nn.Identity()
+
+ self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
+
+ self.init_weights()
+
+ def init_weights(self):
+ trunc_normal_(self.pos_embed, std=0.02)
+ nn.init.normal_(self.cls_token, std=1e-6)
+ if self.register_tokens is not None:
+ nn.init.normal_(self.register_tokens, std=1e-6)
+ named_apply(init_weights_vit_timm, self)
+
+ def interpolate_pos_encoding(self, x, w, h):
+ previous_dtype = x.dtype
+ npatch = x.shape[1] - 1
+ N = self.pos_embed.shape[1] - 1
+ if npatch == N and w == h:
+ return self.pos_embed
+ pos_embed = self.pos_embed.float()
+ class_pos_embed = pos_embed[:, 0]
+ patch_pos_embed = pos_embed[:, 1:]
+ dim = x.shape[-1]
+ w0 = w // self.patch_size
+ h0 = h // self.patch_size
+ # we add a small number to avoid floating point error in the interpolation
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
+ w0, h0 = w0 + self.interpolate_offset, h0 + self.interpolate_offset
+
+ sqrt_N = math.sqrt(N)
+ sx, sy = float(w0) / sqrt_N, float(h0) / sqrt_N
+ patch_pos_embed = nn.functional.interpolate(
+ patch_pos_embed.reshape(1, int(sqrt_N), int(sqrt_N), dim).permute(0, 3, 1, 2),
+ scale_factor=(sx, sy),
+ mode="bicubic",
+ antialias=self.interpolate_antialias,
+ )
+
+ assert int(w0) == patch_pos_embed.shape[-2]
+ assert int(h0) == patch_pos_embed.shape[-1]
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
+
+ def prepare_tokens_with_masks(self, x, masks=None):
+ B, nc, w, h = x.shape
+ x = self.patch_embed(x)
+ if masks is not None:
+ x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
+
+ x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
+ x = x + self.interpolate_pos_encoding(x, w, h)
+
+ if self.register_tokens is not None:
+ x = torch.cat(
+ (
+ x[:, :1],
+ self.register_tokens.expand(x.shape[0], -1, -1),
+ x[:, 1:],
+ ),
+ dim=1,
+ )
+
+ return x
+
+ def forward_features_list(self, x_list, masks_list):
+ x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
+ for blk in self.blocks:
+ x = blk(x)
+
+ all_x = x
+ output = []
+ for x, masks in zip(all_x, masks_list):
+ x_norm = self.norm(x)
+ output.append(
+ {
+ "x_norm_clstoken": x_norm[:, 0],
+ "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
+ "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
+ "x_prenorm": x,
+ "masks": masks,
+ }
+ )
+ return output
+
+ def forward_features(self, x, masks=None):
+ if isinstance(x, list):
+ return self.forward_features_list(x, masks)
+
+ x = self.prepare_tokens_with_masks(x, masks)
+
+ for blk in self.blocks:
+ x = blk(x)
+
+ x_norm = self.norm(x)
+ return {
+ "x_norm_clstoken": x_norm[:, 0],
+ "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
+ "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
+ "x_prenorm": x,
+ "masks": masks,
+ }
+
+ def _get_intermediate_layers_not_chunked(self, x, n=1):
+ x = self.prepare_tokens_with_masks(x)
+ # If n is an int, take the n last blocks. If it's a list, take them
+ output, total_block_len = [], len(self.blocks)
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
+ for i, blk in enumerate(self.blocks):
+ x = blk(x)
+ if i in blocks_to_take:
+ output.append(x)
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
+ return output
+
+ def _get_intermediate_layers_chunked(self, x, n=1):
+ x = self.prepare_tokens_with_masks(x)
+ output, i, total_block_len = [], 0, len(self.blocks[-1])
+ # If n is an int, take the n last blocks. If it's a list, take them
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
+ for block_chunk in self.blocks:
+ for blk in block_chunk[i:]: # Passing the nn.Identity()
+ x = blk(x)
+ if i in blocks_to_take:
+ output.append(x)
+ i += 1
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
+ return output
+
+ def get_intermediate_layers(
+ self,
+ x: torch.Tensor,
+ n: Union[int, Sequence] = 1, # Layers or n last layers to take
+ reshape: bool = False,
+ return_class_token: bool = False,
+ norm=True,
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
+ if self.chunked_blocks:
+ outputs = self._get_intermediate_layers_chunked(x, n)
+ else:
+ outputs = self._get_intermediate_layers_not_chunked(x, n)
+ if norm:
+ outputs = [self.norm(out) for out in outputs]
+ class_tokens = [out[:, 0] for out in outputs]
+ outputs = [out[:, 1:] for out in outputs]
+ if reshape:
+ B, _, w, h = x.shape
+ # print(outputs[0].shape, x.shape, B, w // self.patch_size, h // self.patch_size, "\n"*10)
+ outputs = [
+ out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
+ for out in outputs
+ ]
+ if return_class_token:
+ return tuple(zip(outputs, class_tokens))
+ return tuple(outputs)
+
+ def forward(self, *args, is_training=False, **kwargs):
+ ret = self.forward_features(*args, **kwargs)
+ if is_training:
+ return ret
+ else:
+ return self.head(ret["x_norm_clstoken"])
+
+
+def init_weights_vit_timm(module: nn.Module, name: str = ""):
+ """ViT weight initialization, original timm impl (for reproducibility)"""
+ if isinstance(module, nn.Linear):
+ trunc_normal_(module.weight, std=0.02)
+ if module.bias is not None:
+ nn.init.zeros_(module.bias)
+
+
+def vit_small(patch_size=16, num_register_tokens=0, **kwargs):
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=384,
+ depth=12,
+ num_heads=6,
+ mlp_ratio=4,
+ block_fn=partial(Block, attn_class=MemEffAttention),
+ num_register_tokens=num_register_tokens,
+ **kwargs,
+ )
+ return model
+
+
+def vit_base(patch_size=16, num_register_tokens=0, **kwargs):
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=768,
+ depth=12,
+ num_heads=12,
+ mlp_ratio=4,
+ block_fn=partial(Block, attn_class=MemEffAttention),
+ num_register_tokens=num_register_tokens,
+ **kwargs,
+ )
+ return model
+
+
+def vit_large(patch_size=16, num_register_tokens=0, **kwargs):
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=1024,
+ depth=24,
+ num_heads=16,
+ mlp_ratio=4,
+ block_fn=partial(Block, attn_class=MemEffAttention),
+ num_register_tokens=num_register_tokens,
+ **kwargs,
+ )
+ return model
+
+
+def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs):
+ """
+ Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
+ """
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=1536,
+ depth=40,
+ num_heads=24,
+ mlp_ratio=4,
+ block_fn=partial(Block, attn_class=MemEffAttention),
+ num_register_tokens=num_register_tokens,
+ **kwargs,
+ )
+ return model
\ No newline at end of file
diff --git a/mapper/models/dinov2/run/__init__.py b/mapper/models/dinov2/run/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b88da6bf80be92af00b72dfdb0a806fa64a7a2d9
--- /dev/null
+++ b/mapper/models/dinov2/run/__init__.py
@@ -0,0 +1,4 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
diff --git a/mapper/models/dinov2/run/eval/knn.py b/mapper/models/dinov2/run/eval/knn.py
new file mode 100644
index 0000000000000000000000000000000000000000..d11918445cdfe415fe58ac8b3ad0bf29702e3457
--- /dev/null
+++ b/mapper/models/dinov2/run/eval/knn.py
@@ -0,0 +1,59 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+import logging
+import os
+import sys
+
+from dinov2.eval.knn import get_args_parser as get_knn_args_parser
+from dinov2.logging import setup_logging
+from dinov2.run.submit import get_args_parser, submit_jobs
+
+
+logger = logging.getLogger("dinov2")
+
+
+class Evaluator:
+ def __init__(self, args):
+ self.args = args
+
+ def __call__(self):
+ from dinov2.eval.knn import main as knn_main
+
+ self._setup_args()
+ knn_main(self.args)
+
+ def checkpoint(self):
+ import submitit
+
+ logger.info(f"Requeuing {self.args}")
+ empty = type(self)(self.args)
+ return submitit.helpers.DelayedSubmission(empty)
+
+ def _setup_args(self):
+ import submitit
+
+ job_env = submitit.JobEnvironment()
+ self.args.output_dir = self.args.output_dir.replace("%j", str(job_env.job_id))
+ logger.info(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}")
+ logger.info(f"Args: {self.args}")
+
+
+def main():
+ description = "Submitit launcher for DINOv2 k-NN evaluation"
+ knn_args_parser = get_knn_args_parser(add_help=False)
+ parents = [knn_args_parser]
+ args_parser = get_args_parser(description=description, parents=parents)
+ args = args_parser.parse_args()
+
+ setup_logging()
+
+ assert os.path.exists(args.config_file), "Configuration file does not exist!"
+ submit_jobs(Evaluator, args, name="dinov2:knn")
+ return 0
+
+
+if __name__ == "__main__":
+ sys.exit(main())
diff --git a/mapper/models/dinov2/run/eval/linear.py b/mapper/models/dinov2/run/eval/linear.py
new file mode 100644
index 0000000000000000000000000000000000000000..e1dc3293e88512a5cf885ab775dc08e01aed6724
--- /dev/null
+++ b/mapper/models/dinov2/run/eval/linear.py
@@ -0,0 +1,59 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+import logging
+import os
+import sys
+
+from dinov2.eval.linear import get_args_parser as get_linear_args_parser
+from dinov2.logging import setup_logging
+from dinov2.run.submit import get_args_parser, submit_jobs
+
+
+logger = logging.getLogger("dinov2")
+
+
+class Evaluator:
+ def __init__(self, args):
+ self.args = args
+
+ def __call__(self):
+ from dinov2.eval.linear import main as linear_main
+
+ self._setup_args()
+ linear_main(self.args)
+
+ def checkpoint(self):
+ import submitit
+
+ logger.info(f"Requeuing {self.args}")
+ empty = type(self)(self.args)
+ return submitit.helpers.DelayedSubmission(empty)
+
+ def _setup_args(self):
+ import submitit
+
+ job_env = submitit.JobEnvironment()
+ self.args.output_dir = self.args.output_dir.replace("%j", str(job_env.job_id))
+ logger.info(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}")
+ logger.info(f"Args: {self.args}")
+
+
+def main():
+ description = "Submitit launcher for DINOv2 linear evaluation"
+ linear_args_parser = get_linear_args_parser(add_help=False)
+ parents = [linear_args_parser]
+ args_parser = get_args_parser(description=description, parents=parents)
+ args = args_parser.parse_args()
+
+ setup_logging()
+
+ assert os.path.exists(args.config_file), "Configuration file does not exist!"
+ submit_jobs(Evaluator, args, name="dinov2:linear")
+ return 0
+
+
+if __name__ == "__main__":
+ sys.exit(main())
diff --git a/mapper/models/dinov2/run/eval/log_regression.py b/mapper/models/dinov2/run/eval/log_regression.py
new file mode 100644
index 0000000000000000000000000000000000000000..cdf02181122de72cfa463ef38494967219df9cf3
--- /dev/null
+++ b/mapper/models/dinov2/run/eval/log_regression.py
@@ -0,0 +1,59 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+import logging
+import os
+import sys
+
+from dinov2.eval.log_regression import get_args_parser as get_log_regression_args_parser
+from dinov2.logging import setup_logging
+from dinov2.run.submit import get_args_parser, submit_jobs
+
+
+logger = logging.getLogger("dinov2")
+
+
+class Evaluator:
+ def __init__(self, args):
+ self.args = args
+
+ def __call__(self):
+ from dinov2.eval.log_regression import main as log_regression_main
+
+ self._setup_args()
+ log_regression_main(self.args)
+
+ def checkpoint(self):
+ import submitit
+
+ logger.info(f"Requeuing {self.args}")
+ empty = type(self)(self.args)
+ return submitit.helpers.DelayedSubmission(empty)
+
+ def _setup_args(self):
+ import submitit
+
+ job_env = submitit.JobEnvironment()
+ self.args.output_dir = self.args.output_dir.replace("%j", str(job_env.job_id))
+ logger.info(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}")
+ logger.info(f"Args: {self.args}")
+
+
+def main():
+ description = "Submitit launcher for DINOv2 logistic evaluation"
+ log_regression_args_parser = get_log_regression_args_parser(add_help=False)
+ parents = [log_regression_args_parser]
+ args_parser = get_args_parser(description=description, parents=parents)
+ args = args_parser.parse_args()
+
+ setup_logging()
+
+ assert os.path.exists(args.config_file), "Configuration file does not exist!"
+ submit_jobs(Evaluator, args, name="dinov2:logreg")
+ return 0
+
+
+if __name__ == "__main__":
+ sys.exit(main())
diff --git a/mapper/models/dinov2/run/submit.py b/mapper/models/dinov2/run/submit.py
new file mode 100644
index 0000000000000000000000000000000000000000..4d1f718e704cf9a48913422404c25a7fcc50e738
--- /dev/null
+++ b/mapper/models/dinov2/run/submit.py
@@ -0,0 +1,122 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+import argparse
+import logging
+import os
+from pathlib import Path
+from typing import List, Optional
+
+import submitit
+
+from dinov2.utils.cluster import (
+ get_slurm_executor_parameters,
+ get_slurm_partition,
+ get_user_checkpoint_path,
+)
+
+
+logger = logging.getLogger("dinov2")
+
+
+def get_args_parser(
+ description: Optional[str] = None,
+ parents: Optional[List[argparse.ArgumentParser]] = None,
+ add_help: bool = True,
+) -> argparse.ArgumentParser:
+ parents = parents or []
+ slurm_partition = get_slurm_partition()
+ parser = argparse.ArgumentParser(
+ description=description,
+ parents=parents,
+ add_help=add_help,
+ )
+ parser.add_argument(
+ "--ngpus",
+ "--gpus",
+ "--gpus-per-node",
+ default=8,
+ type=int,
+ help="Number of GPUs to request on each node",
+ )
+ parser.add_argument(
+ "--nodes",
+ "--nnodes",
+ default=1,
+ type=int,
+ help="Number of nodes to request",
+ )
+ parser.add_argument(
+ "--timeout",
+ default=2800,
+ type=int,
+ help="Duration of the job",
+ )
+ parser.add_argument(
+ "--partition",
+ default=slurm_partition,
+ type=str,
+ help="Partition where to submit",
+ )
+ parser.add_argument(
+ "--use-volta32",
+ action="store_true",
+ help="Request V100-32GB GPUs",
+ )
+ parser.add_argument(
+ "--comment",
+ default="",
+ type=str,
+ help="Comment to pass to scheduler, e.g. priority message",
+ )
+ parser.add_argument(
+ "--exclude",
+ default="",
+ type=str,
+ help="Nodes to exclude",
+ )
+ return parser
+
+
+def get_shared_folder() -> Path:
+ user_checkpoint_path = get_user_checkpoint_path()
+ if user_checkpoint_path is None:
+ raise RuntimeError("Path to user checkpoint cannot be determined")
+ path = user_checkpoint_path / "experiments"
+ path.mkdir(exist_ok=True)
+ return path
+
+
+def submit_jobs(task_class, args, name: str):
+ if not args.output_dir:
+ args.output_dir = str(get_shared_folder() / "%j")
+
+ Path(args.output_dir).mkdir(parents=True, exist_ok=True)
+ executor = submitit.AutoExecutor(folder=args.output_dir, slurm_max_num_timeout=30)
+
+ kwargs = {}
+ if args.use_volta32:
+ kwargs["slurm_constraint"] = "volta32gb"
+ if args.comment:
+ kwargs["slurm_comment"] = args.comment
+ if args.exclude:
+ kwargs["slurm_exclude"] = args.exclude
+
+ executor_params = get_slurm_executor_parameters(
+ nodes=args.nodes,
+ num_gpus_per_node=args.ngpus,
+ timeout_min=args.timeout, # max is 60 * 72
+ slurm_signal_delay_s=120,
+ slurm_partition=args.partition,
+ **kwargs,
+ )
+ executor.update_parameters(name=name, **executor_params)
+
+ task = task_class(args)
+ job = executor.submit(task)
+
+ logger.info(f"Submitted job_id: {job.job_id}")
+ str_output_dir = os.path.abspath(args.output_dir).replace("%j", str(job.job_id))
+ logger.info(f"Logs and checkpoints will be saved at: {str_output_dir}")
diff --git a/mapper/models/dinov2/run/train/train.py b/mapper/models/dinov2/run/train/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..c2366e9bf79765e6abcd70dda6b43f31cb7093eb
--- /dev/null
+++ b/mapper/models/dinov2/run/train/train.py
@@ -0,0 +1,59 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+import logging
+import os
+import sys
+
+from dinov2.logging import setup_logging
+from dinov2.train import get_args_parser as get_train_args_parser
+from dinov2.run.submit import get_args_parser, submit_jobs
+
+
+logger = logging.getLogger("dinov2")
+
+
+class Trainer(object):
+ def __init__(self, args):
+ self.args = args
+
+ def __call__(self):
+ from dinov2.train import main as train_main
+
+ self._setup_args()
+ train_main(self.args)
+
+ def checkpoint(self):
+ import submitit
+
+ logger.info(f"Requeuing {self.args}")
+ empty = type(self)(self.args)
+ return submitit.helpers.DelayedSubmission(empty)
+
+ def _setup_args(self):
+ import submitit
+
+ job_env = submitit.JobEnvironment()
+ self.args.output_dir = self.args.output_dir.replace("%j", str(job_env.job_id))
+ logger.info(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}")
+ logger.info(f"Args: {self.args}")
+
+
+def main():
+ description = "Submitit launcher for DINOv2 training"
+ train_args_parser = get_train_args_parser(add_help=False)
+ parents = [train_args_parser]
+ args_parser = get_args_parser(description=description, parents=parents)
+ args = args_parser.parse_args()
+
+ setup_logging()
+
+ assert os.path.exists(args.config_file), "Configuration file does not exist!"
+ submit_jobs(Trainer, args, name="dinov2:train")
+ return 0
+
+
+if __name__ == "__main__":
+ sys.exit(main())
diff --git a/mapper/models/dinov2/train/__init__.py b/mapper/models/dinov2/train/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f1752922d04fff0112eb7796be28ff6b68c6073
--- /dev/null
+++ b/mapper/models/dinov2/train/__init__.py
@@ -0,0 +1,7 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+from .train import get_args_parser, main
+from .ssl_meta_arch import SSLMetaArch
diff --git a/mapper/models/dinov2/train/ssl_meta_arch.py b/mapper/models/dinov2/train/ssl_meta_arch.py
new file mode 100644
index 0000000000000000000000000000000000000000..3ccf15e904ebeb6134dfb4f5c99da4fc8d41b8e4
--- /dev/null
+++ b/mapper/models/dinov2/train/ssl_meta_arch.py
@@ -0,0 +1,400 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+from functools import partial
+import logging
+
+import torch
+from torch import nn
+
+from dinov2.loss import DINOLoss, iBOTPatchLoss, KoLeoLoss
+from dinov2.models import build_model_from_cfg
+from dinov2.layers import DINOHead
+from dinov2.utils.utils import has_batchnorms
+from dinov2.utils.param_groups import get_params_groups_with_decay, fuse_params_groups
+from dinov2.fsdp import get_fsdp_wrapper, ShardedGradScaler, get_fsdp_modules, reshard_fsdp_model
+
+from dinov2.models.vision_transformer import BlockChunk
+
+
+try:
+ from xformers.ops import fmha
+except ImportError:
+ raise AssertionError("xFormers is required for training")
+
+
+logger = logging.getLogger("dinov2")
+
+
+class SSLMetaArch(nn.Module):
+ def __init__(self, cfg):
+ super().__init__()
+ self.cfg = cfg
+ self.fp16_scaler = ShardedGradScaler() if cfg.compute_precision.grad_scaler else None
+
+ student_model_dict = dict()
+ teacher_model_dict = dict()
+
+ student_backbone, teacher_backbone, embed_dim = build_model_from_cfg(cfg)
+ student_model_dict["backbone"] = student_backbone
+ teacher_model_dict["backbone"] = teacher_backbone
+ logger.info(f"OPTIONS -- architecture : embed_dim: {embed_dim}")
+
+ if cfg.student.pretrained_weights:
+ chkpt = torch.load(cfg.student.pretrained_weights)
+ logger.info(f"OPTIONS -- pretrained weights: loading from {cfg.student.pretrained_weights}")
+ student_backbone.load_state_dict(chkpt["model"], strict=False)
+
+ self.embed_dim = embed_dim
+ self.dino_out_dim = cfg.dino.head_n_prototypes
+
+ self.do_dino = cfg.dino.loss_weight > 0
+ self.do_koleo = cfg.dino.koleo_loss_weight > 0
+ self.do_ibot = cfg.ibot.loss_weight > 0
+ self.ibot_separate_head = cfg.ibot.separate_head
+
+ logger.info("OPTIONS -- DINO")
+ if self.do_dino:
+ logger.info(f"OPTIONS -- DINO -- loss_weight: {cfg.dino.loss_weight}")
+ logger.info(f"OPTIONS -- DINO -- head_n_prototypes: {cfg.dino.head_n_prototypes}")
+ logger.info(f"OPTIONS -- DINO -- head_bottleneck_dim: {cfg.dino.head_bottleneck_dim}")
+ logger.info(f"OPTIONS -- DINO -- head_hidden_dim: {cfg.dino.head_hidden_dim}")
+ self.dino_loss_weight = cfg.dino.loss_weight
+ dino_head = partial(
+ DINOHead,
+ in_dim=embed_dim,
+ out_dim=cfg.dino.head_n_prototypes,
+ hidden_dim=cfg.dino.head_hidden_dim,
+ bottleneck_dim=cfg.dino.head_bottleneck_dim,
+ nlayers=cfg.dino.head_nlayers,
+ )
+ self.dino_loss = DINOLoss(self.dino_out_dim)
+ if self.do_koleo:
+ logger.info("OPTIONS -- DINO -- applying KOLEO regularization")
+ self.koleo_loss = KoLeoLoss()
+
+ else:
+ logger.info("OPTIONS -- DINO -- not using DINO")
+
+ if self.do_dino or self.do_ibot:
+ student_model_dict["dino_head"] = dino_head()
+ teacher_model_dict["dino_head"] = dino_head()
+
+ logger.info("OPTIONS -- IBOT")
+ logger.info(f"OPTIONS -- IBOT -- loss_weight: {cfg.ibot.loss_weight}")
+ logger.info(f"OPTIONS -- IBOT masking -- ibot_mask_ratio_tuple: {cfg.ibot.mask_ratio_min_max}")
+ logger.info(f"OPTIONS -- IBOT masking -- ibot_mask_sample_probability: {cfg.ibot.mask_sample_probability}")
+ if self.do_ibot:
+ self.ibot_loss_weight = cfg.ibot.loss_weight
+ assert max(cfg.ibot.mask_ratio_min_max) > 0, "please provide a positive mask ratio tuple for ibot"
+ assert cfg.ibot.mask_sample_probability > 0, "please provide a positive mask probability for ibot"
+ self.ibot_out_dim = cfg.ibot.head_n_prototypes if self.ibot_separate_head else cfg.dino.head_n_prototypes
+ self.ibot_patch_loss = iBOTPatchLoss(self.ibot_out_dim)
+ if self.ibot_separate_head:
+ logger.info(f"OPTIONS -- IBOT -- loss_weight: {cfg.ibot.loss_weight}")
+ logger.info(f"OPTIONS -- IBOT -- head_n_prototypes: {cfg.ibot.head_n_prototypes}")
+ logger.info(f"OPTIONS -- IBOT -- head_bottleneck_dim: {cfg.ibot.head_bottleneck_dim}")
+ logger.info(f"OPTIONS -- IBOT -- head_hidden_dim: {cfg.ibot.head_hidden_dim}")
+ ibot_head = partial(
+ DINOHead,
+ in_dim=embed_dim,
+ out_dim=cfg.ibot.head_n_prototypes,
+ hidden_dim=cfg.ibot.head_hidden_dim,
+ bottleneck_dim=cfg.ibot.head_bottleneck_dim,
+ nlayers=cfg.ibot.head_nlayers,
+ )
+ student_model_dict["ibot_head"] = ibot_head()
+ teacher_model_dict["ibot_head"] = ibot_head()
+ else:
+ logger.info("OPTIONS -- IBOT -- head shared with DINO")
+
+ self.need_to_synchronize_fsdp_streams = True
+
+ self.student = nn.ModuleDict(student_model_dict)
+ self.teacher = nn.ModuleDict(teacher_model_dict)
+
+ # there is no backpropagation through the teacher, so no need for gradients
+ for p in self.teacher.parameters():
+ p.requires_grad = False
+ logger.info(f"Student and Teacher are built: they are both {cfg.student.arch} network.")
+
+ def forward(self, inputs):
+ raise NotImplementedError
+
+ def backprop_loss(self, loss):
+ if self.fp16_scaler is not None:
+ self.fp16_scaler.scale(loss).backward()
+ else:
+ loss.backward()
+
+ def forward_backward(self, images, teacher_temp):
+ n_global_crops = 2
+ assert n_global_crops == 2
+ n_local_crops = self.cfg.crops.local_crops_number
+
+ global_crops = images["collated_global_crops"].cuda(non_blocking=True)
+ local_crops = images["collated_local_crops"].cuda(non_blocking=True)
+
+ masks = images["collated_masks"].cuda(non_blocking=True)
+ mask_indices_list = images["mask_indices_list"].cuda(non_blocking=True)
+ n_masked_patches_tensor = images["n_masked_patches"].cuda(non_blocking=True)
+ n_masked_patches = mask_indices_list.shape[0]
+ upperbound = images["upperbound"]
+ masks_weight = images["masks_weight"].cuda(non_blocking=True)
+
+ n_local_crops_loss_terms = max(n_local_crops * n_global_crops, 1)
+ n_global_crops_loss_terms = (n_global_crops - 1) * n_global_crops
+
+ do_dino = self.do_dino
+ do_ibot = self.do_ibot
+
+ # loss scales
+ ibot_loss_scale = 1.0 / n_global_crops
+
+ # teacher output
+ @torch.no_grad()
+ def get_teacher_output():
+ x, n_global_crops_teacher = global_crops, n_global_crops
+ teacher_backbone_output_dict = self.teacher.backbone(x, is_training=True)
+ teacher_cls_tokens = teacher_backbone_output_dict["x_norm_clstoken"]
+ teacher_cls_tokens = teacher_cls_tokens.chunk(n_global_crops_teacher)
+ # watch out: these are chunked and cat'd in reverse so A is matched to B in the global crops dino loss
+ teacher_cls_tokens = torch.cat((teacher_cls_tokens[1], teacher_cls_tokens[0]))
+ ibot_teacher_patch_tokens = teacher_backbone_output_dict["x_norm_patchtokens"]
+ _dim = ibot_teacher_patch_tokens.shape[-1]
+ n_cls_tokens = teacher_cls_tokens.shape[0]
+
+ if do_ibot and not self.ibot_separate_head:
+ buffer_tensor_teacher = ibot_teacher_patch_tokens.new_zeros(upperbound + n_cls_tokens, _dim)
+ buffer_tensor_teacher[:n_cls_tokens].copy_(teacher_cls_tokens)
+ torch.index_select(
+ ibot_teacher_patch_tokens.flatten(0, 1),
+ dim=0,
+ index=mask_indices_list,
+ out=buffer_tensor_teacher[n_cls_tokens : n_cls_tokens + n_masked_patches],
+ )
+ tokens_after_head = self.teacher.dino_head(buffer_tensor_teacher)
+ teacher_cls_tokens_after_head = tokens_after_head[:n_cls_tokens]
+ masked_teacher_patch_tokens_after_head = tokens_after_head[
+ n_cls_tokens : n_cls_tokens + n_masked_patches
+ ]
+ elif do_ibot and self.ibot_separate_head:
+ buffer_tensor_teacher = ibot_teacher_patch_tokens.new_zeros(upperbound, _dim)
+ torch.index_select(
+ ibot_teacher_patch_tokens.flatten(0, 1),
+ dim=0,
+ index=mask_indices_list,
+ out=buffer_tensor_teacher[:n_masked_patches],
+ )
+ teacher_cls_tokens_after_head = self.teacher.dino_head(teacher_cls_tokens)
+ masked_teacher_patch_tokens_after_head = self.teacher.ibot_head(buffer_tensor_teacher)[
+ :n_masked_patches
+ ]
+ else:
+ teacher_cls_tokens_after_head = self.teacher.dino_head(teacher_cls_tokens)
+ masked_teacher_ibot_softmaxed_centered = None
+
+ if self.cfg.train.centering == "centering":
+ teacher_dino_softmaxed_centered_list = self.dino_loss.softmax_center_teacher(
+ teacher_cls_tokens_after_head, teacher_temp=teacher_temp
+ ).view(n_global_crops_teacher, -1, *teacher_cls_tokens_after_head.shape[1:])
+ self.dino_loss.update_center(teacher_cls_tokens_after_head)
+ if do_ibot:
+ masked_teacher_patch_tokens_after_head = masked_teacher_patch_tokens_after_head.unsqueeze(0)
+ masked_teacher_ibot_softmaxed_centered = self.ibot_patch_loss.softmax_center_teacher(
+ masked_teacher_patch_tokens_after_head[:, :n_masked_patches], teacher_temp=teacher_temp
+ )
+ masked_teacher_ibot_softmaxed_centered = masked_teacher_ibot_softmaxed_centered.squeeze(0)
+ self.ibot_patch_loss.update_center(masked_teacher_patch_tokens_after_head[:n_masked_patches])
+
+ elif self.cfg.train.centering == "sinkhorn_knopp":
+ teacher_dino_softmaxed_centered_list = self.dino_loss.sinkhorn_knopp_teacher(
+ teacher_cls_tokens_after_head, teacher_temp=teacher_temp
+ ).view(n_global_crops_teacher, -1, *teacher_cls_tokens_after_head.shape[1:])
+
+ if do_ibot:
+ masked_teacher_ibot_softmaxed_centered = self.ibot_patch_loss.sinkhorn_knopp_teacher(
+ masked_teacher_patch_tokens_after_head,
+ teacher_temp=teacher_temp,
+ n_masked_patches_tensor=n_masked_patches_tensor,
+ )
+
+ else:
+ raise NotImplementedError
+
+ return teacher_dino_softmaxed_centered_list, masked_teacher_ibot_softmaxed_centered
+
+ teacher_dino_softmaxed_centered_list, masked_teacher_ibot_softmaxed_centered = get_teacher_output()
+ reshard_fsdp_model(self.teacher)
+
+ loss_dict = {}
+
+ loss_accumulator = 0 # for backprop
+ student_global_backbone_output_dict, student_local_backbone_output_dict = self.student.backbone(
+ [global_crops, local_crops], masks=[masks, None], is_training=True
+ )
+
+ inputs_for_student_head_list = []
+
+ # 1a: local crops cls tokens
+ student_local_cls_tokens = student_local_backbone_output_dict["x_norm_clstoken"]
+ inputs_for_student_head_list.append(student_local_cls_tokens.unsqueeze(0))
+
+ # 1b: global crops cls tokens
+ student_global_cls_tokens = student_global_backbone_output_dict["x_norm_clstoken"]
+ inputs_for_student_head_list.append(student_global_cls_tokens.unsqueeze(0))
+
+ # 1c: global crops patch tokens
+ if do_ibot:
+ _dim = student_global_backbone_output_dict["x_norm_clstoken"].shape[-1]
+ ibot_student_patch_tokens = student_global_backbone_output_dict["x_norm_patchtokens"]
+ buffer_tensor_patch_tokens = ibot_student_patch_tokens.new_zeros(upperbound, _dim)
+ buffer_tensor_patch_tokens[:n_masked_patches].copy_(
+ torch.index_select(ibot_student_patch_tokens.flatten(0, 1), dim=0, index=mask_indices_list)
+ )
+ if not self.ibot_separate_head:
+ inputs_for_student_head_list.append(buffer_tensor_patch_tokens.unsqueeze(0))
+ else:
+ student_global_masked_patch_tokens_after_head = self.student.ibot_head(buffer_tensor_patch_tokens)[
+ :n_masked_patches
+ ]
+
+ # 2: run
+ _attn_bias, cat_inputs = fmha.BlockDiagonalMask.from_tensor_list(inputs_for_student_head_list)
+ outputs_list = _attn_bias.split(self.student.dino_head(cat_inputs))
+
+ # 3a: local crops cls tokens
+ student_local_cls_tokens_after_head = outputs_list.pop(0).squeeze(0)
+
+ # 3b: global crops cls tokens
+ student_global_cls_tokens_after_head = outputs_list.pop(0).squeeze(0)
+
+ # 3c: global crops patch tokens
+ if do_ibot and not self.ibot_separate_head:
+ student_global_masked_patch_tokens_after_head = outputs_list.pop(0).squeeze(0)[:n_masked_patches]
+
+ if n_local_crops > 0:
+ dino_local_crops_loss = self.dino_loss(
+ student_output_list=student_local_cls_tokens_after_head.chunk(n_local_crops),
+ teacher_out_softmaxed_centered_list=teacher_dino_softmaxed_centered_list,
+ ) / (n_global_crops_loss_terms + n_local_crops_loss_terms)
+
+ # store for display
+ loss_dict["dino_local_crops_loss"] = dino_local_crops_loss
+
+ # accumulate loss
+ loss_accumulator += self.dino_loss_weight * dino_local_crops_loss
+
+ # process global crops
+ loss_scales = 2 # this is here since we process global crops together
+
+ if do_dino:
+ # compute loss
+ dino_global_crops_loss = (
+ self.dino_loss(
+ student_output_list=[student_global_cls_tokens_after_head],
+ teacher_out_softmaxed_centered_list=[
+ teacher_dino_softmaxed_centered_list.flatten(0, 1)
+ ], # these were chunked and stacked in reverse so A is matched to B
+ )
+ * loss_scales
+ / (n_global_crops_loss_terms + n_local_crops_loss_terms)
+ )
+
+ loss_dict["dino_global_crops_loss"] = dino_global_crops_loss
+
+ # accumulate loss
+ loss_accumulator += self.dino_loss_weight * dino_global_crops_loss
+
+ student_cls_tokens = student_global_cls_tokens
+
+ if self.do_koleo:
+ koleo_loss = self.cfg.dino.koleo_loss_weight * sum(
+ self.koleo_loss(p) for p in student_cls_tokens.chunk(2)
+ ) # we don't apply koleo loss between cls tokens of a same image
+ loss_accumulator += koleo_loss
+ loss_dict["koleo_loss"] = (
+ koleo_loss / loss_scales
+ ) # this is to display the same losses as before but we can remove eventually
+
+ if do_ibot:
+ # compute loss
+ ibot_patch_loss = (
+ self.ibot_patch_loss.forward_masked(
+ student_global_masked_patch_tokens_after_head,
+ masked_teacher_ibot_softmaxed_centered,
+ student_masks_flat=masks,
+ n_masked_patches=n_masked_patches,
+ masks_weight=masks_weight,
+ )
+ * loss_scales
+ * ibot_loss_scale
+ )
+
+ # store for display
+ loss_dict["ibot_loss"] = ibot_patch_loss / 2
+
+ # accumulate loss
+ loss_accumulator += self.ibot_loss_weight * ibot_patch_loss
+
+ self.backprop_loss(loss_accumulator)
+
+ self.fsdp_synchronize_streams()
+
+ return loss_dict
+
+ def fsdp_synchronize_streams(self):
+ if self.need_to_synchronize_fsdp_streams:
+ torch.cuda.synchronize()
+ self.student.dino_head._streams = (
+ self.teacher.dino_head._streams
+ ) = self.student.backbone._streams = self.teacher.backbone._streams
+ self.need_to_synchronize_fsdp_streams = False
+
+ def update_teacher(self, m):
+ student_param_list = []
+ teacher_param_list = []
+ with torch.no_grad():
+ for k in self.student.keys():
+ for ms, mt in zip(get_fsdp_modules(self.student[k]), get_fsdp_modules(self.teacher[k])):
+ student_param_list += ms.params
+ teacher_param_list += mt.params
+ torch._foreach_mul_(teacher_param_list, m)
+ torch._foreach_add_(teacher_param_list, student_param_list, alpha=1 - m)
+
+ def train(self):
+ super().train()
+ self.teacher.eval()
+
+ def get_maybe_fused_params_for_submodel(self, m):
+ params_groups = get_params_groups_with_decay(
+ model=m,
+ lr_decay_rate=self.cfg.optim.layerwise_decay,
+ patch_embed_lr_mult=self.cfg.optim.patch_embed_lr_mult,
+ )
+ fused_params_groups = fuse_params_groups(params_groups)
+ logger.info("fusing param groups")
+
+ for g in fused_params_groups:
+ g["foreach"] = True
+ return fused_params_groups
+
+ def get_params_groups(self):
+ all_params_groups = []
+ for m in self.student.values():
+ all_params_groups += self.get_maybe_fused_params_for_submodel(m)
+ return all_params_groups
+
+ def prepare_for_distributed_training(self):
+ logger.info("DISTRIBUTED FSDP -- preparing model for distributed training")
+ if has_batchnorms(self.student):
+ raise NotImplementedError
+ # below will synchronize all student subnetworks across gpus:
+ for k, v in self.student.items():
+ self.teacher[k].load_state_dict(self.student[k].state_dict())
+ student_model_cfg = self.cfg.compute_precision.student[k]
+ self.student[k] = get_fsdp_wrapper(student_model_cfg, modules_to_wrap={BlockChunk})(self.student[k])
+ teacher_model_cfg = self.cfg.compute_precision.teacher[k]
+ self.teacher[k] = get_fsdp_wrapper(teacher_model_cfg, modules_to_wrap={BlockChunk})(self.teacher[k])
diff --git a/mapper/models/dinov2/train/train.py b/mapper/models/dinov2/train/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..473b8d01473654182de9f91c94a2d8720fe096a5
--- /dev/null
+++ b/mapper/models/dinov2/train/train.py
@@ -0,0 +1,318 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+import argparse
+import logging
+import math
+import os
+from functools import partial
+
+from fvcore.common.checkpoint import PeriodicCheckpointer
+import torch
+
+from dinov2.data import SamplerType, make_data_loader, make_dataset
+from dinov2.data import collate_data_and_cast, DataAugmentationDINO, MaskingGenerator
+import dinov2.distributed as distributed
+from dinov2.fsdp import FSDPCheckpointer
+from dinov2.logging import MetricLogger
+from dinov2.utils.config import setup
+from dinov2.utils.utils import CosineScheduler
+
+from dinov2.train.ssl_meta_arch import SSLMetaArch
+
+
+torch.backends.cuda.matmul.allow_tf32 = True # PyTorch 1.12 sets this to False by default
+logger = logging.getLogger("dinov2")
+
+
+def get_args_parser(add_help: bool = True):
+ parser = argparse.ArgumentParser("DINOv2 training", add_help=add_help)
+ parser.add_argument("--config-file", default="", metavar="FILE", help="path to config file")
+ parser.add_argument(
+ "--no-resume",
+ action="store_true",
+ help="Whether to not attempt to resume from the checkpoint directory. ",
+ )
+ parser.add_argument("--eval-only", action="store_true", help="perform evaluation only")
+ parser.add_argument("--eval", type=str, default="", help="Eval type to perform")
+ parser.add_argument(
+ "opts",
+ help="""
+Modify config options at the end of the command. For Yacs configs, use
+space-separated "PATH.KEY VALUE" pairs.
+For python-based LazyConfig, use "path.key=value".
+ """.strip(),
+ default=None,
+ nargs=argparse.REMAINDER,
+ )
+ parser.add_argument(
+ "--output-dir",
+ "--output_dir",
+ default="",
+ type=str,
+ help="Output directory to save logs and checkpoints",
+ )
+
+ return parser
+
+
+def build_optimizer(cfg, params_groups):
+ return torch.optim.AdamW(params_groups, betas=(cfg.optim.adamw_beta1, cfg.optim.adamw_beta2))
+
+
+def build_schedulers(cfg):
+ OFFICIAL_EPOCH_LENGTH = cfg.train.OFFICIAL_EPOCH_LENGTH
+ lr = dict(
+ base_value=cfg.optim["lr"],
+ final_value=cfg.optim["min_lr"],
+ total_iters=cfg.optim["epochs"] * OFFICIAL_EPOCH_LENGTH,
+ warmup_iters=cfg.optim["warmup_epochs"] * OFFICIAL_EPOCH_LENGTH,
+ start_warmup_value=0,
+ )
+ wd = dict(
+ base_value=cfg.optim["weight_decay"],
+ final_value=cfg.optim["weight_decay_end"],
+ total_iters=cfg.optim["epochs"] * OFFICIAL_EPOCH_LENGTH,
+ )
+ momentum = dict(
+ base_value=cfg.teacher["momentum_teacher"],
+ final_value=cfg.teacher["final_momentum_teacher"],
+ total_iters=cfg.optim["epochs"] * OFFICIAL_EPOCH_LENGTH,
+ )
+ teacher_temp = dict(
+ base_value=cfg.teacher["teacher_temp"],
+ final_value=cfg.teacher["teacher_temp"],
+ total_iters=cfg.teacher["warmup_teacher_temp_epochs"] * OFFICIAL_EPOCH_LENGTH,
+ warmup_iters=cfg.teacher["warmup_teacher_temp_epochs"] * OFFICIAL_EPOCH_LENGTH,
+ start_warmup_value=cfg.teacher["warmup_teacher_temp"],
+ )
+
+ lr_schedule = CosineScheduler(**lr)
+ wd_schedule = CosineScheduler(**wd)
+ momentum_schedule = CosineScheduler(**momentum)
+ teacher_temp_schedule = CosineScheduler(**teacher_temp)
+ last_layer_lr_schedule = CosineScheduler(**lr)
+
+ last_layer_lr_schedule.schedule[
+ : cfg.optim["freeze_last_layer_epochs"] * OFFICIAL_EPOCH_LENGTH
+ ] = 0 # mimicking the original schedules
+
+ logger.info("Schedulers ready.")
+
+ return (
+ lr_schedule,
+ wd_schedule,
+ momentum_schedule,
+ teacher_temp_schedule,
+ last_layer_lr_schedule,
+ )
+
+
+def apply_optim_scheduler(optimizer, lr, wd, last_layer_lr):
+ for param_group in optimizer.param_groups:
+ is_last_layer = param_group["is_last_layer"]
+ lr_multiplier = param_group["lr_multiplier"]
+ wd_multiplier = param_group["wd_multiplier"]
+ param_group["weight_decay"] = wd * wd_multiplier
+ param_group["lr"] = (last_layer_lr if is_last_layer else lr) * lr_multiplier
+
+
+def do_test(cfg, model, iteration):
+ new_state_dict = model.teacher.state_dict()
+
+ if distributed.is_main_process():
+ iterstring = str(iteration)
+ eval_dir = os.path.join(cfg.train.output_dir, "eval", iterstring)
+ os.makedirs(eval_dir, exist_ok=True)
+ # save teacher checkpoint
+ teacher_ckp_path = os.path.join(eval_dir, "teacher_checkpoint.pth")
+ torch.save({"teacher": new_state_dict}, teacher_ckp_path)
+
+
+def do_train(cfg, model, resume=False):
+ model.train()
+ inputs_dtype = torch.half
+ fp16_scaler = model.fp16_scaler # for mixed precision training
+
+ # setup optimizer
+
+ optimizer = build_optimizer(cfg, model.get_params_groups())
+ (
+ lr_schedule,
+ wd_schedule,
+ momentum_schedule,
+ teacher_temp_schedule,
+ last_layer_lr_schedule,
+ ) = build_schedulers(cfg)
+
+ # checkpointer
+ checkpointer = FSDPCheckpointer(model, cfg.train.output_dir, optimizer=optimizer, save_to_disk=True)
+
+ start_iter = checkpointer.resume_or_load(cfg.MODEL.WEIGHTS, resume=resume).get("iteration", -1) + 1
+
+ OFFICIAL_EPOCH_LENGTH = cfg.train.OFFICIAL_EPOCH_LENGTH
+ max_iter = cfg.optim.epochs * OFFICIAL_EPOCH_LENGTH
+
+ periodic_checkpointer = PeriodicCheckpointer(
+ checkpointer,
+ period=3 * OFFICIAL_EPOCH_LENGTH,
+ max_iter=max_iter,
+ max_to_keep=3,
+ )
+
+ # setup data preprocessing
+
+ img_size = cfg.crops.global_crops_size
+ patch_size = cfg.student.patch_size
+ n_tokens = (img_size // patch_size) ** 2
+ mask_generator = MaskingGenerator(
+ input_size=(img_size // patch_size, img_size // patch_size),
+ max_num_patches=0.5 * img_size // patch_size * img_size // patch_size,
+ )
+
+ data_transform = DataAugmentationDINO(
+ cfg.crops.global_crops_scale,
+ cfg.crops.local_crops_scale,
+ cfg.crops.local_crops_number,
+ global_crops_size=cfg.crops.global_crops_size,
+ local_crops_size=cfg.crops.local_crops_size,
+ )
+
+ collate_fn = partial(
+ collate_data_and_cast,
+ mask_ratio_tuple=cfg.ibot.mask_ratio_min_max,
+ mask_probability=cfg.ibot.mask_sample_probability,
+ n_tokens=n_tokens,
+ mask_generator=mask_generator,
+ dtype=inputs_dtype,
+ )
+
+ # setup data loader
+
+ dataset = make_dataset(
+ dataset_str=cfg.train.dataset_path,
+ transform=data_transform,
+ target_transform=lambda _: (),
+ )
+ # sampler_type = SamplerType.INFINITE
+ sampler_type = SamplerType.SHARDED_INFINITE
+ data_loader = make_data_loader(
+ dataset=dataset,
+ batch_size=cfg.train.batch_size_per_gpu,
+ num_workers=cfg.train.num_workers,
+ shuffle=True,
+ seed=start_iter, # TODO: Fix this -- cfg.train.seed
+ sampler_type=sampler_type,
+ sampler_advance=0, # TODO(qas): fix this -- start_iter * cfg.train.batch_size_per_gpu,
+ drop_last=True,
+ collate_fn=collate_fn,
+ )
+
+ # training loop
+
+ iteration = start_iter
+
+ logger.info("Starting training from iteration {}".format(start_iter))
+ metrics_file = os.path.join(cfg.train.output_dir, "training_metrics.json")
+ metric_logger = MetricLogger(delimiter=" ", output_file=metrics_file)
+ header = "Training"
+
+ for data in metric_logger.log_every(
+ data_loader,
+ 10,
+ header,
+ max_iter,
+ start_iter,
+ ):
+ current_batch_size = data["collated_global_crops"].shape[0] / 2
+ if iteration > max_iter:
+ return
+
+ # apply schedules
+
+ lr = lr_schedule[iteration]
+ wd = wd_schedule[iteration]
+ mom = momentum_schedule[iteration]
+ teacher_temp = teacher_temp_schedule[iteration]
+ last_layer_lr = last_layer_lr_schedule[iteration]
+ apply_optim_scheduler(optimizer, lr, wd, last_layer_lr)
+
+ # compute losses
+
+ optimizer.zero_grad(set_to_none=True)
+ loss_dict = model.forward_backward(data, teacher_temp=teacher_temp)
+
+ # clip gradients
+
+ if fp16_scaler is not None:
+ if cfg.optim.clip_grad:
+ fp16_scaler.unscale_(optimizer)
+ for v in model.student.values():
+ v.clip_grad_norm_(cfg.optim.clip_grad)
+ fp16_scaler.step(optimizer)
+ fp16_scaler.update()
+ else:
+ if cfg.optim.clip_grad:
+ for v in model.student.values():
+ v.clip_grad_norm_(cfg.optim.clip_grad)
+ optimizer.step()
+
+ # perform teacher EMA update
+
+ model.update_teacher(mom)
+
+ # logging
+
+ if distributed.get_global_size() > 1:
+ for v in loss_dict.values():
+ torch.distributed.all_reduce(v)
+ loss_dict_reduced = {k: v.item() / distributed.get_global_size() for k, v in loss_dict.items()}
+
+ if math.isnan(sum(loss_dict_reduced.values())):
+ logger.info("NaN detected")
+ raise AssertionError
+ losses_reduced = sum(loss for loss in loss_dict_reduced.values())
+
+ metric_logger.update(lr=lr)
+ metric_logger.update(wd=wd)
+ metric_logger.update(mom=mom)
+ metric_logger.update(last_layer_lr=last_layer_lr)
+ metric_logger.update(current_batch_size=current_batch_size)
+ metric_logger.update(total_loss=losses_reduced, **loss_dict_reduced)
+
+ # checkpointing and testing
+
+ if cfg.evaluation.eval_period_iterations > 0 and (iteration + 1) % cfg.evaluation.eval_period_iterations == 0:
+ do_test(cfg, model, f"training_{iteration}")
+ torch.cuda.synchronize()
+ periodic_checkpointer.step(iteration)
+
+ iteration = iteration + 1
+ metric_logger.synchronize_between_processes()
+ return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
+
+
+def main(args):
+ cfg = setup(args)
+
+ model = SSLMetaArch(cfg).to(torch.device("cuda"))
+ model.prepare_for_distributed_training()
+
+ logger.info("Model:\n{}".format(model))
+ if args.eval_only:
+ iteration = (
+ FSDPCheckpointer(model, save_dir=cfg.train.output_dir)
+ .resume_or_load(cfg.MODEL.WEIGHTS, resume=not args.no_resume)
+ .get("iteration", -1)
+ + 1
+ )
+ return do_test(cfg, model, f"manual_{iteration}")
+
+ do_train(cfg, model, resume=not args.no_resume)
+
+
+if __name__ == "__main__":
+ args = get_args_parser(add_help=True).parse_args()
+ main(args)
diff --git a/mapper/models/dinov2/utils/__init__.py b/mapper/models/dinov2/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b88da6bf80be92af00b72dfdb0a806fa64a7a2d9
--- /dev/null
+++ b/mapper/models/dinov2/utils/__init__.py
@@ -0,0 +1,4 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
diff --git a/mapper/models/dinov2/utils/cluster.py b/mapper/models/dinov2/utils/cluster.py
new file mode 100644
index 0000000000000000000000000000000000000000..3df87dc3e1eb4f0f8a280dc3137cfef031886314
--- /dev/null
+++ b/mapper/models/dinov2/utils/cluster.py
@@ -0,0 +1,95 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+from enum import Enum
+import os
+from pathlib import Path
+from typing import Any, Dict, Optional
+
+
+class ClusterType(Enum):
+ AWS = "aws"
+ FAIR = "fair"
+ RSC = "rsc"
+
+
+def _guess_cluster_type() -> ClusterType:
+ uname = os.uname()
+ if uname.sysname == "Linux":
+ if uname.release.endswith("-aws"):
+ # Linux kernel versions on AWS instances are of the form "5.4.0-1051-aws"
+ return ClusterType.AWS
+ elif uname.nodename.startswith("rsc"):
+ # Linux kernel versions on RSC instances are standard ones but hostnames start with "rsc"
+ return ClusterType.RSC
+
+ return ClusterType.FAIR
+
+
+def get_cluster_type(cluster_type: Optional[ClusterType] = None) -> Optional[ClusterType]:
+ if cluster_type is None:
+ return _guess_cluster_type()
+
+ return cluster_type
+
+
+def get_checkpoint_path(cluster_type: Optional[ClusterType] = None) -> Optional[Path]:
+ cluster_type = get_cluster_type(cluster_type)
+ if cluster_type is None:
+ return None
+
+ CHECKPOINT_DIRNAMES = {
+ ClusterType.AWS: "checkpoints",
+ ClusterType.FAIR: "checkpoint",
+ ClusterType.RSC: "checkpoint/dino",
+ }
+ return Path("/") / CHECKPOINT_DIRNAMES[cluster_type]
+
+
+def get_user_checkpoint_path(cluster_type: Optional[ClusterType] = None) -> Optional[Path]:
+ checkpoint_path = get_checkpoint_path(cluster_type)
+ if checkpoint_path is None:
+ return None
+
+ username = os.environ.get("USER")
+ assert username is not None
+ return checkpoint_path / username
+
+
+def get_slurm_partition(cluster_type: Optional[ClusterType] = None) -> Optional[str]:
+ cluster_type = get_cluster_type(cluster_type)
+ if cluster_type is None:
+ return None
+
+ SLURM_PARTITIONS = {
+ ClusterType.AWS: "learnlab",
+ ClusterType.FAIR: "learnlab",
+ ClusterType.RSC: "learn",
+ }
+ return SLURM_PARTITIONS[cluster_type]
+
+
+def get_slurm_executor_parameters(
+ nodes: int, num_gpus_per_node: int, cluster_type: Optional[ClusterType] = None, **kwargs
+) -> Dict[str, Any]:
+ # create default parameters
+ params = {
+ "mem_gb": 0, # Requests all memory on a node, see https://slurm.schedmd.com/sbatch.html
+ "gpus_per_node": num_gpus_per_node,
+ "tasks_per_node": num_gpus_per_node, # one task per GPU
+ "cpus_per_task": 10,
+ "nodes": nodes,
+ "slurm_partition": get_slurm_partition(cluster_type),
+ }
+ # apply cluster-specific adjustments
+ cluster_type = get_cluster_type(cluster_type)
+ if cluster_type == ClusterType.AWS:
+ params["cpus_per_task"] = 12
+ del params["mem_gb"]
+ elif cluster_type == ClusterType.RSC:
+ params["cpus_per_task"] = 12
+ # set additional parameters / apply overrides
+ params.update(kwargs)
+ return params
diff --git a/mapper/models/dinov2/utils/config.py b/mapper/models/dinov2/utils/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..c9de578787bbcb376f8bd5a782206d0eb7ec1f52
--- /dev/null
+++ b/mapper/models/dinov2/utils/config.py
@@ -0,0 +1,72 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+import math
+import logging
+import os
+
+from omegaconf import OmegaConf
+
+import dinov2.distributed as distributed
+from dinov2.logging import setup_logging
+from dinov2.utils import utils
+from dinov2.configs import dinov2_default_config
+
+
+logger = logging.getLogger("dinov2")
+
+
+def apply_scaling_rules_to_cfg(cfg): # to fix
+ if cfg.optim.scaling_rule == "sqrt_wrt_1024":
+ base_lr = cfg.optim.base_lr
+ cfg.optim.lr = base_lr
+ cfg.optim.lr *= math.sqrt(cfg.train.batch_size_per_gpu * distributed.get_global_size() / 1024.0)
+ logger.info(f"sqrt scaling learning rate; base: {base_lr}, new: {cfg.optim.lr}")
+ else:
+ raise NotImplementedError
+ return cfg
+
+
+def write_config(cfg, output_dir, name="config.yaml"):
+ logger.info(OmegaConf.to_yaml(cfg))
+ saved_cfg_path = os.path.join(output_dir, name)
+ with open(saved_cfg_path, "w") as f:
+ OmegaConf.save(config=cfg, f=f)
+ return saved_cfg_path
+
+
+def get_cfg_from_args(args):
+ args.output_dir = os.path.abspath(args.output_dir)
+ args.opts += [f"train.output_dir={args.output_dir}"]
+ default_cfg = OmegaConf.create(dinov2_default_config)
+ cfg = OmegaConf.load(args.config_file)
+ cfg = OmegaConf.merge(default_cfg, cfg, OmegaConf.from_cli(args.opts))
+ return cfg
+
+
+def default_setup(args):
+ distributed.enable(overwrite=True)
+ seed = getattr(args, "seed", 0)
+ rank = distributed.get_global_rank()
+
+ global logger
+ setup_logging(output=args.output_dir, level=logging.INFO)
+ logger = logging.getLogger("dinov2")
+
+ utils.fix_random_seeds(seed + rank)
+ logger.info("git:\n {}\n".format(utils.get_sha()))
+ logger.info("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items())))
+
+
+def setup(args):
+ """
+ Create configs and perform basic setups.
+ """
+ cfg = get_cfg_from_args(args)
+ os.makedirs(args.output_dir, exist_ok=True)
+ default_setup(args)
+ apply_scaling_rules_to_cfg(cfg)
+ write_config(cfg, args.output_dir)
+ return cfg
diff --git a/mapper/models/dinov2/utils/dtype.py b/mapper/models/dinov2/utils/dtype.py
new file mode 100644
index 0000000000000000000000000000000000000000..80f4cd74d99faa2731dbe9f8d3a13d71b3f8e3a8
--- /dev/null
+++ b/mapper/models/dinov2/utils/dtype.py
@@ -0,0 +1,37 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+
+from typing import Dict, Union
+
+import numpy as np
+import torch
+
+
+TypeSpec = Union[str, np.dtype, torch.dtype]
+
+
+_NUMPY_TO_TORCH_DTYPE: Dict[np.dtype, torch.dtype] = {
+ np.dtype("bool"): torch.bool,
+ np.dtype("uint8"): torch.uint8,
+ np.dtype("int8"): torch.int8,
+ np.dtype("int16"): torch.int16,
+ np.dtype("int32"): torch.int32,
+ np.dtype("int64"): torch.int64,
+ np.dtype("float16"): torch.float16,
+ np.dtype("float32"): torch.float32,
+ np.dtype("float64"): torch.float64,
+ np.dtype("complex64"): torch.complex64,
+ np.dtype("complex128"): torch.complex128,
+}
+
+
+def as_torch_dtype(dtype: TypeSpec) -> torch.dtype:
+ if isinstance(dtype, torch.dtype):
+ return dtype
+ if isinstance(dtype, str):
+ dtype = np.dtype(dtype)
+ assert isinstance(dtype, np.dtype), f"Expected an instance of nunpy dtype, got {type(dtype)}"
+ return _NUMPY_TO_TORCH_DTYPE[dtype]
diff --git a/mapper/models/dinov2/utils/param_groups.py b/mapper/models/dinov2/utils/param_groups.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d1bcc83fdde39cd31a54664a1c47783be87605f
--- /dev/null
+++ b/mapper/models/dinov2/utils/param_groups.py
@@ -0,0 +1,103 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+from collections import defaultdict
+import logging
+
+
+logger = logging.getLogger("dinov2")
+
+
+def get_vit_lr_decay_rate(name, lr_decay_rate=1.0, num_layers=12, force_is_backbone=False, chunked_blocks=False):
+ """
+ Calculate lr decay rate for different ViT blocks.
+ Args:
+ name (string): parameter name.
+ lr_decay_rate (float): base lr decay rate.
+ num_layers (int): number of ViT blocks.
+ Returns:
+ lr decay rate for the given parameter.
+ """
+ layer_id = num_layers + 1
+ if name.startswith("backbone") or force_is_backbone:
+ if (
+ ".pos_embed" in name
+ or ".patch_embed" in name
+ or ".mask_token" in name
+ or ".cls_token" in name
+ or ".register_tokens" in name
+ ):
+ layer_id = 0
+ elif force_is_backbone and (
+ "pos_embed" in name
+ or "patch_embed" in name
+ or "mask_token" in name
+ or "cls_token" in name
+ or "register_tokens" in name
+ ):
+ layer_id = 0
+ elif ".blocks." in name and ".residual." not in name:
+ layer_id = int(name[name.find(".blocks.") :].split(".")[2]) + 1
+ elif chunked_blocks and "blocks." in name and "residual." not in name:
+ layer_id = int(name[name.find("blocks.") :].split(".")[2]) + 1
+ elif "blocks." in name and "residual." not in name:
+ layer_id = int(name[name.find("blocks.") :].split(".")[1]) + 1
+
+ return lr_decay_rate ** (num_layers + 1 - layer_id)
+
+
+def get_params_groups_with_decay(model, lr_decay_rate=1.0, patch_embed_lr_mult=1.0):
+ chunked_blocks = False
+ if hasattr(model, "n_blocks"):
+ logger.info("chunked fsdp")
+ n_blocks = model.n_blocks
+ chunked_blocks = model.chunked_blocks
+ elif hasattr(model, "blocks"):
+ logger.info("first code branch")
+ n_blocks = len(model.blocks)
+ elif hasattr(model, "backbone"):
+ logger.info("second code branch")
+ n_blocks = len(model.backbone.blocks)
+ else:
+ logger.info("else code branch")
+ n_blocks = 0
+ all_param_groups = []
+
+ for name, param in model.named_parameters():
+ name = name.replace("_fsdp_wrapped_module.", "")
+ if not param.requires_grad:
+ continue
+ decay_rate = get_vit_lr_decay_rate(
+ name, lr_decay_rate, num_layers=n_blocks, force_is_backbone=n_blocks > 0, chunked_blocks=chunked_blocks
+ )
+ d = {"params": param, "is_last_layer": False, "lr_multiplier": decay_rate, "wd_multiplier": 1.0, "name": name}
+
+ if "last_layer" in name:
+ d.update({"is_last_layer": True})
+
+ if name.endswith(".bias") or "norm" in name or "gamma" in name:
+ d.update({"wd_multiplier": 0.0})
+
+ if "patch_embed" in name:
+ d.update({"lr_multiplier": d["lr_multiplier"] * patch_embed_lr_mult})
+
+ all_param_groups.append(d)
+ logger.info(f"""{name}: lr_multiplier: {d["lr_multiplier"]}, wd_multiplier: {d["wd_multiplier"]}""")
+
+ return all_param_groups
+
+
+def fuse_params_groups(all_params_groups, keys=("lr_multiplier", "wd_multiplier", "is_last_layer")):
+ fused_params_groups = defaultdict(lambda: {"params": []})
+ for d in all_params_groups:
+ identifier = ""
+ for k in keys:
+ identifier += k + str(d[k]) + "_"
+
+ for k in keys:
+ fused_params_groups[identifier][k] = d[k]
+ fused_params_groups[identifier]["params"].append(d["params"])
+
+ return fused_params_groups.values()
\ No newline at end of file
diff --git a/mapper/models/dinov2/utils/utils.py b/mapper/models/dinov2/utils/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..68f8e2c3be5f780bbb7e00359b5ac4fd0ba0785f
--- /dev/null
+++ b/mapper/models/dinov2/utils/utils.py
@@ -0,0 +1,95 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+import logging
+import os
+import random
+import subprocess
+from urllib.parse import urlparse
+
+import numpy as np
+import torch
+from torch import nn
+
+
+logger = logging.getLogger("dinov2")
+
+
+def load_pretrained_weights(model, pretrained_weights, checkpoint_key):
+ if urlparse(pretrained_weights).scheme: # If it looks like an URL
+ state_dict = torch.hub.load_state_dict_from_url(pretrained_weights, map_location="cpu")
+ else:
+ state_dict = torch.load(pretrained_weights, map_location="cpu")
+ if checkpoint_key is not None and checkpoint_key in state_dict:
+ logger.info(f"Take key {checkpoint_key} in provided checkpoint dict")
+ state_dict = state_dict[checkpoint_key]
+ # remove `module.` prefix
+ state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
+ # remove `backbone.` prefix induced by multicrop wrapper
+ state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
+ msg = model.load_state_dict(state_dict, strict=False)
+ logger.info("Pretrained weights found at {} and loaded with msg: {}".format(pretrained_weights, msg))
+
+
+def fix_random_seeds(seed=31):
+ """
+ Fix random seeds.
+ """
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+ np.random.seed(seed)
+ random.seed(seed)
+
+
+def get_sha():
+ cwd = os.path.dirname(os.path.abspath(__file__))
+
+ def _run(command):
+ return subprocess.check_output(command, cwd=cwd).decode("ascii").strip()
+
+ sha = "N/A"
+ diff = "clean"
+ branch = "N/A"
+ try:
+ sha = _run(["git", "rev-parse", "HEAD"])
+ subprocess.check_output(["git", "diff"], cwd=cwd)
+ diff = _run(["git", "diff-index", "HEAD"])
+ diff = "has uncommitted changes" if diff else "clean"
+ branch = _run(["git", "rev-parse", "--abbrev-ref", "HEAD"])
+ except Exception:
+ pass
+ message = f"sha: {sha}, status: {diff}, branch: {branch}"
+ return message
+
+
+class CosineScheduler(object):
+ def __init__(self, base_value, final_value, total_iters, warmup_iters=0, start_warmup_value=0, freeze_iters=0):
+ super().__init__()
+ self.final_value = final_value
+ self.total_iters = total_iters
+
+ freeze_schedule = np.zeros((freeze_iters))
+
+ warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters)
+
+ iters = np.arange(total_iters - warmup_iters - freeze_iters)
+ schedule = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters)))
+ self.schedule = np.concatenate((freeze_schedule, warmup_schedule, schedule))
+
+ assert len(self.schedule) == self.total_iters
+
+ def __getitem__(self, it):
+ if it >= self.total_iters:
+ return self.final_value
+ else:
+ return self.schedule[it]
+
+
+def has_batchnorms(model):
+ bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm)
+ for name, module in model.named_modules():
+ if isinstance(module, bn_types):
+ return True
+ return False
diff --git a/mapper/models/feature_extractor_DPT.py b/mapper/models/feature_extractor_DPT.py
new file mode 100644
index 0000000000000000000000000000000000000000..d71bec2fd4b42767c497568726b023478932b816
--- /dev/null
+++ b/mapper/models/feature_extractor_DPT.py
@@ -0,0 +1,85 @@
+from .base import BaseModel
+from .schema import DINOConfiguration
+import logging
+import torch
+import torch.nn as nn
+
+import sys
+import re
+import os
+
+from .dinov2.eval.depth.ops.wrappers import resize
+from .dinov2.hub.backbones import dinov2_vitb14_reg
+
+module_dir = os.path.dirname(os.path.abspath(__file__))
+sys.path.append(module_dir)
+
+logger = logging.getLogger(__name__)
+
+
+class FeatureExtractor(BaseModel):
+ mean = [0.485, 0.456, 0.406]
+ std = [0.229, 0.224, 0.225]
+
+ def build_encoder(self, conf: DINOConfiguration):
+ BACKBONE_SIZE = "small"
+ backbone_archs = {
+ "small": "vits14",
+ "base": "vitb14", # this one
+ "large": "vitl14",
+ "giant": "vitg14",
+ }
+ backbone_arch = backbone_archs[BACKBONE_SIZE]
+ self.crop_size = int(re.search(r"\d+", backbone_arch).group())
+ backbone_name = f"dinov2_{backbone_arch}"
+
+ self.backbone_model = dinov2_vitb14_reg(
+ pretrained=conf.pretrained, drop_path_rate=0.1)
+
+ if conf.frozen:
+ for param in self.backbone_model.patch_embed.parameters():
+ param.requires_grad = False
+
+ for i in range(0, 10):
+ for param in self.backbone_model.blocks[i].parameters():
+ param.requires_grad = False
+ self.backbone_model.blocks[i].drop_path1 = nn.Identity()
+ self.backbone_model.blocks[i].drop_path2 = nn.Identity()
+
+ self.feat_projection = torch.nn.Conv2d(
+ 768, conf.output_dim, kernel_size=1)
+
+ return self.backbone_model
+
+ def _init(self, conf: DINOConfiguration):
+ # Preprocessing
+ self.register_buffer("mean_", torch.tensor(
+ self.mean), persistent=False)
+ self.register_buffer("std_", torch.tensor(self.std), persistent=False)
+
+ self.build_encoder(conf)
+
+ def _forward(self, data):
+ _, _, h, w = data["image"].shape
+
+ h_num_patches = h // self.crop_size
+ w_num_patches = w // self.crop_size
+
+ h_dino = h_num_patches * self.crop_size
+ w_dino = w_num_patches * self.crop_size
+
+ image = resize(data["image"], (h_dino, w_dino))
+
+ image = (image - self.mean_[:, None, None]) / self.std_[:, None, None]
+
+ output = self.backbone_model.forward_features(
+ image)['x_norm_patchtokens']
+ output = output.reshape(-1, h_num_patches,
+ w_num_patches, output.shape[-1])
+ output = output.permute(0, 3, 1, 2) # channel first
+ output = self.feat_projection(output)
+
+ camera = data['camera'].to(data["image"].device, non_blocking=True)
+ camera = camera.scale(output.shape[-1] / data["image"].shape[-1])
+
+ return output, camera
diff --git a/mapper/models/feature_extractor_resnet.py b/mapper/models/feature_extractor_resnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..3b6b2b479d75fa452b062e720e4397d73e3a31a0
--- /dev/null
+++ b/mapper/models/feature_extractor_resnet.py
@@ -0,0 +1,210 @@
+import logging
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torchvision
+from torchvision.models.feature_extraction import create_feature_extractor
+
+from .base import BaseModel
+from .schema import ResNetConfiguration
+
+logger = logging.getLogger(__name__)
+
+
+class DecoderBlock(nn.Module):
+ def __init__(
+ self, previous, out, ksize=3, num_convs=1, norm=nn.BatchNorm2d, padding="zeros"
+ ):
+ super().__init__()
+ layers = []
+ for i in range(num_convs):
+ conv = nn.Conv2d(
+ previous if i == 0 else out,
+ out,
+ kernel_size=ksize,
+ padding=ksize // 2,
+ bias=norm is None,
+ padding_mode=padding,
+ )
+ layers.append(conv)
+ if norm is not None:
+ layers.append(norm(out))
+ layers.append(nn.ReLU(inplace=True))
+ self.layers = nn.Sequential(*layers)
+
+ def forward(self, previous, skip):
+ _, _, hp, wp = previous.shape
+ _, _, hs, ws = skip.shape
+ scale = 2 ** np.round(np.log2(np.array([hs / hp, ws / wp])))
+ upsampled = nn.functional.interpolate(
+ previous, scale_factor=scale.tolist(), mode="bilinear", align_corners=False
+ )
+ # If the shape of the input map `skip` is not a multiple of 2,
+ # it will not match the shape of the upsampled map `upsampled`.
+ # If the downsampling uses ceil_mode=False, we nedd to crop `skip`.
+ # If it uses ceil_mode=True (not supported here), we should pad it.
+ _, _, hu, wu = upsampled.shape
+ _, _, hs, ws = skip.shape
+ if (hu <= hs) and (wu <= ws):
+ skip = skip[:, :, :hu, :wu]
+ elif (hu >= hs) and (wu >= ws):
+ skip = nn.functional.pad(skip, [0, wu - ws, 0, hu - hs])
+ else:
+ raise ValueError(
+ f"Inconsistent skip vs upsampled shapes: {(hs, ws)}, {(hu, wu)}"
+ )
+
+ return self.layers(skip) + upsampled
+
+
+class FPN(nn.Module):
+ def __init__(self, in_channels_list, out_channels, **kw):
+ super().__init__()
+ self.first = nn.Conv2d(
+ in_channels_list[-1], out_channels, 1, padding=0, bias=True
+ )
+ self.blocks = nn.ModuleList(
+ [
+ DecoderBlock(c, out_channels, ksize=1, **kw)
+ for c in in_channels_list[::-1][1:]
+ ]
+ )
+ self.out = nn.Sequential(
+ nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False),
+ nn.BatchNorm2d(out_channels),
+ nn.ReLU(inplace=True),
+ )
+
+ def forward(self, layers):
+ feats = None
+ for idx, x in enumerate(reversed(layers.values())):
+ if feats is None:
+ feats = self.first(x)
+ else:
+ feats = self.blocks[idx - 1](feats, x)
+ out = self.out(feats)
+ return out
+
+
+def remove_conv_stride(conv):
+ conv_new = nn.Conv2d(
+ conv.in_channels,
+ conv.out_channels,
+ conv.kernel_size,
+ bias=conv.bias is not None,
+ stride=1,
+ padding=conv.padding,
+ )
+ conv_new.weight = conv.weight
+ conv_new.bias = conv.bias
+ return conv_new
+
+
+class FeatureExtractor(BaseModel):
+ default_conf = {
+ "pretrained": True,
+ "input_dim": 3,
+ "output_dim": 128, # # of channels in output feature maps
+ "encoder": "resnet50", # torchvision net as string
+ "remove_stride_from_first_conv": False,
+ "num_downsample": None, # how many downsample block
+ "decoder_norm": "nn.BatchNorm2d", # normalization ind decoder blocks
+ "do_average_pooling": False,
+ "checkpointed": False, # whether to use gradient checkpointing
+ }
+ mean = [0.485, 0.456, 0.406]
+ std = [0.229, 0.224, 0.225]
+
+ def freeze_encoder(self):
+ """
+ Freeze the encoder part of the model, i.e., set requires_grad = False
+ for all parameters in the encoder.
+ """
+ for param in self.encoder.parameters():
+ param.requires_grad = False
+ logger.debug("Encoder has been frozen.")
+
+ def unfreeze_encoder(self):
+ """
+ Unfreeze the encoder part of the model, i.e., set requires_grad = True
+ for all parameters in the encoder.
+ """
+ for param in self.encoder.parameters():
+ param.requires_grad = True
+ logger.debug("Encoder has been unfrozen.")
+
+ def build_encoder(self, conf: ResNetConfiguration):
+ assert isinstance(conf.encoder, str)
+ if conf.pretrained:
+ assert conf.input_dim == 3
+ Encoder = getattr(torchvision.models, conf.encoder)
+
+ kw = {}
+ if conf.encoder.startswith("resnet"):
+ layers = ["relu", "layer1", "layer2", "layer3", "layer4"]
+ kw["replace_stride_with_dilation"] = [False, False, False]
+ elif conf.encoder == "vgg13":
+ layers = [
+ "features.3",
+ "features.8",
+ "features.13",
+ "features.18",
+ "features.23",
+ ]
+ elif conf.encoder == "vgg16":
+ layers = [
+ "features.3",
+ "features.8",
+ "features.15",
+ "features.22",
+ "features.29",
+ ]
+ else:
+ raise NotImplementedError(conf.encoder)
+
+ if conf.num_downsample is not None:
+ layers = layers[: conf.num_downsample]
+ encoder = Encoder(weights="DEFAULT" if conf.pretrained else None, **kw)
+ encoder = create_feature_extractor(encoder, return_nodes=layers)
+ if conf.encoder.startswith("resnet") and conf.remove_stride_from_first_conv:
+ encoder.conv1 = remove_conv_stride(encoder.conv1)
+
+ if conf.do_average_pooling:
+ raise NotImplementedError
+ if conf.checkpointed:
+ raise NotImplementedError
+
+ return encoder, layers
+
+ def _init(self, conf):
+ # Preprocessing
+ self.register_buffer("mean_", torch.tensor(self.mean), persistent=False)
+ self.register_buffer("std_", torch.tensor(self.std), persistent=False)
+
+ # Encoder
+ self.encoder, self.layers = self.build_encoder(conf)
+ s = 128
+ inp = torch.zeros(1, 3, s, s)
+ features = list(self.encoder(inp).values())
+ self.skip_dims = [x.shape[1] for x in features]
+ self.layer_strides = [s / f.shape[-1] for f in features]
+ self.scales = [self.layer_strides[0]]
+
+ # Decoder
+ norm = eval(conf.decoder_norm) if conf.decoder_norm else None # noqa
+ self.decoder = FPN(self.skip_dims, out_channels=conf.output_dim, norm=norm)
+
+ logger.debug(
+ "Built feature extractor with layers {name:dim:stride}:\n"
+ f"{list(zip(self.layers, self.skip_dims, self.layer_strides))}\n"
+ f"and output scales {self.scales}."
+ )
+
+ def _forward(self, data):
+ image = data["image"]
+ image = (image - self.mean_[:, None, None]) / self.std_[:, None, None]
+
+ skip_features = self.encoder(image)
+ output = self.decoder(skip_features)
+ return output, data['camera']
diff --git a/mapper/models/loss.py b/mapper/models/loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..76c9c6f570dd093a9acb4c2ced794736797ea1d7
--- /dev/null
+++ b/mapper/models/loss.py
@@ -0,0 +1,144 @@
+from typing import Optional, Dict
+import torch.nn as nn
+import torch
+from .schema import LossConfiguration
+
+
+def dice_loss(input: torch.Tensor,
+ target: torch.Tensor,
+ loss_mask: torch.Tensor,
+ class_weights: Optional[torch.Tensor | bool],
+ smooth=1e-5):
+ '''
+ :param input: (B, H, W, C) Logits for each class
+ :param target: (B, H, W, C) Ground truth class labels in one_hot
+ :param loss_mask: (B, H, W) Mask indicating valid regions of the image
+ :param class_weights: (C) Weights for each class
+ :param smooth: Smoothing factor to avoid division by zero, default 1.0
+ '''
+
+ if isinstance(class_weights, torch.Tensor):
+ class_weights = class_weights.unsqueeze(0)
+ elif class_weights is None or class_weights == False:
+ class_weights = torch.ones(
+ 1, target.size(-1), dtype=target.dtype, device=target.device)
+ elif class_weights == True:
+ class_weights = target.sum(1)
+ class_weights = torch.reciprocal(target.mean(1) + 1e-3)
+ class_weights = class_weights.clamp(min=1e-5)
+ # Only consider classes that are present
+ class_weights *= (target.sum(1) != 0).float()
+ class_weights.requires_grad = False
+
+ intersect = (2 * input * target)
+ intersect = (intersect) + smooth
+
+ union = (input + target)
+ union = (union) + smooth
+
+ loss = 1 - (intersect / union) # B, H, W, C
+ loss *= class_weights.unsqueeze(0).unsqueeze(0)
+ loss = loss.sum(-1) / class_weights.sum()
+ loss *= loss_mask
+ loss = loss.sum() / loss_mask.sum() # 1
+
+ return loss
+
+
+class EnhancedLoss(nn.Module):
+ def __init__(
+ self,
+ cfg: LossConfiguration,
+ ): # following params in the paper
+ super(EnhancedLoss, self).__init__()
+ self.num_classes = cfg.num_classes
+ self.xent_weight = cfg.xent_weight
+ self.focal = cfg.focal_loss
+ self.focal_gamma = cfg.focal_loss_gamma
+ self.dice_weight = cfg.dice_weight
+ # self.class_mapping =
+
+ if self.xent_weight == 0. and self.dice_weight == 0.:
+ raise ValueError(
+ "At least one of xent_weight and dice_weight must be greater than 0.")
+
+ if self.xent_weight > 0.:
+ self.xent_loss = nn.BCEWithLogitsLoss(
+ reduction="none"
+ )
+
+ if self.dice_weight > 0.:
+ self.dice_loss = dice_loss
+
+ if cfg.class_weights is not None and cfg.class_weights != True:
+ self.register_buffer("class_weights", torch.tensor(
+ cfg.class_weights), persistent=False)
+ else:
+ self.class_weights = cfg.class_weights
+
+ self.class_weights: Optional[torch.Tensor | bool]
+
+ self.requires_frustrum = cfg.requires_frustrum
+ self.requires_flood_mask = cfg.requires_flood_mask
+ self.label_smoothing = cfg.label_smoothing
+
+ def forward(self, pred: Dict[str, torch.Tensor], data: Dict[str, torch.Tensor]):
+ '''
+ Args:
+ pred: Dict containing the
+ - output: (B, C, H, W) Probabilities for each class
+ - valid_bev: (B, H, W) Mask indicating valid regions of the image
+ - conf: (B, H, W) Confidence map
+ data: Dict containing the
+ - seg_masks: (B, H, W, C) Ground truth class labels, one-hot encoded
+ - confidence_map: (B, H, W) Confidence map
+ '''
+ loss = {}
+
+ probs = pred['output'].permute(0, 2, 3, 1) # (B, H, W, C)
+ logits = pred['logits'].permute(0, 2, 3, 1) # (B, H, W, C)
+ labels: torch.Tensor = data['seg_masks'] # (B, H, W, C)
+
+ loss_mask = torch.ones(
+ labels.shape[:3], device=labels.device, dtype=labels.dtype)
+
+ if self.requires_frustrum:
+ frustrum_mask = pred["valid_bev"][..., :-1] != 0
+ loss_mask = loss_mask * frustrum_mask.float()
+
+ if self.requires_flood_mask:
+ flood_mask = data["flood_masks"] == 0
+ loss_mask = loss_mask * flood_mask.float()
+
+ if self.xent_weight > 0.:
+
+ if self.label_smoothing > 0.:
+ labels_ls = labels.float().clone()
+ labels_ls = labels_ls * \
+ (1 - self.label_smoothing) + \
+ self.label_smoothing / self.num_classes
+
+ xent_loss = self.xent_loss(logits, labels_ls)
+ else:
+ xent_loss = self.xent_loss(logits, labels)
+
+ if self.focal:
+ pt = torch.exp(-xent_loss)
+ xent_loss = (1 - pt) ** self.focal_gamma * xent_loss
+
+ xent_loss *= loss_mask.unsqueeze(-1)
+ xent_loss = xent_loss.sum() / (loss_mask.sum() + 1e-5)
+ loss['cross_entropy'] = xent_loss
+ loss['total'] = xent_loss * self.xent_weight
+
+ if self.dice_weight > 0.:
+ dloss = self.dice_loss(
+ probs, labels, loss_mask, self.class_weights)
+ loss['dice'] = dloss
+
+ if 'total' in loss:
+ loss['total'] += dloss * self.dice_weight
+ else:
+ loss['total'] = dloss * self.dice_weight
+
+ return loss
diff --git a/mapper/models/map_perception_net.py b/mapper/models/map_perception_net.py
new file mode 100644
index 0000000000000000000000000000000000000000..1b365239601a72f6084f9440da8c844a0198381e
--- /dev/null
+++ b/mapper/models/map_perception_net.py
@@ -0,0 +1,91 @@
+import torch
+
+from .metrics import PixelAccuracy, MeanObservableIOU, MeanUnobservableIOU, ObservableIOU, UnobservableIOU, mAP
+
+from .loss import EnhancedLoss
+
+from .segmentation_head import SegmentationHead
+
+from . import get_model
+from .base import BaseModel
+from .bev_projection import CartesianProjection, PolarProjectionDepth
+from .schema import ModelConfiguration
+
+class MapPerceptionNet(BaseModel):
+
+ def _init(self, conf: ModelConfiguration):
+ self.image_encoder = get_model(
+ conf.image_encoder.name
+ )(conf.image_encoder.backbone)
+
+ self.decoder = SegmentationHead(
+ in_channels=conf.latent_dim, n_classes=conf.num_classes)
+
+ ppm = conf.pixel_per_meter
+ self.projection_polar = PolarProjectionDepth(
+ conf.z_max,
+ ppm,
+ conf.scale_range,
+ conf.z_min,
+ )
+ self.projection_bev = CartesianProjection(
+ conf.z_max, conf.x_max, ppm, conf.z_min
+ )
+
+ self.scale_classifier = torch.nn.Linear(
+ conf.latent_dim, conf.num_scale_bins
+ ) # l4 - working
+
+ self.num_classes = conf.num_classes
+
+ self.loss_fn = EnhancedLoss(conf.loss)
+
+ def _forward(self, data):
+ f_image, camera = self.image_encoder(data)
+
+ scales = self.scale_classifier(
+ f_image.moveaxis(1, -1))
+ f_polar = self.projection_polar(f_image, scales, camera)
+
+ # Map to the BEV.
+ f_bev, valid_bev, _ = self.projection_bev(
+ f_polar.float(), None, camera.float()
+ )
+
+ output = self.decoder(f_bev[..., :-1])
+
+ probs = torch.nn.functional.sigmoid(output)
+
+ return {
+ "output": probs,
+ "logits": output,
+ "scales": scales,
+ "features_image": f_image,
+ "features_bev": f_bev,
+ "valid_bev": valid_bev.squeeze(1),
+ }
+
+ def loss(self, pred, data):
+ loss = self.loss_fn(pred, data)
+ return loss
+
+ def metrics(self):
+ m = {
+ "pix_acc": PixelAccuracy(),
+ "map": mAP(self.num_classes),
+ "miou_observable": MeanObservableIOU(self.num_classes),
+ "miou_non_observable": MeanUnobservableIOU(self.num_classes),
+ }
+ m.update(
+ {
+ f"IoU_observable_class_{i}": ObservableIOU(i, num_classes=self.num_classes)
+ for i in range(self.num_classes)
+ }
+ )
+ m.update(
+ {
+ f"IoU_non_observable_{i}": UnobservableIOU(i, num_classes=self.num_classes)
+ for i in range(self.num_classes)
+ }
+ )
+ return m
diff --git a/mapper/models/metrics.py b/mapper/models/metrics.py
new file mode 100644
index 0000000000000000000000000000000000000000..9f9f2238121d9f1f9ea2f5f2b4c252110a44679d
--- /dev/null
+++ b/mapper/models/metrics.py
@@ -0,0 +1,129 @@
+import torch
+import torchmetrics
+import torchmetrics.classification
+
+
+class PixelAccuracy(torchmetrics.Metric):
+ def __init__(self):
+ super().__init__()
+ self.add_state("correct_pixels", default=torch.tensor(
+ 0), dist_reduce_fx="sum")
+ self.add_state("total_pixels", default=torch.tensor(0),
+ dist_reduce_fx="sum")
+
+ def update(self, pred, data):
+ output_mask = pred['output'] > 0.5
+ gt_mask = data["seg_masks"].permute(0, 3, 1, 2)
+ self.correct_pixels += (
+ (output_mask == gt_mask).sum()
+ )
+ self.total_pixels += torch.numel(pred["valid_bev"][..., :-1])
+
+ def compute(self):
+ return self.correct_pixels / self.total_pixels
+
+
+class IOU(torchmetrics.Metric):
+ def __init__(self, num_classes=3, **kwargs):
+ super().__init__(**kwargs)
+ self.num_classes = num_classes
+ self.add_state("intersection_observable", default=torch.zeros(
+ num_classes), dist_reduce_fx="sum")
+ self.add_state("union_observable", default=torch.zeros(
+ num_classes), dist_reduce_fx="sum")
+ self.add_state("intersection_non_observable",
+ default=torch.zeros(num_classes), dist_reduce_fx="sum")
+ self.add_state("union_non_observable", default=torch.zeros(
+ num_classes), dist_reduce_fx="sum")
+
+ def update(self, output, data):
+
+ gt = data["seg_masks"]
+ pred = output['output']
+
+ if "confidence_map" in data:
+ observable_mask = torch.logical_and(
+ output["valid_bev"][..., :-1], data["confidence_map"] == 0)
+ non_observable_mask = torch.logical_and(
+ output["valid_bev"][..., :-1], data["confidence_map"] == 1)
+ else:
+ observable_mask = output["valid_bev"][..., :-1]
+ non_observable_mask = torch.logical_not(observable_mask)
+
+ for class_idx in range(self.num_classes):
+ pred_mask = pred[:, class_idx] > 0.5
+ gt_mask = gt[..., class_idx]
+
+ # For observable areas
+ intersection_observable = torch.logical_and(
+ torch.logical_and(pred_mask, gt_mask), observable_mask
+ ).sum()
+ union_observable = torch.logical_and(
+ torch.logical_or(pred_mask, gt_mask), observable_mask
+ ).sum()
+ self.intersection_observable[class_idx] += intersection_observable
+ self.union_observable[class_idx] += union_observable
+
+ # For non-observable areas
+ intersection_non_observable = torch.logical_and(
+ torch.logical_and(pred_mask, gt_mask), non_observable_mask
+ ).sum()
+ union_non_observable = torch.logical_and(
+ torch.logical_or(pred_mask, gt_mask), non_observable_mask
+ ).sum()
+
+ self.intersection_non_observable[class_idx] += intersection_non_observable
+ self.union_non_observable[class_idx] += union_non_observable
+
+ def compute(self):
+ raise NotImplemented
+
+
+class ObservableIOU(IOU):
+ def __init__(self, class_idx=0, **kwargs):
+ super().__init__(**kwargs)
+ self.class_idx = class_idx
+
+ def compute(self):
+ return (self.intersection_observable / (self.union_observable + 1e-6))[self.class_idx]
+
+
+class UnobservableIOU(IOU):
+ def __init__(self, class_idx=0, **kwargs):
+ super().__init__(**kwargs)
+ self.class_idx = class_idx
+
+ def compute(self):
+ return (self.intersection_non_observable / (self.union_non_observable + 1e-6))[self.class_idx]
+
+
+class MeanObservableIOU(IOU):
+ def compute(self):
+ return self.intersection_observable.sum() / (self.union_observable.sum() + 1e-6)
+
+
+class MeanUnobservableIOU(IOU):
+ def compute(self):
+ return self.intersection_non_observable.sum() / (self.union_non_observable.sum() + 1e-6)
+
+
+class mAP(torchmetrics.classification.MultilabelPrecision):
+ def __init__(self, num_labels, **kwargs):
+ super().__init__(num_labels=num_labels, **kwargs)
+
+ def update(self, output, data):
+
+ if "confidence_map" in data:
+ observable_mask = torch.logical_and(
+ output["valid_bev"][..., :-1], data["confidence_map"] == 0)
+ else:
+ observable_mask = output["valid_bev"][..., :-1]
+
+ pred = output['output']
+ pred = pred.permute(0, 2, 3, 1)
+ pred = pred[observable_mask]
+
+ target = data['seg_masks']
+ target = target[observable_mask]
+
+ super(mAP, self).update(pred, target)
diff --git a/mapper/models/schema.py b/mapper/models/schema.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d78d8336cd506c3a475d5c88c12f3c7b0a001b9
--- /dev/null
+++ b/mapper/models/schema.py
@@ -0,0 +1,61 @@
+from dataclasses import dataclass, field
+from typing import Any, Dict, Optional
+
+@dataclass
+class LossConfiguration:
+ num_classes: int
+
+ xent_weight: float = 1.0
+ dice_weight: float = 1.0
+ focal_loss: bool = False
+ focal_loss_gamma: float = 2.0
+ requires_frustrum: bool = True
+ requires_flood_mask: bool = False
+ class_weights: Optional[Any] = None
+ label_smoothing: float = 0.1
+
+@dataclass
+class BackboneConfigurationBase:
+ pretrained: bool
+ frozen: bool
+ output_dim: bool
+
+@dataclass
+class DINOConfiguration(BackboneConfigurationBase):
+ pretrained: bool = True
+ frozen: bool = False
+ output_dim: int = 128
+
+@dataclass
+class ResNetConfiguration(BackboneConfigurationBase):
+ input_dim: int
+ encoder: str
+ remove_stride_from_first_conv: bool
+ num_downsample: Optional[int]
+ decoder_norm: str
+ do_average_pooling: bool
+ checkpointed: bool
+
+@dataclass
+class ImageEncoderConfiguration:
+ name: str
+ backbone: Any
+
+@dataclass
+class ModelConfiguration:
+ segmentation_head: Dict[str, Any]
+ image_encoder: ImageEncoderConfiguration
+
+ name: str
+ num_classes: int
+ latent_dim: int
+ z_max: int
+ x_max: int
+
+ pixel_per_meter: int
+ num_scale_bins: int
+
+ loss: LossConfiguration
+
+ scale_range: list[int] = field(default_factory=lambda: [0, 9])
+ z_min: Optional[int] = None
\ No newline at end of file
diff --git a/mapper/models/segmentation_head.py b/mapper/models/segmentation_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b64c5258cade726de552fbc5a5a71f218435f5b
--- /dev/null
+++ b/mapper/models/segmentation_head.py
@@ -0,0 +1,102 @@
+import torch
+import torch.nn as nn
+import torchvision.models as models
+
+
+class UpsamplingAdd(nn.Module):
+ def __init__(self, in_channels: int, out_channels: int, scale_factor: int = 2):
+ super().__init__()
+ self.upsample_layer = nn.Sequential(
+ nn.Upsample(
+ scale_factor=scale_factor, mode="bilinear", align_corners=False
+ ),
+ nn.Conv2d(in_channels, out_channels,
+ kernel_size=1, padding=0, bias=False),
+ nn.InstanceNorm2d(out_channels),
+ )
+
+ def forward(self, x: torch.Tensor, x_skip: torch.Tensor):
+ # Check if the width dimension is odd and needs zero padding
+ x = self.upsample_layer(x)
+
+ if x.shape[-1] != x_skip.shape[-1] or x.shape[-2] != x_skip.shape[-2]:
+ x = nn.functional.interpolate(
+ x, size=(x_skip.shape[-2], x_skip.shape[-1]), mode="bilinear"
+ )
+
+ return x + x_skip
+
+
+class SegmentationHead(nn.Module):
+ def __init__(self, in_channels: int, n_classes: int, dropout_rate: float = 0.0):
+ super(SegmentationHead, self).__init__()
+
+ backbone = models.resnet18(pretrained=False, zero_init_residual=True)
+
+ self.first_conv = nn.Conv2d(
+ in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False
+ )
+ self.bn1 = backbone.bn1
+ self.relu = backbone.relu
+
+ self.layer1 = backbone.layer1
+ self.layer2 = backbone.layer2
+ self.layer3 = backbone.layer3
+
+ # Upsampling layers
+ self.up3_skip = UpsamplingAdd(
+ in_channels=256, out_channels=128, scale_factor=2)
+ self.up2_skip = UpsamplingAdd(
+ in_channels=128, out_channels=64, scale_factor=2)
+ self.up1_skip = UpsamplingAdd(
+ in_channels=64, out_channels=in_channels, scale_factor=2)
+
+ # Segmentation head
+ self.dropout = nn.Dropout(
+ dropout_rate) if dropout_rate > 0 else nn.Identity()
+
+ self.segmentation_head = nn.Sequential(
+ nn.Conv2d(in_channels, in_channels,
+ kernel_size=3, padding=1, bias=False),
+ nn.InstanceNorm2d(in_channels),
+ nn.ReLU(inplace=True),
+ self.dropout,
+ nn.Conv2d(in_channels, n_classes, kernel_size=1, padding=0),
+ )
+
+ def forward(self, x: torch.Tensor):
+ # (H, W)
+ skip_x = {"1": x}
+ x = self.first_conv(x)
+ x = self.bn1(x)
+ x = self.relu(x)
+ x = self.dropout(x)
+
+ # (H/4, W/4)
+ x = self.layer1(x)
+ skip_x["2"] = x
+ x = self.dropout(x)
+
+ x = self.layer2(x)
+ skip_x["3"] = x
+ x = self.dropout(x)
+
+ # (H/8, W/8)
+ x = self.layer3(x)
+ x = self.dropout(x)
+
+ # First upsample to (H/4, W/4)
+ x = self.up3_skip(x, skip_x["3"])
+ x = self.dropout(x)
+
+ # Second upsample to (H/2, W/2)
+ x = self.up2_skip(x, skip_x["2"])
+ x = self.dropout(x)
+
+ # Third upsample to (H, W)
+ x = self.up1_skip(x, skip_x["1"])
+ x = self.dropout(x)
+
+ segmentation_output = self.segmentation_head(x)
+
+ return segmentation_output
diff --git a/mapper/models/utils.py b/mapper/models/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..d1f18b3d2f1e17d92e5953393c674b5c2e1a0a21
--- /dev/null
+++ b/mapper/models/utils.py
@@ -0,0 +1,69 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+
+import math
+from typing import Optional
+
+import torch
+
+
+def checkpointed(cls, do=True):
+ """Adapted from the DISK implementation of Michał Tyszkiewicz."""
+ assert issubclass(cls, torch.nn.Module)
+
+ class Checkpointed(cls):
+ def forward(self, *args, **kwargs):
+ super_fwd = super(Checkpointed, self).forward
+ if any((torch.is_tensor(a) and a.requires_grad) for a in args):
+ return torch.utils.checkpoint.checkpoint(super_fwd, *args, **kwargs)
+ else:
+ return super_fwd(*args, **kwargs)
+
+ return Checkpointed if do else cls
+
+
+@torch.jit.script
+def make_grid(
+ w: float,
+ h: float,
+ step_x: float = 1.0,
+ step_y: float = 1.0,
+ orig_x: float = 0,
+ orig_y: float = 0,
+ y_up: bool = False,
+ device: Optional[torch.device] = None,
+) -> torch.Tensor:
+ x, y = torch.meshgrid(
+ [
+ torch.arange(orig_x, w + orig_x, step_x, device=device),
+ torch.arange(orig_y, h + orig_y, step_y, device=device),
+ ],
+ indexing="xy",
+ )
+ if y_up:
+ y = y.flip(-2)
+ grid = torch.stack((x, y), -1)
+ return grid
+
+
+@torch.jit.script
+def rotmat2d(angle: torch.Tensor) -> torch.Tensor:
+ c = torch.cos(angle)
+ s = torch.sin(angle)
+ R = torch.stack([c, -s, s, c], -1).reshape(angle.shape + (2, 2))
+ return R
+
+
+@torch.jit.script
+def rotmat2d_grad(angle: torch.Tensor) -> torch.Tensor:
+ c = torch.cos(angle)
+ s = torch.sin(angle)
+ R = torch.stack([-s, -c, c, -s], -1).reshape(angle.shape + (2, 2))
+ return R
+
+
+def deg2rad(x):
+ return x * math.pi / 180
+
+
+def rad2deg(x):
+ return x * 180 / math.pi
diff --git a/mapper/module.py b/mapper/module.py
new file mode 100644
index 0000000000000000000000000000000000000000..dc0c96764b5aa73b46a4d7a4b4623ac3e5eddf55
--- /dev/null
+++ b/mapper/module.py
@@ -0,0 +1,156 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+
+from pathlib import Path
+
+import pytorch_lightning as pl
+import torch
+from omegaconf import DictConfig, OmegaConf, open_dict
+from torchmetrics import MeanMetric, MetricCollection
+
+from . import logger
+from .models import get_model
+
+
+class AverageKeyMeter(MeanMetric):
+ def __init__(self, key, *args, **kwargs):
+ self.key = key
+ super().__init__(*args, **kwargs)
+
+ def update(self, dict):
+ value = dict[self.key]
+ value = value[torch.isfinite(value)]
+ return super().update(value)
+
+
+class GenericModule(pl.LightningModule):
+ def __init__(self, cfg):
+ super().__init__()
+ name = cfg.model.get("name")
+ name = "map_perception_net" if name is None else name
+ self.model = get_model(name)(cfg.model)
+ self.cfg = cfg
+ self.save_hyperparameters(cfg)
+ self.metrics_val = MetricCollection(
+ self.model.metrics(), prefix="val/")
+ self.losses_val = None # we do not know the loss keys in advance
+
+ def forward(self, batch):
+ return self.model(batch)
+
+ def training_step(self, batch):
+ pred = self(batch)
+ losses = self.model.loss(pred, batch)
+ self.log_dict(
+ {f"train/loss/{k}": v.mean() for k, v in losses.items()},
+ prog_bar=True,
+ rank_zero_only=True,
+ on_epoch=True,
+ sync_dist=True
+ )
+ return losses["total"].mean()
+
+ def validation_step(self, batch, batch_idx):
+ pred = self(batch)
+ losses = self.model.loss(pred, batch)
+ if self.losses_val is None:
+ self.losses_val = MetricCollection(
+ {k: AverageKeyMeter(k).to(self.device) for k in losses},
+ prefix="val/",
+ postfix="/loss",
+ )
+ self.metrics_val(pred, batch)
+ self.log_dict(self.metrics_val, on_epoch=True)
+ self.losses_val.update(losses)
+ self.log_dict(self.losses_val, on_epoch=True)
+
+ return pred
+
+ def test_step(self, batch, batch_idx):
+ pred = self(batch)
+
+ return pred
+
+ def validation_epoch_start(self, batch):
+ self.losses_val = None
+
+ def configure_optimizers(self):
+ optimizer = torch.optim.Adam(
+ self.parameters(), lr=self.cfg.training.lr)
+ ret = {"optimizer": optimizer}
+ cfg_scheduler = self.cfg.training.get("lr_scheduler")
+ if cfg_scheduler is not None:
+ scheduler_args = cfg_scheduler.get("args", {})
+ for key in scheduler_args:
+ if scheduler_args[key] == "$total_epochs":
+ scheduler_args[key] = int(self.trainer.max_epochs)
+ scheduler = getattr(torch.optim.lr_scheduler, cfg_scheduler.name)(
+ optimizer=optimizer, **scheduler_args
+ )
+ ret["lr_scheduler"] = {
+ "scheduler": scheduler,
+ "interval": "epoch",
+ "frequency": 1,
+ "monitor": "loss/total/val",
+ "strict": True,
+ "name": "learning_rate",
+ }
+ return ret
+
+ @classmethod
+ def load_from_checkpoint(
+ cls,
+ checkpoint_path,
+ map_location=None,
+ hparams_file=None,
+ strict=True,
+ cfg=None,
+ find_best=False,
+ ):
+ assert hparams_file is None, "hparams are not supported."
+
+ checkpoint = torch.load(
+ checkpoint_path, map_location=map_location or (
+ lambda storage, loc: storage)
+ )
+ if find_best:
+ best_score, best_name = None, None
+ modes = {"min": torch.lt, "max": torch.gt}
+ for key, state in checkpoint["callbacks"].items():
+ if not key.startswith("ModelCheckpoint"):
+ continue
+ mode = eval(key.replace("ModelCheckpoint", ""))["mode"]
+ if best_score is None or modes[mode](
+ state["best_model_score"], best_score
+ ):
+ best_score = state["best_model_score"]
+ best_name = Path(state["best_model_path"]).name
+ logger.info("Loading best checkpoint %s", best_name)
+ if best_name != checkpoint_path:
+ return cls.load_from_checkpoint(
+ Path(checkpoint_path).parent / best_name,
+ map_location,
+ hparams_file,
+ strict,
+ cfg,
+ find_best=False,
+ )
+
+ logger.info(
+ "Using checkpoint %s from epoch %d and step %d.",
+ checkpoint_path,
+ checkpoint["epoch"],
+ checkpoint["global_step"],
+ )
+ cfg_ckpt = checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY]
+ if list(cfg_ckpt.keys()) == ["cfg"]: # backward compatibility
+ cfg_ckpt = cfg_ckpt["cfg"]
+ cfg_ckpt = OmegaConf.create(cfg_ckpt)
+
+ if cfg is None:
+ cfg = {}
+ if not isinstance(cfg, DictConfig):
+ cfg = OmegaConf.create(cfg)
+ with open_dict(cfg_ckpt):
+ cfg = OmegaConf.merge(cfg_ckpt, cfg)
+
+ return pl.core.saving._load_state(cls, checkpoint, strict=strict, cfg=cfg)
diff --git a/mapper/utils/generate_raycast_masks.py b/mapper/utils/generate_raycast_masks.py
new file mode 100644
index 0000000000000000000000000000000000000000..37fe00a30a715a60ec75d75adc6ab175e23d098d
--- /dev/null
+++ b/mapper/utils/generate_raycast_masks.py
@@ -0,0 +1,251 @@
+from multiprocessing import Pool
+from tqdm import tqdm
+from pathlib import Path
+import numpy as np
+from collections import deque
+import argparse
+import cv2
+
+def get_raycast_building_mask(building_grid):
+ laser_range = 200
+ num_laser = 100
+ robot_pos = (building_grid.shape[0] // 2-1, building_grid.shape[1] // 2 - 1)
+ unoccupied_pos = np.stack(np.where(building_grid != 1), axis=1)
+
+ if len(unoccupied_pos) == 0:
+ return None
+
+ l2_dist = unoccupied_pos - [robot_pos[0], robot_pos[1]]
+ closest = ((l2_dist ** 2).sum(1)**0.5).argmin()
+
+ robot_pos = (unoccupied_pos[closest][0], unoccupied_pos[closest][1])
+
+ free_points, hit_points, actual_hit_points = get_free_points_in_front(building_grid, robot_pos, laser_range=laser_range, num_laser=num_laser)
+ free_points[:, 0][free_points[:, 0] >= building_grid.shape[0]] = building_grid.shape[0] - 1
+ free_points[:, 1][free_points[:, 1] >= building_grid.shape[1]] = building_grid.shape[1] - 1
+ free_points[:, 0][free_points[:, 0] < 0] = 0
+ free_points[:, 1][free_points[:, 1] < 0] = 0
+
+ hit_points[:, 0][hit_points[:, 0] >= building_grid.shape[0]] = building_grid.shape[0] - 1
+ hit_points[:, 1][hit_points[:, 1] >= building_grid.shape[1]] = building_grid.shape[1] - 1
+ hit_points[:, 0][hit_points[:, 0] < 0] = 0
+ hit_points[:, 1][hit_points[:, 1] < 0] = 0
+
+ if len(free_points) > 0:
+
+ # Get vis mask by flood filling free space boundary
+ inited_flood_grid = init_flood_fill(robot_pos, hit_points, building_grid.shape)
+ inited_flood_grid = (inited_flood_grid * 255).astype(np.uint8).copy()
+
+ # pick a seed point from free points, that is not 0 in inited_flood_grid. We want it to be unknown
+ np.random.shuffle(free_points)
+
+ for i in range(len(free_points)):
+ seed_point = free_points[i]
+ if inited_flood_grid[seed_point[0], seed_point[1]] != 0:
+ break # Found a valid seed point, exit the loop
+ else:
+ print('Unable to find a valid seed point')
+ return None
+
+ num_filled, flooded_image, mask, bounding_box = cv2.floodFill(inited_flood_grid.copy(), None, seedPoint=(seed_point[1], seed_point[0]), newVal=0)
+ # name = names[batch_ind][-1]
+ return flooded_image
+ else:
+ print("No free points")
+ return None
+
+def flood_fill_simple(center_point, occupancy_map):
+ """
+ center_point: starting point (x,y) of fill
+ occupancy_map: occupancy map generated from Bresenham ray-tracing
+ """
+ # Fill empty areas with queue method
+ occupancy_map = np.copy(occupancy_map)
+ sx, sy = occupancy_map.shape
+ fringe = deque()
+ fringe.appendleft(center_point)
+ while fringe:
+
+ n = fringe.pop()
+ nx, ny = n
+ unknown_val = 0.5
+ # West
+ if nx > 0:
+ if occupancy_map[nx - 1, ny] == unknown_val:
+ occupancy_map[nx - 1, ny] = 0
+ fringe.appendleft((nx - 1, ny))
+ # East
+ if nx < sx - 1:
+ if occupancy_map[nx + 1, ny] == unknown_val:
+ occupancy_map[nx + 1, ny] = 0
+ fringe.appendleft((nx + 1, ny))
+ # North
+ if ny > 0:
+ if occupancy_map[nx, ny - 1] == unknown_val:
+ occupancy_map[nx, ny - 1] = 0
+ fringe.appendleft((nx, ny - 1))
+ # South
+ if ny < sy - 1:
+ if occupancy_map[nx, ny + 1] == unknown_val:
+ occupancy_map[nx, ny + 1] = 0
+ fringe.appendleft((nx, ny + 1))
+ return occupancy_map
+
+def init_flood_fill(robot_pos, obstacle_points, occ_grid_shape):
+ """
+ center_point: center point
+ obstacle_points: detected obstacles points (x,y)
+ xy_points: (x,y) point pairs
+ """
+ center_x, center_y = robot_pos
+ prev_ix, prev_iy = center_x, center_y
+ occupancy_map = (np.ones(occ_grid_shape)) * 0.5
+ # append first obstacle point to last
+ obstacle_points = np.vstack((obstacle_points, obstacle_points[0]))
+ for (x, y) in zip(obstacle_points[:,0], obstacle_points[:,1]):
+ # x coordinate of the the occupied area
+ ix = int(x)
+ # y coordinate of the the occupied area
+ iy = int(y)
+ free_area = bresenham((prev_ix, prev_iy), (ix, iy))
+ for fa in free_area:
+ occupancy_map[fa[0]][fa[1]] = 0 # free area 0.0
+ prev_ix = ix
+ prev_iy = iy
+ return occupancy_map
+
+show_animation = False
+
+def bresenham(start, end):
+ """
+ Implementation of Bresenham's line drawing algorithm
+ See en.wikipedia.org/wiki/Bresenham's_line_algorithm
+ Bresenham's Line Algorithm
+ Produces a np.array from start and end (original from roguebasin.com)
+ >>> points1 = bresenham((4, 4), (6, 10))
+ >>> print(points1)
+ np.array([[4,4], [4,5], [5,6], [5,7], [5,8], [6,9], [6,10]])
+ """
+ # setup initial conditions
+ x1, y1 = start
+ x2, y2 = end
+ dx = x2 - x1
+ dy = y2 - y1
+ is_steep = abs(dy) > abs(dx) # determine how steep the line is
+ if is_steep: # rotate line
+ x1, y1 = y1, x1
+ x2, y2 = y2, x2
+ # swap start and end points if necessary and store swap state
+ swapped = False
+ if x1 > x2:
+ x1, x2 = x2, x1
+ y1, y2 = y2, y1
+ swapped = True
+ dx = x2 - x1 # recalculate differentials
+ dy = y2 - y1 # recalculate differentials
+ error = int(dx / 2.0) # calculate error
+ y_step = 1 if y1 < y2 else -1
+ # iterate over bounding box generating points between start and end
+ y = y1
+ points = []
+ for x in range(x1, x2 + 1):
+ coord = [y, x] if is_steep else (x, y)
+ points.append(coord)
+ error -= abs(dy)
+ if error < 0:
+ y += y_step
+ error += dx
+ if swapped: # reverse the list if the coordinates were swapped
+ points.reverse()
+ points = np.array(points)
+ return points
+
+def get_free_points_in_front(occupancy_grid, robot_pos, laser_range=10, num_laser=100):
+ """
+ Assumes circular lidar
+ occupancy_grid: np.array (h x w)
+ robot_pos: (x, y)
+
+ Outputs:
+ free_points: np.array of hit points (x, y)
+ """
+
+ free_points = []
+ hit_points = [] # actual hit points + last bresenham point (for some reason need this for flodding)
+ actual_hit_points = [] #
+ for orientation in np.linspace(np.pi/2, 3*np.pi/2, num_laser):
+ end_point = (round(robot_pos[0] + laser_range * np.cos(orientation)), round(robot_pos[1] + laser_range * np.sin(orientation)))
+
+ # Get index along ray to check
+ bresenham_points = (bresenham(robot_pos, end_point))
+
+ # Go through the points and see the first hit
+ # TODO: do a check if any first?
+ for i in range(len(bresenham_points)):
+ # if bresenham point is in the map
+ if bresenham_points[i,0] < 0 or bresenham_points[i,0] >= occupancy_grid.shape[0] or bresenham_points[i,1] < 0 or bresenham_points[i,1] >= occupancy_grid.shape[1]:
+ if i != 0:
+ hit_points.append(bresenham_points[i-1])
+ break # don't use this bresenham point
+
+ if occupancy_grid[bresenham_points[i,0], bresenham_points[i,1]] == 1: # hit if it is void or occupied #! THINK IF THIS IS A GOOD ASSUMPTION
+
+ for j in range(min(4, len(bresenham_points) - i - 1)): # add 4 points in front of hit
+ free_points.append(bresenham_points[i+j])
+
+ actual_hit_points.append(bresenham_points[i + j + 1])
+ hit_points.append(bresenham_points[i + j + 1])
+
+ break
+ else: # no hits
+ free_point = bresenham_points[i]
+ free_points.append(free_point)
+
+ if i == len(bresenham_points) - 1:
+ hit_points.append(end_point) # need to add this for proper flooding for vis mask
+ break
+
+
+ # Convert to np.array
+ free_points = np.array(free_points)
+ hit_points = np.array(hit_points)
+ actual_hit_points = np.array(actual_hit_points)
+ return free_points, hit_points, actual_hit_points
+
+if __name__ == "__main__":
+ # Argparse
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--dataset_folder", type=str, default="/path/to/raycast")
+ parser.add_argument("--class_idx_building", type=int, default=4)
+ parser.add_argument("--num_workers", type=int, default=60)
+ parser.add_argument("--location", type=str, default="los_angeles")
+
+ args = parser.parse_args()
+
+ dataset_folder = Path(args.dataset_folder)
+ bev_folder = dataset_folder / args.location / "semantic_masks"
+ output_folder = dataset_folder / args.location / "flood_fill"
+
+ output_folder.mkdir(exist_ok=True, parents=True)
+
+ def generate_mask(filepath):
+ mask = np.load(filepath)
+ building_grid = mask[..., args.class_idx_building]
+ try:
+ flooded_image = get_raycast_building_mask(building_grid)
+ except:
+ raise Exception(f"Error in {filepath}")
+
+ if flooded_image is not None:
+ output_file = output_folder / filepath.name
+ np.save(output_file, flooded_image)
+ else:
+ print("No flood fill generated")
+
+ bev_files = list(bev_folder.iterdir())
+
+ with Pool(args.num_workers) as p:
+ for _ in tqdm(p.imap_unordered(generate_mask, bev_files), total=len(bev_files)):
+ pass
+
diff --git a/mapper/utils/geometry.py b/mapper/utils/geometry.py
new file mode 100644
index 0000000000000000000000000000000000000000..9bcbcba7c41e689e9dd9e35fe33e7787fdd13b03
--- /dev/null
+++ b/mapper/utils/geometry.py
@@ -0,0 +1,68 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+
+import numpy as np
+import torch
+
+
+def from_homogeneous(points, eps: float = 1e-8):
+ """Remove the homogeneous dimension of N-dimensional points.
+ Args:
+ points: torch.Tensor or numpy.ndarray with size (..., N+1).
+ Returns:
+ A torch.Tensor or numpy ndarray with size (..., N).
+ """
+ return points[..., :-1] / (points[..., -1:] + eps)
+
+
+def to_homogeneous(points):
+ """Convert N-dimensional points to homogeneous coordinates.
+ Args:
+ points: torch.Tensor or numpy.ndarray with size (..., N).
+ Returns:
+ A torch.Tensor or numpy.ndarray with size (..., N+1).
+ """
+ if isinstance(points, torch.Tensor):
+ pad = points.new_ones(points.shape[:-1] + (1,))
+ return torch.cat([points, pad], dim=-1)
+ elif isinstance(points, np.ndarray):
+ pad = np.ones((points.shape[:-1] + (1,)), dtype=points.dtype)
+ return np.concatenate([points, pad], axis=-1)
+ else:
+ raise ValueError
+
+
+@torch.jit.script
+def undistort_points(pts, dist):
+ dist = dist.unsqueeze(-2) # add point dimension
+ ndist = dist.shape[-1]
+ undist = pts
+ valid = torch.ones(pts.shape[:-1], device=pts.device, dtype=torch.bool)
+ if ndist > 0:
+ k1, k2 = dist[..., :2].split(1, -1)
+ r2 = torch.sum(pts**2, -1, keepdim=True)
+ radial = k1 * r2 + k2 * r2**2
+ undist = undist + pts * radial
+
+ # The distortion model is supposedly only valid within the image
+ # boundaries. Because of the negative radial distortion, points that
+ # are far outside of the boundaries might actually be mapped back
+ # within the image. To account for this, we discard points that are
+ # beyond the inflection point of the distortion model,
+ # e.g. such that d(r + k_1 r^3 + k2 r^5)/dr = 0
+ limited = ((k2 > 0) & ((9 * k1**2 - 20 * k2) > 0)) | ((k2 <= 0) & (k1 > 0))
+ limit = torch.abs(
+ torch.where(
+ k2 > 0,
+ (torch.sqrt(9 * k1**2 - 20 * k2) - 3 * k1) / (10 * k2),
+ 1 / (3 * k1),
+ )
+ )
+ valid = valid & torch.squeeze(~limited | (r2 < limit), -1)
+
+ if ndist > 2:
+ p12 = dist[..., 2:]
+ p21 = p12.flip(-1)
+ uv = torch.prod(pts, -1, keepdim=True)
+ undist = undist + 2 * p12 * uv + p21 * (r2 + 2 * pts**2)
+
+ return undist, valid
diff --git a/mapper/utils/io.py b/mapper/utils/io.py
new file mode 100644
index 0000000000000000000000000000000000000000..bae91c57917b8c794ee9340196b8603fa8a4c289
--- /dev/null
+++ b/mapper/utils/io.py
@@ -0,0 +1,61 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+
+import json
+import requests
+import shutil
+from pathlib import Path
+
+import cv2
+import numpy as np
+import torch
+from tqdm.auto import tqdm
+
+from .. import logger
+
+DATA_URL = "https://cvg-data.inf.ethz.ch/OrienterNet_CVPR2023"
+
+
+def read_image(path, grayscale=False):
+ if grayscale:
+ mode = cv2.IMREAD_GRAYSCALE
+ else:
+ mode = cv2.IMREAD_COLOR
+ image = cv2.imread(str(path), mode)
+ if image is None:
+ raise ValueError(f"Cannot read image {path}.")
+ if not grayscale and len(image.shape) == 3:
+ image = np.ascontiguousarray(image[:, :, ::-1]) # BGR to RGB
+ return image
+
+
+def write_torch_image(path, image):
+ image_cv2 = np.round(image.clip(0, 1) * 255).astype(int)[..., ::-1]
+ cv2.imwrite(str(path), image_cv2)
+
+
+class JSONEncoder(json.JSONEncoder):
+ def default(self, obj):
+ if isinstance(obj, (np.ndarray, torch.Tensor)):
+ return obj.tolist()
+ elif isinstance(obj, np.generic):
+ return obj.item()
+ return json.JSONEncoder.default(self, obj)
+
+
+def write_json(path, data):
+ with open(path, "w") as f:
+ json.dump(data, f, cls=JSONEncoder)
+
+
+def download_file(url, path):
+ path = Path(path)
+ if path.is_dir():
+ path = path / Path(url).name
+ path.parent.mkdir(exist_ok=True, parents=True)
+ logger.info("Downloading %s to %s.", url, path)
+ with requests.get(url, stream=True) as r:
+ total_length = int(r.headers.get("Content-Length"))
+ with tqdm.wrapattr(r.raw, "read", total=total_length, desc="") as raw:
+ with open(path, "wb") as output:
+ shutil.copyfileobj(raw, output)
+ return path
diff --git a/mapper/utils/viz_2d.py b/mapper/utils/viz_2d.py
new file mode 100644
index 0000000000000000000000000000000000000000..d72529625d63ee713bfbe76f0426a82288ba773d
--- /dev/null
+++ b/mapper/utils/viz_2d.py
@@ -0,0 +1,106 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+
+# Adapted from Hierarchical-Localization, Paul-Edouard Sarlin, ETH Zurich
+# https://github.com/cvg/Hierarchical-Localization/blob/master/hloc/utils/viz.py
+# Released under the Apache License 2.0
+
+import numpy as np
+import torch
+
+
+def features_to_RGB(*Fs, masks=None, skip=1):
+ """Project a list of d-dimensional feature maps to RGB colors using PCA."""
+ from sklearn.decomposition import PCA
+
+ def normalize(x):
+ return x / np.linalg.norm(x, axis=-1, keepdims=True)
+
+ if masks is not None:
+ assert len(Fs) == len(masks)
+
+ flatten = []
+ for i, F in enumerate(Fs):
+ c, h, w = F.shape
+ F = np.rollaxis(F, 0, 3)
+ F_flat = F.reshape(-1, c)
+ if masks is not None and masks[i] is not None:
+ mask = masks[i]
+ assert mask.shape == F.shape[:2]
+ F_flat = F_flat[mask.reshape(-1)]
+ flatten.append(F_flat)
+ flatten = np.concatenate(flatten, axis=0)
+ flatten = normalize(flatten)
+
+ pca = PCA(n_components=3)
+ if skip > 1:
+ pca.fit(flatten[::skip])
+ flatten = pca.transform(flatten)
+ else:
+ flatten = pca.fit_transform(flatten)
+ flatten = (normalize(flatten) + 1) / 2
+
+ Fs_rgb = []
+ for i, F in enumerate(Fs):
+ h, w = F.shape[-2:]
+ if masks is None or masks[i] is None:
+ F_rgb, flatten = np.split(flatten, [h * w], axis=0)
+ F_rgb = F_rgb.reshape((h, w, 3))
+ else:
+ F_rgb = np.zeros((h, w, 3))
+ indices = np.where(masks[i])
+ F_rgb[indices], flatten = np.split(flatten, [len(indices[0])], axis=0)
+ F_rgb = np.concatenate([F_rgb, masks[i][..., None]], axis=-1)
+ Fs_rgb.append(F_rgb)
+ assert flatten.shape[0] == 0, flatten.shape
+ return Fs_rgb
+
+
+def one_hot_argmax_to_rgb(y, num_class):
+ '''
+ Args:
+ probs: (B, C, H, W)
+ num_class: int
+ 0: road 0
+1: crossing 1
+2: explicit_pedestrian 2
+4: building
+6: terrain
+7: parking `
+
+ '''
+
+
+ class_colors = {
+ 'road': (0, 0, 0), # 0: Black
+ 'crossing': (255, 0, 0), # 1; Red
+ 'explicit_pedestrian': (255, 255, 0), # 2: Yellow
+ # 'explicit_void': (128, 128, 128), # 3: White
+ 'park': (0, 255, 0), # 4: Green
+ 'building': (255, 0, 255), # 5: Magenta
+ 'water': (0, 0, 255), # 6: Blue
+ 'terrain': (0, 255, 255), # 7: Cyan
+ 'parking': (170, 170, 170), # 8: Dark Grey
+ 'train': (85, 85, 85) , # 9: Light Grey
+ 'predicted_void': (256, 256, 256)
+ }
+ class_colors = class_colors.values()
+ class_colors = [torch.tensor(x) for x in class_colors]
+
+ argmaxed = torch.argmax((y > 0.5).float(), dim=1) # Take argmax
+ argmaxed[torch.all(y <= 0.5, dim=1)] = num_class
+ # print(argmaxed.shape)
+
+ seg_rgb = torch.ones(
+ (
+ argmaxed.shape[0],
+ 3,
+ argmaxed.shape[1],
+ argmaxed.shape[2],
+ )
+ ) * 256
+ for i in range(num_class + 1):
+ seg_rgb[:, 0, :, :][argmaxed == i] = class_colors[i][0]
+ seg_rgb[:, 1, :, :][argmaxed == i] = class_colors[i][1]
+ seg_rgb[:, 2, :, :][argmaxed == i] = class_colors[i][2]
+
+ return seg_rgb
diff --git a/mapper/utils/wrappers.py b/mapper/utils/wrappers.py
new file mode 100644
index 0000000000000000000000000000000000000000..cdd8480fd1257f216f929399c036bc057c9a2b51
--- /dev/null
+++ b/mapper/utils/wrappers.py
@@ -0,0 +1,348 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+
+# Adapted from PixLoc, Paul-Edouard Sarlin, ETH Zurich
+# https://github.com/cvg/pixloc
+# Released under the Apache License 2.0
+
+"""
+Convenience classes for an SE3 pose and a pinhole Camera with lens distortion.
+Based on PyTorch tensors: differentiable, batched, with GPU support.
+"""
+
+import functools
+import inspect
+import math
+from typing import Dict, List, NamedTuple, Tuple, Union
+
+import numpy as np
+import torch
+
+from .geometry import undistort_points
+
+
+def autocast(func):
+ """Cast the inputs of a TensorWrapper method to PyTorch tensors
+ if they are numpy arrays. Use the device and dtype of the wrapper.
+ """
+
+ @functools.wraps(func)
+ def wrap(self, *args):
+ device = torch.device("cpu")
+ dtype = None
+ if isinstance(self, TensorWrapper):
+ if self._data is not None:
+ device = self.device
+ dtype = self.dtype
+ elif not inspect.isclass(self) or not issubclass(self, TensorWrapper):
+ raise ValueError(self)
+
+ cast_args = []
+ for arg in args:
+ if isinstance(arg, np.ndarray):
+ arg = torch.from_numpy(arg)
+ arg = arg.to(device=device, dtype=dtype)
+ cast_args.append(arg)
+ return func(self, *cast_args)
+
+ return wrap
+
+
+class TensorWrapper:
+ _data = None
+
+ @autocast
+ def __init__(self, data: torch.Tensor):
+ self._data = data
+
+ @property
+ def shape(self):
+ return self._data.shape[:-1]
+
+ @property
+ def device(self):
+ return self._data.device
+
+ @property
+ def dtype(self):
+ return self._data.dtype
+
+ def __getitem__(self, index):
+ return self.__class__(self._data[index])
+
+ def __setitem__(self, index, item):
+ self._data[index] = item.data
+
+ def to(self, *args, **kwargs):
+ return self.__class__(self._data.to(*args, **kwargs))
+
+ def cpu(self):
+ return self.__class__(self._data.cpu())
+
+ def cuda(self):
+ return self.__class__(self._data.cuda())
+
+ def pin_memory(self):
+ return self.__class__(self._data.pin_memory())
+
+ def float(self):
+ return self.__class__(self._data.float())
+
+ def double(self):
+ return self.__class__(self._data.double())
+
+ def detach(self):
+ return self.__class__(self._data.detach())
+
+ @classmethod
+ def stack(cls, objects: List, dim=0, *, out=None):
+ data = torch.stack([obj._data for obj in objects], dim=dim, out=out)
+ return cls(data)
+
+ @classmethod
+ def __torch_function__(cls, func, types, args=(), kwargs=None):
+ if kwargs is None:
+ kwargs = {}
+ if func is torch.stack:
+ return cls.stack(*args, **kwargs)
+ else:
+ return NotImplemented
+
+
+class Pose(TensorWrapper):
+ def __init__(self, data: torch.Tensor):
+ assert data.shape[-1] == 12
+ super().__init__(data)
+
+ @classmethod
+ @autocast
+ def from_Rt(cls, R: torch.Tensor, t: torch.Tensor):
+ """Pose from a rotation matrix and translation vector.
+ Accepts numpy arrays or PyTorch tensors.
+
+ Args:
+ R: rotation matrix with shape (..., 3, 3).
+ t: translation vector with shape (..., 3).
+ """
+ assert R.shape[-2:] == (3, 3)
+ assert t.shape[-1] == 3
+ assert R.shape[:-2] == t.shape[:-1]
+ data = torch.cat([R.flatten(start_dim=-2), t], -1)
+ return cls(data)
+
+ @classmethod
+ def from_4x4mat(cls, T: torch.Tensor):
+ """Pose from an SE(3) transformation matrix.
+ Args:
+ T: transformation matrix with shape (..., 4, 4).
+ """
+ assert T.shape[-2:] == (4, 4)
+ R, t = T[..., :3, :3], T[..., :3, 3]
+ return cls.from_Rt(R, t)
+
+ @classmethod
+ def from_colmap(cls, image: NamedTuple):
+ """Pose from a COLMAP Image."""
+ return cls.from_Rt(image.qvec2rotmat(), image.tvec)
+
+ @property
+ def R(self) -> torch.Tensor:
+ """Underlying rotation matrix with shape (..., 3, 3)."""
+ rvec = self._data[..., :9]
+ return rvec.reshape(rvec.shape[:-1] + (3, 3))
+
+ @property
+ def t(self) -> torch.Tensor:
+ """Underlying translation vector with shape (..., 3)."""
+ return self._data[..., -3:]
+
+ def inv(self) -> "Pose":
+ """Invert an SE(3) pose."""
+ R = self.R.transpose(-1, -2)
+ t = -(R @ self.t.unsqueeze(-1)).squeeze(-1)
+ return self.__class__.from_Rt(R, t)
+
+ def compose(self, other: "Pose") -> "Pose":
+ """Chain two SE(3) poses: T_B2C.compose(T_A2B) -> T_A2C."""
+ R = self.R @ other.R
+ t = self.t + (self.R @ other.t.unsqueeze(-1)).squeeze(-1)
+ return self.__class__.from_Rt(R, t)
+
+ @autocast
+ def transform(self, p3d: torch.Tensor) -> torch.Tensor:
+ """Transform a set of 3D points.
+ Args:
+ p3d: 3D points, numpy array or PyTorch tensor with shape (..., 3).
+ """
+ assert p3d.shape[-1] == 3
+ # assert p3d.shape[:-2] == self.shape # allow broadcasting
+ return p3d @ self.R.transpose(-1, -2) + self.t.unsqueeze(-2)
+
+ def __matmul__(
+ self, other: Union["Pose", torch.Tensor]
+ ) -> Union["Pose", torch.Tensor]:
+ """Transform a set of 3D points: T_A2B * p3D_A -> p3D_B.
+ or chain two SE(3) poses: T_B2C @ T_A2B -> T_A2C."""
+ if isinstance(other, self.__class__):
+ return self.compose(other)
+ else:
+ return self.transform(other)
+
+ def numpy(self) -> Tuple[np.ndarray]:
+ return self.R.numpy(), self.t.numpy()
+
+ def magnitude(self) -> Tuple[torch.Tensor]:
+ """Magnitude of the SE(3) transformation.
+ Returns:
+ dr: rotation anngle in degrees.
+ dt: translation distance in meters.
+ """
+ trace = torch.diagonal(self.R, dim1=-1, dim2=-2).sum(-1)
+ cos = torch.clamp((trace - 1) / 2, -1, 1)
+ dr = torch.acos(cos).abs() / math.pi * 180
+ dt = torch.norm(self.t, dim=-1)
+ return dr, dt
+
+ def __repr__(self):
+ return f"Pose: {self.shape} {self.dtype} {self.device}"
+
+
+class Camera(TensorWrapper):
+ eps = 1e-4
+
+ def __init__(self, data: torch.Tensor):
+ assert data.shape[-1] in {6, 8, 10}
+ super().__init__(data)
+
+ @classmethod
+ def from_dict(cls, camera: Union[Dict, NamedTuple]):
+ """Camera from a COLMAP Camera tuple or dictionary.
+ We assume that the origin (0, 0) is the center of the top-left pixel.
+ This is different from COLMAP.
+ """
+ if isinstance(camera, tuple):
+ camera = camera._asdict()
+
+ model = camera["model"]
+ params = camera["params"]
+
+ if model in ["OPENCV", "PINHOLE"]:
+ (fx, fy, cx, cy), params = np.split(params, [4])
+ elif model in ["SIMPLE_PINHOLE", "SIMPLE_RADIAL", "RADIAL"]:
+ (f, cx, cy), params = np.split(params, [3])
+ fx = fy = f
+ if model == "SIMPLE_RADIAL":
+ params = np.r_[params, 0.0]
+ else:
+ raise NotImplementedError(model)
+
+ data = np.r_[
+ camera["width"], camera["height"], fx, fy, cx - 0.5, cy - 0.5, params
+ ]
+ return cls(data)
+
+ @property
+ def size(self) -> torch.Tensor:
+ """Size (width height) of the images, with shape (..., 2)."""
+ return self._data[..., :2]
+
+ @property
+ def f(self) -> torch.Tensor:
+ """Focal lengths (fx, fy) with shape (..., 2)."""
+ return self._data[..., 2:4]
+
+ @property
+ def c(self) -> torch.Tensor:
+ """Principal points (cx, cy) with shape (..., 2)."""
+ return self._data[..., 4:6]
+
+ @property
+ def dist(self) -> torch.Tensor:
+ """Distortion parameters, with shape (..., {0, 2, 4})."""
+ return self._data[..., 6:]
+
+ def scale(self, scales: Union[float, int, Tuple[Union[float, int]]]):
+ """Update the camera parameters after resizing an image."""
+ if isinstance(scales, (int, float)):
+ scales = (scales, scales)
+ s = self._data.new_tensor(scales)
+ data = torch.cat(
+ [self.size * s, self.f * s, (self.c + 0.5) * s - 0.5, self.dist], -1
+ )
+ return self.__class__(data)
+
+ def crop(self, left_top: Tuple[float], size: Tuple[int]):
+ """Update the camera parameters after cropping an image."""
+ left_top = self._data.new_tensor(left_top)
+ size = self._data.new_tensor(size)
+ data = torch.cat([size, self.f, self.c - left_top, self.dist], -1)
+ return self.__class__(data)
+
+ def flip(self):
+ """Update the camera parameters after flipping an image."""
+ data = self._data.clone()
+ data[..., 4] = self.size[..., 0] - self.c[..., 0] - 1
+ return self.__class__(data)
+
+ @autocast
+ def in_image(self, p2d: torch.Tensor):
+ """Check if 2D points are within the image boundaries."""
+ assert p2d.shape[-1] == 2
+ # assert p2d.shape[:-2] == self.shape # allow broadcasting
+ size = self.size.unsqueeze(-2)
+ valid = torch.all((p2d >= 0) & (p2d <= (size - 1)), -1)
+ return valid
+
+ @autocast
+ def project(self, p3d: torch.Tensor) -> Tuple[torch.Tensor]:
+ """Project 3D points into the camera plane and check for visibility."""
+ z = p3d[..., -1]
+ valid = z > self.eps
+ z = z.clamp(min=self.eps)
+ p2d = p3d[..., :-1] / z.unsqueeze(-1)
+ return p2d, valid
+
+ def J_project(self, p3d: torch.Tensor):
+ x, y, z = p3d[..., 0], p3d[..., 1], p3d[..., 2]
+ zero = torch.zeros_like(z)
+ J = torch.stack([1 / z, zero, -x / z**2, zero, 1 / z, -y / z**2], dim=-1)
+ J = J.reshape(p3d.shape[:-1] + (2, 3))
+ return J # N x 2 x 3
+
+ @autocast
+ def undistort(self, pts: torch.Tensor) -> Tuple[torch.Tensor]:
+ """Undistort normalized 2D coordinates
+ and check for validity of the distortion model.
+ """
+ assert pts.shape[-1] == 2
+ # assert pts.shape[:-2] == self.shape # allow broadcasting
+ return undistort_points(pts, self.dist)
+
+ @autocast
+ def denormalize(self, p2d: torch.Tensor) -> torch.Tensor:
+ """Convert normalized 2D coordinates into pixel coordinates."""
+ return p2d * self.f.unsqueeze(-2) + self.c.unsqueeze(-2)
+
+ @autocast
+ def normalize(self, p2d: torch.Tensor) -> torch.Tensor:
+ """Convert pixel coordinates into normalized 2D coordinates."""
+ return (p2d - self.c.unsqueeze(-2)) / self.f.unsqueeze(-2)
+
+ def J_denormalize(self):
+ return torch.diag_embed(self.f).unsqueeze(-3) # 1 x 2 x 2
+
+ @autocast
+ def world2image(self, p3d: torch.Tensor) -> Tuple[torch.Tensor]:
+ """Transform 3D points into 2D pixel coordinates."""
+ p2d, visible = self.project(p3d)
+ p2d, mask = self.undistort(p2d)
+ p2d = self.denormalize(p2d)
+ valid = visible & mask & self.in_image(p2d)
+ return p2d, valid
+
+ def J_world2image(self, p3d: torch.Tensor):
+ p2d_dist, valid = self.project(p3d)
+ J = self.J_denormalize() @ self.J_undistort(p2d_dist) @ self.J_project(p3d)
+ return J, valid
+
+ def __repr__(self):
+ return f"Camera {self.shape} {self.dtype} {self.device}"
diff --git a/mia/Dockerfile b/mia/Dockerfile
new file mode 100644
index 0000000000000000000000000000000000000000..44147fab853f33ab52929dc68d440a3706f64857
--- /dev/null
+++ b/mia/Dockerfile
@@ -0,0 +1,70 @@
+# Typical Usage:
+# docker image build . -t mia:release
+# docker run -v /home/:/home/ --network=bridge -it mia:release
+# Add '--gpus all' for gpu support
+
+# For CPU
+# FROM ubuntu:20.04
+# For GPU
+FROM nvidia/cuda:12.2.2-runtime-ubuntu20.04
+
+ARG DEBIAN_FRONTEND=noninteractive
+
+# Install apt-getable dependencies
+RUN apt-get update \
+ && apt-get install -y \
+ build-essential \
+ cmake \
+ git \
+ libeigen3-dev \
+ libopencv-dev \
+ libceres-dev \
+ python3-dev \
+ curl \
+ pkg-config \
+ libcairo2-dev \
+ software-properties-common \
+ && apt-get clean \
+ && rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/*
+
+# Mapmachine requirements
+RUN add-apt-repository ppa:ubuntugis/ppa && \
+ apt-get update && \
+ apt-get -y install libgeos-dev
+
+RUN add-apt-repository ppa:deadsnakes/ppa && \
+ apt-get update && \
+ apt install -y python3.9-dev && \
+ curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py && \
+ python3.9 get-pip.py
+
+ARG REINSTALL_MAPMACHINE=1
+RUN pip3.9 install git+https://github.com/tonyzzzzzz/map-machine
+
+WORKDIR /home/
+
+# OrienterNet Requirements TODO: Install directly from our requirements once our repo is public
+
+RUN git clone https://github.com/mapillary/OpenSfM.git && cd OpenSfM && \
+ pip3.9 install -r requirements.txt
+
+RUN git clone https://github.com/facebookresearch/OrienterNet.git && cd OrienterNet && \
+ pip3 install -r requirements/full.txt
+
+# MapPerceptionNet extra requirements
+RUN pip3.9 install geojson shapely geopandas mercantile turfpy vt2geojson folium \
+ geopy gradio_client pyarrow cloudpickle==2.0.0 urllib3~=1.25.6 scikit-image filelock hydra-core
+
+# Earth Engine requirements (Required if sattelite image support is needed)
+
+RUN echo "deb [signed-by=/usr/share/keyrings/cloud.google.gpg] https://packages.cloud.google.com/apt cloud-sdk main" | \
+ tee -a /etc/apt/sources.list.d/google-cloud-sdk.list && \
+ curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | \
+ gpg --dearmor -o /usr/share/keyrings/cloud.google.gpg && \
+ apt-get update -y && apt-get install google-cloud-sdk -y
+
+RUN pip3.9 install earthengine-api
+
+# Run these once you are in the docker with your credentials and google earth project
+# earthengine authenticate
+# gcloud auth application-default set-quota-project PROJECT_ID
\ No newline at end of file
diff --git a/mia/__init__.py b/mia/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..46ed54d93bebb65cb4c1ebe64e3e14ef691ea6ad
--- /dev/null
+++ b/mia/__init__.py
@@ -0,0 +1,49 @@
+import os, sys
+
+sys.path.append(os.path.dirname(os.path.realpath(__file__)))
+import logging
+from dataclasses import dataclass
+
+@dataclass
+class Colors:
+ grey: str = "\x1b[38;20m"
+ blue: str = "\x1b[34;20m"
+ bold_blue: str = "\x1b[34;1m"
+ yellow: str = "\x1b[33;20m"
+ red: str = "\x1b[31;20m"
+ bold_red: str = "\x1b[31;1m"
+ reset: str = "\x1b[0m"
+
+
+class ColorFormatter(logging.Formatter):
+
+ colors = Colors()
+ format = "[%(asctime)s %(name)s %(levelname)s] %(message)s"
+ datefmt="%Y-%m-%d %H:%M:%S"
+
+ FORMATS = {
+ logging.DEBUG: colors.grey + format + colors.reset,
+ logging.INFO: colors.grey + format + colors.reset,
+ logging.WARNING: colors.yellow + format + colors.reset,
+ logging.ERROR: colors.red + format + colors.reset,
+ logging.CRITICAL: colors.bold_red + format + colors.reset
+ }
+
+ def format(self, record):
+ log_fmt = self.FORMATS.get(record.levelno)
+ formatter = logging.Formatter(log_fmt, datefmt=self.datefmt)
+ return formatter.format(record)
+
+formatter = logging.Formatter(
+ fmt="[%(asctime)s %(name)s %(levelname)s] %(message)s",
+ datefmt="%Y-%m-%d %H:%M:%S",
+)
+
+handler = logging.StreamHandler()
+handler.setFormatter(ColorFormatter())
+handler.setLevel(logging.INFO)
+
+logger = logging.getLogger("mia")
+logger.setLevel(logging.INFO)
+logger.addHandler(handler)
+logger.propagate = False
\ No newline at end of file
diff --git a/mia/bev/get_bev.py b/mia/bev/get_bev.py
new file mode 100644
index 0000000000000000000000000000000000000000..e3d69693ad22a1f7f7b3fa485b40751954b9013b
--- /dev/null
+++ b/mia/bev/get_bev.py
@@ -0,0 +1,619 @@
+"""Script to get BEV images from a dataset of locations.
+
+Example usage:
+ python3.9 -m mia.bev.get_bev
+"""
+
+import argparse
+import multiprocessing as mp
+from pathlib import Path
+import io
+import os
+import requests
+import contextlib
+import traceback
+import colour
+
+import numpy as np
+from matplotlib import pyplot as plt
+import pandas as pd
+import geopandas as gpd
+import torch.nn as nn
+import torch
+from tqdm import tqdm
+from filelock import FileLock
+from math import sqrt, ceil
+import svgwrite
+import cairosvg
+from PIL import Image
+from xml.etree import ElementTree as ET
+from pyproj.transformer import Transformer
+from shapely.geometry import box
+from omegaconf import OmegaConf
+import urllib3
+
+from map_machine.map_configuration import MapConfiguration
+from map_machine.scheme import Scheme
+from map_machine.geometry.boundary_box import BoundaryBox
+from map_machine.osm.osm_getter import NetworkError
+from map_machine.osm.osm_reader import OSMData
+from map_machine.geometry.flinger import MercatorFlinger
+from map_machine.pictogram.icon import ShapeExtractor
+from map_machine.workspace import workspace
+from map_machine.mapper import Map
+from map_machine.constructor import Constructor
+
+from .. import logger
+from .image import center_crop_to_size, center_pad
+
+# MUST match colors from map rendering style
+COLORS = {
+ "road": "#000",
+ "crossing": "#F00",
+ "explicit_pedestrian": "#FF0",
+ "park": "#0F0",
+ "building": "#F0F",
+ "water": "#00F",
+ "terrain": "#0FF",
+ "parking": "#AAA",
+ "train": "#555"
+}
+
+# While the color mapping above must match what is in the
+# rendering style, the pretty colors below are just for visualization
+# purposes and can easily be changed below without worrying.
+# Colors set to None will not be rendered in rendered masks
+PRETTY_COLORS = {
+ "road": "#444",
+ "crossing": "#F4A261",
+ "explicit_pedestrian": "#E9C46A",
+ "park": None,
+ "building": "#E76F51",
+ "water": None,
+ "terrain": "#2A9D8F",
+ "parking": "#CCC",
+ "train": None
+}
+
+# Better order for visualization
+VIS_ORDER = ["terrain", "water", "park", "parking", "train",
+ "road", "explicit_pedestrian", "crossing", "building"]
+
+def checkColor(code):
+
+ def check_ele(ele):
+ isColor = False
+ if "stroke" in ele.attribs:
+ if ele.attribs["stroke"] != "none":
+ color = colour.Color(ele.attribs["stroke"])
+ isColor |= color == colour.Color(code)
+
+ if "fill" in ele.attribs:
+ if ele.attribs["fill"] != "none":
+ color = colour.Color(ele.attribs["fill"])
+ isColor |= color == colour.Color(code)
+
+ return isColor
+
+ return check_ele
+
+def hex2rgb(hex_str):
+ hex_str = hex_str.lstrip('#')
+ if len(hex_str) == 3:
+ hex_str = "".join([hex_str[i//2] for i in range(6)])
+ return tuple(int(hex_str[i:i+2], 16) for i in (0, 2, 4))
+
+def mask2rgb(mask, pretty=True):
+ H,W,N = mask.shape
+ rgb = np.ones((H,W,3), dtype=np.uint8)*255
+ cmap = PRETTY_COLORS if pretty else COLORS
+ key2mask_i = dict(zip(cmap.keys(), range(N)))
+ for k in VIS_ORDER:
+ if cmap[k]:
+ rgb[mask[:,:, key2mask_i[k]]>0.5] = (np.array(hex2rgb(cmap[k])))
+
+ return rgb
+
+def draw_bev(bbox: BoundaryBox, osm_data: OSMData,
+ configuration: MapConfiguration, meters_per_pixel: float, heading: float):
+ """Rasterize OSM data as a BEV image"""
+ lat = bbox.center()[0]
+ # Equation rearranged from https://wiki.openstreetmap.org/wiki/Zoom_levels
+ # To get zoom level given meters_per_pixel
+ z = np.log2(np.abs(osm_data.equator_length*np.cos(np.deg2rad(lat))/meters_per_pixel/256))
+ flinger = MercatorFlinger(bbox, z, osm_data.equator_length)
+
+ size = flinger.size
+ svg: svgwrite.Drawing = svgwrite.Drawing(None, size) # None since we are not saving an svg file
+
+ icon_extractor: ShapeExtractor = ShapeExtractor(
+ workspace.ICONS_PATH, workspace.ICONS_CONFIG_PATH
+ )
+ constructor: Constructor = Constructor(
+ osm_data=osm_data,
+ flinger=flinger,
+ extractor=icon_extractor,
+ configuration=configuration,
+ )
+ constructor.construct()
+ map_: Map = Map(flinger=flinger, svg=svg, configuration=configuration)
+ try:
+ imgs = []
+
+ map_.draw(constructor)
+
+ # svg.defs.add(svgwrite.container.Style(f"transform: rotate({str(heading)}deg)"))
+ for ele in svg.elements:
+ ele.rotate(360 - heading, (size[0]/2, size[1]/2))
+
+ for k, v in COLORS.items():
+ svg_new = svg.copy()
+ svg_new.elements = list(filter(checkColor(v), svg_new.elements))
+
+ png_byte_string = cairosvg.svg2png(bytestring=svg_new.tostring(),
+ output_width=size[0],
+ output_height=size[1]) # convert svg to png
+ img = Image.open(io.BytesIO(png_byte_string))
+
+ imgs.append(img)
+
+ except Exception as e:
+ # Prepare the stack trace
+ stack_trace = traceback.format_exc()
+ logger.error(f"Failed to render BEV for bbox {bbox.get_format()}. Exception: {repr(e)}. Skipping.. Stack trace: {stack_trace}")
+ return None, None
+
+ return imgs, svg
+
+
+def process_img(img, num_pixels, heading=None):
+ """Rotate + Crop to correct for heading and ensure correct dimensions"""
+
+ img = center_pad(img, num_pixels, num_pixels)
+ s = min(img.size)
+ squared_img = center_crop_to_size(img, s, s) # Ensure it is square before rotating (Perhaps not needed)
+ if heading:
+ squared_img = squared_img.rotate(heading, expand=False, resample=Image.Resampling.BILINEAR)
+ center_cropped_bev_img = center_crop_to_size(squared_img, num_pixels, num_pixels)
+ # robot_cropped_bev_img = center_cropped_bev_img.crop((0, 0, num_pixels, num_pixels/2)) # left, upper, right, lower
+ return center_cropped_bev_img
+
+
+def get_satellite_from_bbox(bbox, output_fp, num_pixels, heading):
+ # TODO: This method does not always produce a full satellite image.
+ # We need something more consistent like mapbox but free.
+
+ region = ee.Geometry.Rectangle(bbox, proj="EPSG:4326", geodesic=False)
+ # Load a satellite image collection, filter it by date and region, then select the first image
+ image = ee.ImageCollection('USDA/NAIP/DOQQ') \
+ .filterBounds(region) \
+ .filterDate('2022-01-01', '2022-12-31') \
+ .sort('CLOUDY_PIXEL_PERCENTAGE') \
+ .first().select(['R', 'G', 'B'])
+
+ # Reproject the image to a common projection (e.g., EPSG:4326)
+ image = image.reproject(crs='EPSG:4326', scale=0.5)
+
+ # Get the image URL
+ url = image.getThumbURL({'min': 0, 'max': 255, 'region': region.getInfo()['coordinates']})
+
+ # Download the image to your desktop
+ response = requests.get(url)
+ img = Image.open(io.BytesIO(response.content))
+ robot_cropped_bev_img = process_img(img, num_pixels, heading)
+ robot_cropped_bev_img.save(output_fp)
+
+
+def get_data(address: str, parameters: dict[str, str]) -> bytes:
+ """
+ Construct Internet page URL and get its descriptor.
+
+ :param address: URL without parameters
+ :param parameters: URL parameters
+ :return: connection descriptor
+ """
+ for _ in range(50):
+ http = urllib3.PoolManager()
+
+ urllib3.disable_warnings()
+
+ try:
+ result = http.request("GET", address, fields=parameters)
+ except urllib3.exceptions.MaxRetryError:
+ continue
+
+ if result.status == 200:
+ break
+ else:
+ print(result.data)
+ raise NetworkError(f"Cannot download data: {result.status} {result.reason}")
+
+ http.clear()
+ return result.data
+
+
+def get_osm_data(bbox: BoundaryBox, osm_output_fp: Path,
+ overwrite=False, use_lock=False) -> OSMData:
+ """
+ Get OSM data within bounding box from usingoverpass APIs and
+ write data to osm_output_fp.
+ """
+
+ OVERPASS_ENDPOINTS = [
+ "http://overpass-api.de/api/map",
+ "http://overpass.kumi.systems/api/map",
+ "http://maps.mail.ru/osm/tools/overpass/api/map"
+ ]
+
+ RETRIES = 10
+ osm_data = None
+ overpass_endpoints_i = 0
+
+ for retry in range(RETRIES):
+ try:
+ # fetch or load from cache
+ # A lock is needed if we are using multiple processes without store_osm_per_id
+ # Since multiple workers may share the same cached OSM file.
+ # Note: Can optimize locking further by implementing a readers-writer lock scheme
+ if use_lock:
+ lock_fp = osm_output_fp.parent.parent / (osm_output_fp.parent.name + "_tmp_locks") / (osm_output_fp.name + ".lock")
+ lock = FileLock(lock_fp)
+ else:
+ lock = contextlib.nullcontext()
+
+ with lock:
+ if not overwrite and osm_output_fp.is_file():
+ with osm_output_fp.open(encoding="utf-8") as output_file:
+ xml_str = output_file.read()
+ else:
+ content: bytes = get_data(
+ address=OVERPASS_ENDPOINTS[overpass_endpoints_i],
+ parameters={"bbox": bbox.get_format()}
+ )
+
+ xml_str = content.decode("utf-8")
+
+ if not content.startswith(b" None:
+ """Get BEV image from a boundary box. Optionally rotate, crop and save the extracted semantic mask."""
+
+ if osm_data is None:
+ if osm_output_fp.is_file():
+ # Load from cache
+ try:
+ osm_data = OSMData()
+ with osm_output_fp.open(encoding="utf-8") as output_file:
+ xml_str = output_file.read()
+ tree = ET.fromstring(xml_str)
+ osm_data.parse_osm(tree, parse_nodes=True,
+ parse_relations=True, parse_ways=True)
+ except Exception as e:
+ osm_data, _ = get_osm_data(bbox, osm_output_fp, use_lock=use_osm_cache_lock)
+ else:
+ # No local osm planet dump file. Need to download or read from cache
+ osm_data, _ = get_osm_data(bbox, osm_output_fp, use_lock=use_osm_cache_lock)
+
+ if osm_data is None:
+ return
+
+ if download_osm_only:
+ return
+
+ imgs, svg = draw_bev(bbox, osm_data, configuration, meters_per_pixel, heading)
+ if imgs is None:
+ return
+
+ if bev_output_fp:
+ svg.saveas(bev_output_fp)
+
+ cropped_imgs = []
+ for img in imgs:
+ # Set heading to None because we already rotated in draw_bev
+ cropped_imgs.append(process_img(img, num_pixels, heading=None))
+
+ masks = []
+ for img in cropped_imgs:
+ arr = np.array(img)
+ masks.append(arr[..., -1] != 0)
+
+ extracted_mask = np.stack(masks, axis=0)
+ extracted_mask[2][extracted_mask[0]] = 0
+
+ if final_downsample > 1:
+ max_pool_layer = nn.MaxPool2d(kernel_size=final_downsample, stride=final_downsample)
+ # Apply max pooling
+ mask_tensor = torch.tensor(extracted_mask, dtype=torch.float32).unsqueeze(0)
+ max_pool_tensor = max_pool_layer(mask_tensor)
+ # Remove the batch dimension and permute back to original dimension order, then convert to numpy
+ multilabel_mask_downsampled = max_pool_tensor.squeeze(0).permute(1, 2, 0).numpy()
+ else:
+ multilabel_mask_downsampled = extracted_mask.transpose(1, 2, 0)
+
+
+ # Save npz files for semantic masks
+ if mask_output_fp:
+ np.savez_compressed(mask_output_fp, multilabel_mask_downsampled)
+
+ # Save rendered BEV map if we want for visualization
+ if rendered_mask_output_fp:
+ rgb = mask2rgb(multilabel_mask_downsampled)
+ plt.imsave(rendered_mask_output_fp.with_suffix('.png'), rgb)
+
+
+def get_bev_from_bbox_worker_init(osm_cache_dir, bev_dir, semantic_mask_dir, rendered_mask_dir,
+ scheme_path, map_length, meters_per_pixel,
+ osm_data, redownload, download_osm_only, store_osm_per_id,
+ use_osm_cache_lock, final_downsample):
+ global worker_kwargs
+ worker_kwargs=locals()
+ # MapConfiguration is not picklable so we have to initialize it for each worker
+ scheme = Scheme.from_file(Path(scheme_path))
+ configuration = MapConfiguration(scheme)
+ configuration.show_credit = False
+ worker_kwargs["configuration"] = configuration
+ logger.info(f"Worker {os.getpid()} started.")
+
+
+def get_bev_from_bbox_worker(job_dict):
+ id = job_dict['id']
+ bbox = job_dict['bbox_formatted']
+ bbox = BoundaryBox.from_text(bbox)
+ heading = job_dict['computed_compass_angle']
+
+ # Setting a path to None disables storing that file
+ bev_fp = worker_kwargs["bev_dir"]
+ if bev_fp:
+ bev_fp = bev_fp / f"{id}.svg"
+
+ semantic_mask_fp = worker_kwargs["semantic_mask_dir"]
+ if semantic_mask_fp:
+ semantic_mask_fp = semantic_mask_fp / f"{id}.npz"
+
+ rendered_mask_fp = worker_kwargs["rendered_mask_dir"]
+ if rendered_mask_fp:
+ rendered_mask_fp = rendered_mask_fp / f"{id}.png"
+
+ if worker_kwargs["store_osm_per_id"]:
+ osm_output_fp = worker_kwargs["osm_cache_dir"] / f"{id}.osm"
+ else:
+ osm_output_fp = worker_kwargs["osm_cache_dir"] / f"{bbox.get_format()}.osm"
+
+
+ if ( (bev_fp is None or bev_fp.exists() ) # Bev exists or we don't want to save it
+ and (semantic_mask_fp is None or semantic_mask_fp.exists()) # ...
+ and (rendered_mask_fp is None or rendered_mask_fp.exists()) # ...
+ and not worker_kwargs["redownload"]):
+ return
+
+ get_bev_from_bbox(bbox=bbox,
+ num_pixels=worker_kwargs["map_length"],
+ meters_per_pixel=worker_kwargs["meters_per_pixel"],
+ configuration=worker_kwargs["configuration"],
+ osm_output_fp=osm_output_fp,
+ bev_output_fp=bev_fp,
+ mask_output_fp=semantic_mask_fp,
+ rendered_mask_output_fp=rendered_mask_fp,
+ osm_data=worker_kwargs["osm_data"],
+ heading=heading,
+ final_downsample=worker_kwargs["final_downsample"],
+ download_osm_only=worker_kwargs["download_osm_only"],
+ use_osm_cache_lock=worker_kwargs["use_osm_cache_lock"])
+
+def main(dataset_dir, locations, args):
+ # setup directory paths
+ dataset_dir = Path(dataset_dir)
+
+ for loc in locations:
+ loc_name = loc["name"].lower().replace(" ", "_")
+ location_dir = dataset_dir / loc_name
+ osm_cache_dir = location_dir / "osm_cache"
+ bev_dir = location_dir / "bev_raw" if args.store_all_steps else None
+ semantic_mask_dir = location_dir / "semantic_masks"
+ rendered_mask_dir = location_dir / "rendered_semantic_masks" if args.store_all_steps else None
+
+ for d in [location_dir, osm_cache_dir, bev_dir, semantic_mask_dir, rendered_mask_dir]:
+ if d:
+ d.mkdir(parents=True, exist_ok=True)
+
+ # read the parquet file
+ parquet_fp = location_dir / f"image_metadata_filtered_processed.parquet"
+ logger.info(f"Reading parquet file from {parquet_fp}.")
+ df = pd.read_parquet(parquet_fp)
+
+ if args.n_samples > 0:# If -1, use all samples
+ logger.info(f"Sampling {args.n_samples} rows.")
+ df = df.sample(args.n_samples, replace=False, random_state=1)
+
+ df.reset_index(drop=True, inplace=True)
+ logger.info(f"Read {len(df)} rows from the parquet file.")
+
+ # convert pandas dataframe to geopandas dataframe
+ gdf = gpd.GeoDataFrame(df,
+ geometry=gpd.points_from_xy(
+ df['computed_geometry.long'],
+ df['computed_geometry.lat']),
+ crs=4326)
+
+ # convert the geopandas dataframe to UTM
+ utm_crs = gdf.estimate_utm_crs()
+ gdf_utm = gdf.to_crs(utm_crs)
+ transformer = Transformer.from_crs(utm_crs, 4326)
+ logger.info(f"UTM zone for {loc_name} is {utm_crs.to_epsg()}.")
+
+ # load OSM data, if available
+ padding = args.padding
+ # calculate the required distance from the center to the edge of the image
+ # so that the image will not be out of bounds when we rotate it
+ map_length = args.map_length
+ map_length = ceil(sqrt(map_length**2 + map_length**2))
+ distance = map_length * args.meters_per_pixel / 2
+ logger.info(f"Each image will be {map_length:.2f} x {map_length:.2f} pixels. The distance from the center to the edge is {distance:.2f} meters.")
+
+ osm_data = None
+ if args.osm_fp:
+ logger.info(f"Loading OSM data from {args.osm_fp}.")
+ osm_fp = Path(args.osm_fp)
+ osm_data = OSMData()
+ if osm_fp.suffix == '.osm':
+ osm_data.parse_osm_file(osm_fp)
+ elif osm_fp.suffix == '.json':
+ osm_data.parse_overpass(osm_fp)
+ else:
+ raise ValueError(f"OSM file format {osm_fp.suffix} is not supported.")
+ # make sure that the loaded osm data at least covers some points in the dataframe
+ bbox = osm_data.boundary_box
+ shapely_bbox = box(bbox.left, bbox.bottom, bbox.right, bbox.top)
+ logger.warning(f"Clipping the geopandas dataframe to the OSM boundary box. May result in loss of points.")
+ gdf = gpd.clip(gdf, shapely_bbox)
+ if gdf.empty:
+ raise ValueError("Clipped geopandas dataframe is empty. Exiting.")
+ logger.info(f"Clipped geopandas dataframe is left with {len(gdf)} points.")
+
+ elif args.one_big_osm:
+ osm_fp = location_dir / "one_big_map.osm"
+ min_long = gdf_utm.geometry.x.min() - distance - padding
+ max_long = gdf_utm.geometry.x.max() + distance + padding
+ min_lat = gdf_utm.geometry.y.min() - distance - padding
+ max_lat = gdf_utm.geometry.y.max() + distance + padding
+ padding = 0
+ big_bbox = transformer.transform_bounds(left=min_long, bottom=min_lat, right=max_long, top=max_lat)
+ # TODO: Check why transformer is flipping lat long
+ big_bbox = (big_bbox[1], big_bbox[0], big_bbox[3], big_bbox[2])
+ big_bbox_fmt = ",".join([str(x) for x in big_bbox])
+ logger.info(f"Fetching one big osm file using coordinates {big_bbox_fmt}.")
+ big_bbox = BoundaryBox.from_text(big_bbox_fmt)
+ osm_data, retries = get_osm_data(big_bbox, osm_fp, overwrite=args.redownload)
+
+ # create bounding boxes for each point
+ gdf_utm['bounding_box_utm_p1'] = gdf_utm.apply(lambda row: (
+ row.geometry.x - distance - padding,
+ row.geometry.y - distance - padding,
+ ), axis=1)
+
+ gdf_utm['bounding_box_utm_p2'] = gdf_utm.apply(lambda row: (
+ row.geometry.x + distance + padding,
+ row.geometry.y + distance + padding,
+ ), axis=1)
+
+ # convert the bounding box back to lat, long
+ gdf_utm['bounding_box_lat_long_p1'] = gdf_utm.apply(lambda row: transformer.transform(*row['bounding_box_utm_p1']), axis=1)
+ gdf_utm['bounding_box_lat_long_p2'] = gdf_utm.apply(lambda row: transformer.transform(*row['bounding_box_utm_p2']), axis=1)
+ gdf_utm['bbox_min_lat'] = gdf_utm['bounding_box_lat_long_p1'].apply(lambda x: x[0])
+ gdf_utm['bbox_min_long'] = gdf_utm['bounding_box_lat_long_p1'].apply(lambda x: x[1])
+ gdf_utm['bbox_max_lat'] = gdf_utm['bounding_box_lat_long_p2'].apply(lambda x: x[0])
+ gdf_utm['bbox_max_long'] = gdf_utm['bounding_box_lat_long_p2'].apply(lambda x: x[1])
+ gdf_utm['bbox_formatted'] = gdf_utm.apply(lambda row: f"{row['bbox_min_long']},{row['bbox_min_lat']},{row['bbox_max_long']},{row['bbox_max_lat']}", axis=1)
+
+ # iterate over the dataframe and get BEV images
+ jobs = gdf_utm[['id', 'bbox_formatted', 'computed_compass_angle']] # only need the id and bbox_formatted columns for the jobs
+ jobs = jobs.to_dict(orient='records').copy()
+
+ use_osm_cache_lock = args.n_workers > 0 and not args.store_osm_per_id
+ if use_osm_cache_lock:
+ logger.info("Using osm cache locks to prevent race conditions since number of workers > 0 and store_osm_per_id is false")
+
+ init_args = [osm_cache_dir, bev_dir, semantic_mask_dir, rendered_mask_dir,
+ args.map_machine_scheme,
+ args.map_length, args.meters_per_pixel,
+ osm_data, args.redownload, args.download_osm_only,
+ args.store_osm_per_id, use_osm_cache_lock, args.final_downsample]
+
+ if args.n_workers > 0:
+ logger.info(f"Launching {args.n_workers} workers to fetch BEVs for {len(jobs)} bounding boxes.")
+ with mp.Pool(args.n_workers,
+ initializer=get_bev_from_bbox_worker_init,
+ initargs=init_args) as pool:
+ for _ in tqdm(pool.imap_unordered(get_bev_from_bbox_worker, jobs, chunksize=16),
+ total=len(jobs), desc="Getting BEV images"):
+ pass
+ else:
+ get_bev_from_bbox_worker_init(*init_args)
+ pbar = tqdm(jobs, desc="Getting BEV images")
+ for job_dict in pbar:
+ get_bev_from_bbox_worker(job_dict)
+
+ # Download sattelite images if needed
+ if args.store_sat:
+ logger.info("Downloading sattelite images.")
+ sat_dir = location_dir / "sattelite"
+ sat_dir.mkdir(parents=True, exist_ok=True)
+ pbar = tqdm(jobs, desc="Getting Sattelite images")
+ for job_dict in pbar:
+ id = job_dict['id']
+ sat_fp = sat_dir / f"{id}.png"
+ if sat_fp.exists() and not args.redownload:
+ continue
+ bbox = [float(x) for x in job_dict['bbox_formatted'].split(",")]
+ try:
+ get_satellite_from_bbox(bbox, sat_fp, heading=job_dict['computed_compass_angle'], num_pixels=args.map_length)
+ except Exception as e:
+ logger.error(f"Failed to get sattelite image for bbox {job_dict['bbox_formatted']}. Exception {repr(e)}")
+
+ # TODO: Post BEV retireval filtering
+ # df.to_parquet(location_dir / "image_metadata_bev_processed.parquet")
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="Get BEV images from a dataset of locations using MapMachine.")
+ parser.add_argument("--cfg", type=str, default="mia/conf/example.yaml", help="Path to config yaml file.")
+ args = parser.parse_args()
+
+ cfgs = OmegaConf.load(args.cfg)
+
+ if cfgs.bev_options.store_sat:
+ if cfgs.bev_options.n_workers > 0:
+ logger.fatal("Satellite download is not multiprocessed yet !!")
+ import ee
+ ee.Initialize()
+
+ logger.info("="*80)
+ logger.info("Running get_bev.py")
+ logger.info("Arguments:")
+ for arg in vars(args):
+ logger.info(f"- {arg}: {getattr(args, arg)}")
+ logger.info("="*80)
+ main(cfgs.dataset_dir, cfgs.cities, cfgs.bev_options)
\ No newline at end of file
diff --git a/mia/bev/image.py b/mia/bev/image.py
new file mode 100644
index 0000000000000000000000000000000000000000..3a7fefec66e27907a0804efbdd2c2d22a43d5c0f
--- /dev/null
+++ b/mia/bev/image.py
@@ -0,0 +1,26 @@
+from PIL import Image
+
+
+def center_pad(img: Image, width: int, height: int):
+ if img.width < width or img.height < height:
+ height = max(img.height, height)
+ width = max(img.width, width)
+ padded_img = Image.new("RGBA", (width, height), (0, 0, 0, 0))
+ x_offset = (width - img.width) // 2
+ y_offset = (height - img.height) // 2
+ padded_img.paste(img, (x_offset, y_offset))
+ img = padded_img
+ return img
+
+def center_crop_to_size(img: Image, width: int, height: int) -> Image:
+ """Center crop the image to the given width and height."""
+ if img.width < width:
+ raise ValueError("Invalid crop width. Crop width is larger than image width.")
+ if img.height < height:
+ raise ValueError("Invalid crop height. Crop height is larger than image height.")
+ left = (img.width - width) / 2
+ top = (img.height - height) / 2
+ right = (img.width + width) / 2
+ bottom = (img.height + height) / 2
+ img = img.crop((left, top, right, bottom))
+ return img
\ No newline at end of file
diff --git a/mia/bev/styles/default.yml b/mia/bev/styles/default.yml
new file mode 100644
index 0000000000000000000000000000000000000000..8f10e632bb679c12a348bcc7eda908eae126a173
--- /dev/null
+++ b/mia/bev/styles/default.yml
@@ -0,0 +1,3037 @@
+options:
+
+ draw_nodes: yes
+ draw_trees: yes
+ draw_craters: yes
+ draw_buildings: yes
+ draw_directions: yes
+
+colors:
+
+ # Entity
+
+ default: "#444444"
+ extra: "#888888"
+
+ direction_view_color: "#C8E8FF"
+ direction_camera_color: "#0088FF"
+
+ background_color: "#EEEEEE"
+ road_color: "#FFFFFF"
+ text_color: "#444444"
+ text_main_color: "#000000"
+ text_outline_color: "#FFFFFF"
+
+ wheat_color: "#F0DCAA"
+ wheat_border_color: "#F4D67F"
+ wheat_dark_color: "#BF9340"
+ rye_color: "#E0CA96"
+ rye_dark_color: "#AE955D"
+ oat_color: "#EDDDB7"
+ oat_dark_color: "#C4894A"
+ barley_color: "#F3EEC4"
+ barley_border_color: "#D2CE9D"
+ barley_dark_color: "#908F62"
+ sunflower_dark_color: "#DEAC4A"
+
+ motorway_border_color: "#CC8800"
+ motorway_color: "#FFAA33"
+ primary_border_color: "#AA8800"
+ primary_color: "#FFDD66"
+ secondary_border_color: "#BB9911"
+ secondary_color: "#FFEE77"
+ tertiary_border_color: "#CCAA22"
+ tertiary_color: "#FFFF88"
+
+ bridge_color: "#666666"
+ ford_color: "#88BBFF"
+ embankment_color_color: "#666666"
+
+ allotments_color: "#D0E0D0"
+ beach_color: "#F0E0C0"
+ boundary_color: "#880088"
+ building_border_color: "#E0D0C0"
+ building_color: "#F8F0E8"
+ building_construction_border_color: "#C4C0BC"
+ building_construction_color: "#D4D0CC"
+ construction_color: "#CCCCCC"
+ cycle_color: "#4444EE"
+ desert_color: "#F0E0D0"
+ decidious_color: "#FCAF3E"
+ emergency_color: "#DD2222"
+ evergreen_color: "#688C44"
+ farmland_color: "#FFEEBB"
+ farmland_border_color: "#DDCC99"
+ farmland_darker_color: "#998855"
+ ferry_terminal_color: "#AABBDD"
+ foot_area_color: "#DDDDDD"
+ foot_area_border_color: "#BBBBBB"
+ foot_border_color: "#FFFFFF"
+ foot_color: "#B89A74"
+ grass_border_color: "#BFD098"
+ grass_color: "#CFE0A8"
+ hidden_color: "#000000"
+ indoor_border_color: "#A0A890"
+ indoor_color: "#E8E4E0"
+ indoor_column_color: {color: indoor_border_color, darken: 0.5}
+ meadow_border_color: "#BFD078"
+ meadow_color: "#CFE088"
+ orchard_color: "#B8DCA4"
+ orchard_border_color: "#98BC84"
+ outline_color: "#FFFFFF"
+ parking_color: "#DDCC99"
+ park_color: "#CFE0A8"
+ pitch_color: "#AADDCC"
+ pitch_border_color: "#88BBAA"
+ platform_border_color: "#AAAAAA"
+ platform_color: "#CCCCCC"
+ playground_border_color: "#FFAA88"
+ playground_color: "#FFDDCC"
+ ridge_color: "#000000"
+ road_border_color: "#CCCCCC"
+ rock_color: "#DDDDDD"
+ route_color: "#FFFFFF"
+ sand_color: "#E8E0C0"
+ scree_color: "#CCCCCC"
+ track_color: "#A88A64"
+ trunk_color: "#97612b"
+ tree_color: "#98AC64"
+ village_green_color: "#DDEEBB"
+ wall_bottom_1_color: "#AAAAAA"
+ wall_bottom_2_color: "#C3C3C3"
+ wall_color: "#E8E8E8"
+ wall_construction_color: "#84807C"
+ water_border_color: "#6688BB"
+ water_color: "#AACCFF"
+ wetland_color: "#BFE0D8"
+ wood_border_color: "#A8BC74"
+ wood_color: "#B8CC84"
+
+ runway_color: "#999399"
+ runway_border_color: {color: runway_color, darken: 0.25}
+ taxiway_color: "#AAA4AA"
+ taxiway_border_color: {color: taxiway_color, darken: 0.25}
+
+ sell_color: "#880088"
+ craft_color: "#008800"
+
+ # Colors not in W3C
+
+ rose: "#FF007F" # Wikipedia
+ slate_blue: "#6A5ACD" # W3C slateblue
+
+colors_dark: # dark
+
+ # Entity
+
+ default: "#CCCCCC"
+
+ direction_view_color: {color: "#C8E8FF", darken: 0.75}
+ direction_camera_color: "#0088FF"
+
+ background_color: "#222222"
+ road_color: "#000000"
+ text_color: "#DDDDDD"
+ text_main_color: "#ffffff"
+ text_outline_color: "#000000"
+
+ wheat_color: "#F0DCAA"
+ wheat_border_color: "#F4D67F"
+ wheat_dark_color: "#BF9340"
+ rye_color: "#E0CA96"
+ rye_dark_color: "#AE955D"
+ oat_color: "#EDDDB7"
+ oat_dark_color: "#C4894A"
+ barley_color: "#F3EEC4"
+ barley_border_color: "#D2CE9D"
+ barley_dark_color: "#908F62"
+ sunflower_dark_color: "#DEAC4A"
+
+ motorway_border_color: "#CC8800"
+ motorway_color: "#FFAA33"
+ primary_border_color: {color: "#AA8800", darken: 0.3}
+ primary_color: {color: "#FFDD66", darken: 0.75}
+ secondary_border_color: {color: "#BB9911", darken: 0.3}
+ secondary_color: {color: "#FFEE77", darken: 0.75}
+ tertiary_border_color: {color: "#CCAA22", darken: 0.3}
+ tertiary_color: {color: "#FFFF88", darken: 0.75}
+
+ allotments_color: "#D0E0D0"
+ beach_color: "#F0E0C0"
+ boundary_color: "#880088"
+ building_border_color: "#888888"
+ building_color: "#444444"
+ construction_color: {color: "#CCCCCC", darken: 0.75}
+ cycle_color: "#4444EE"
+ desert_color: "#F0E0D0"
+ decidious_color: "#FCAF3E"
+ emergency_color: "#DD2222"
+ evergreen_color: "#688C44"
+ farmland_color: "#FFEEBB"
+ farmland_border_color: "#DDCC99"
+ farmland_darker_color: "#998855"
+ ferry_terminal_color: "#AABBDD"
+ foot_area_color: "#222222"
+ foot_area_border_color: "#444444"
+ foot_border_color: "#000000"
+ foot_color: {color: "#B89A74", darken: 0.5}
+ grass_border_color: {color: "#BFD098", darken: 0.75}
+ grass_color: {color: "#CFE0A8", darken: 0.75}
+ hidden_color: "#FFFFFF"
+ indoor_border_color: "#C0B8B0"
+ indoor_color: "#E8E4E0"
+ meadow_border_color: "#BFD078"
+ meadow_color: "#CFE088"
+ orchard_color: "#B8DCA4"
+ orchard_border_color: "#98BC84"
+ outline_color: "#FFFFFF"
+ parking_color: {color: "#DDCC99", darken: 0.75}
+ park_color: {color: "#CFE0A8", darken: 0.75}
+ pitch_color: {color: "#AADDCC", darken: 0.75}
+ pitch_border_color: {color: "#88BBAA", darken: 0.75}
+ platform_border_color: "#AAAAAA"
+ platform_color: "#CCCCCC"
+ playground_border_color: {color: "#FFAA88", darken: 0.75}
+ playground_color: {color: "#FFDDCC", darken: 0.75}
+ ridge_color: "#000000"
+ road_border_color: "#444444"
+ rock_color: "#DDDDDD"
+ route_color: "#FFFFFF"
+ sand_color: "#E8E0C0"
+ scree_color: "#CCCCCC"
+ track_color: "#A88A64"
+ trunk_color: "#97612b"
+ tree_color: "#98AC64"
+ water_border_color: "#6688BB"
+ water_color: "#AACCFF"
+ wall_bottom_1_color: "#444444"
+ wall_bottom_2_color: "#222222"
+ wall_color_start: 0.0
+ wetland_color: "#BFE0D8"
+ wood_border_color: {color: "#A8BC74", darken: 0.75}
+ wood_color: {color: "#B8CC84", darken: 0.75}
+
+ runway_color: "#999399"
+ runway_border_color: {color: runway_color, darken: 0.25}
+ taxiway_color: "#AAA4AA"
+ taxiway_border_color: {color: taxiway_color, darken: 0.25}
+
+ sell_color: "#880088"
+
+ # Colors not in W3C
+
+ rose: "#FF007F" # Wikipedia
+ slate_blue: "#6A5ACD" # W3C slateblue
+
+carto_colors:
+ building_border_color: {color: building_color, darken: 0.15}
+ building_color: "#d9d0c9"
+ cemetery_color: "#aacbaf"
+ commercial_color: "#f2dad9"
+ commercial_border_color: "#d1b2b0"
+ grass_color: "#cdebb0"
+ industrial_color: "#ebdbe8"
+ industrial_border_color: "#d1b2b0"
+ military_color: "#f55"
+ park_color: "#c8facc"
+ residential_color: "#e0dfdf"
+ residential_border_color: "#b9b9b9"
+ wood_color: "#add19e"
+
+material_colors:
+
+ bronze: "#CD7F32"
+ concrete: "#AAAAAA"
+ glass: "#CCEEFF"
+
+node_icons:
+
+ - group: "No draw"
+ tags:
+ - tags: {type: multipolygon}
+ draw: false
+ - tags: {place: "*"}
+ draw: false
+ - tags: {building: "yes"}
+ draw: false
+
+ - group: "Huge transport hubs"
+ start_zoom_level: 10.0
+ tags:
+ - tags: {amenity: ferry_terminal}
+ shapes: [anchor]
+ - tags: {amenity: ferry_terminal, cargo: vehicle}
+ shapes: [car_on_ferry]
+ - tags: {amenity: ferry_terminal, cargo: passengers}
+ shapes: [human_on_ferry]
+ - tags: {aeroway: aerodrome}
+ shapes: [plane]
+ - tags: {aeroway: helipad}
+ shapes: [h]
+ - tags: {aeroway: spaceport}
+ shapes: [rocket_on_launch_pad]
+
+ - group: "Normal transport hubs"
+ start_zoom_level: 11.0
+ tags:
+ - tags: {aeroway: launchpad}
+ shapes: [rocket_flying]
+ - tags: {aeroway: landingpad}
+ shapes: [booster_landing]
+ - tags: {highway: bus_station}
+ shapes: [buses]
+ - tags: {highway: bus_stop}
+ shapes: [bus_stop]
+ add_shapes: [bus]
+ - tags: {railway: station}
+ shapes: [train]
+ - tags: {railway: station, station: subway, transport: subway}
+ shapes: [train]
+ - tags: {railway: subway_entrance}
+ shapes: [train]
+ - tags: {railway: subway_entrance, entrance: "yes"}
+ shapes: [train]
+ - tags: {public_transport: stop_position}
+ shapes: [bus_stop]
+ - tags: {railway: tram_station}
+ shapes: [tram]
+ - tags: {railway: tram_stop}
+ shapes: [tram]
+ - tags: {public_transport: platform}
+ shapes: [bus_stop_sign]
+ with_icon: [bus_stop_bench, bus_stop_shelter]
+ over_icon: [platform]
+ - tags: {highway: bus_stop, public_transport: platform}
+ shapes: [bus_stop_sign]
+ with_icon: [bus_stop_bench, bus_stop_shelter]
+ over_icon: [platform]
+ - tags: {highway: bus_stop, shelter: "yes"}
+ shapes: [bus_stop_sign]
+ under_icon: [bus_stop_sign]
+ with_icon: [bus_stop_bench, platform]
+ over_icon: [bus_stop_shelter]
+ - tags: {highway: bus_stop, bench: "yes"}
+ under_icon: [bus_stop_sign]
+ with_icon: [bus_stop_shelter, platform]
+ over_icon: [bus_stop_bench]
+ - tags: {highway: stop}
+ shapes: [stop]
+ - tags: {amenity: taxi}
+ shapes: [taxi]
+
+ - group: "Big territory"
+ start_zoom_level: 12.0
+ tags:
+ - tags: {leisure: fishing}
+ shapes: [fishing_angle]
+ - tags: {historic: archaeological_site}
+ shapes: [amphora]
+ - tags: {leisure: swimming_area}
+ shapes: [swimming_area]
+ - tags: {leisure: swimming_pool}
+ shapes: [swimming_area]
+ - tags: {sport: swimming}
+ shapes: [swimming_area]
+ - tags: {leisure: beach}
+ shapes: [beach]
+ - tags: {amenity: public_bath}
+ shapes: [swimming_area]
+ - tags: {leisure: golf_course}
+ shapes: [golf_club_and_ball]
+ - tags: {power: substation}
+ shapes: [electricity]
+ - tags: {plant: christmas_trees}
+ shapes: [{shape: christmas_tree, color: orchard_border_color}]
+ - tags: {crop: sunflower}
+ shapes: [{shape: sunflower, color: sunflower_dark_color, outline: no}]
+ - tags: {crop: barley}
+ shapes: [{shape: ear_botany, color: barley_dark_color, outline: no}]
+ - tags: {crop: rye}
+ shapes: [{shape: ear_botany, color: rye_dark_color, outline: no}]
+ - tags: {crop: wheat}
+ shapes: [{shape: ear_botany, color: wheat_dark_color, outline: no}]
+ - tags: {crop: rape}
+ shapes: [{shape: rape, color: wheat_dark_color, outline: no}]
+ - tags: {produce: apple}
+ shapes: [{shape: apple, color: orchard_border_color}]
+ - tags: {produce: christmas_trees}
+ shapes: [{shape: christmas_tree, color: orchard_border_color}]
+ - tags: {produce: pear}
+ shapes: [{shape: pear, color: orchard_border_color}]
+ - tags: {trees: apple_trees}
+ shapes: [{shape: apple, color: orchard_border_color}]
+ - tags: {trees: pear_trees}
+ shapes: [{shape: pear, color: orchard_border_color}]
+
+ - group: "Bigger objects"
+ start_zoom_level: 13.0
+ tags:
+ - tags: {waterway: waterfall}
+ shapes: [{shape: waterfall, color: water_border_color}]
+ - tags: {natural: cliff}
+ shapes: [cliff]
+ - tags: {natural: peak}
+ shapes: [triangle_small]
+ - tags: {natural: saddle}
+ shapes: [saddle]
+ - tags: {natural: crater}
+ exception: {diameter: "*"}
+ shapes: [crater]
+
+ - tags: {natural: volcano}
+ shapes: [stratovolcano, {shape: smoke_2, offset: [0, -3]}]
+
+ - tags: {natural: volcano, volcano:type: stratovolcano}
+ shapes: [stratovolcano, {shape: smoke_2, offset: [0, -3]}]
+ - tags: {natural: volcano, volcano:type: shield}
+ shapes: [shield_volcano, {shape: smoke_2, offset: [0, -1]}]
+ - tags: {natural: volcano, volcano:type: scoria}
+ shapes: [volcanic_cone, {shape: smoke_2, offset: [0, -2]}]
+
+ - tags: {natural: volcano, volcano:status: active}
+ shapes: [stratovolcano, {shape: lava, offset: [0, -3]}]
+ - tags: {natural: volcano, volcano:status: dormant}
+ shapes: [stratovolcano, {shape: smoke, offset: [1, -3]}]
+ - tags: {natural: volcano, volcano:status: extinct}
+ shapes: [stratovolcano]
+
+ - tags:
+ natural: volcano
+ volcano:type: stratovolcano
+ volcano:status: active
+ shapes: [stratovolcano, {shape: lava, offset: [0, -3]}]
+ - tags:
+ natural: volcano
+ volcano:type: shield
+ volcano:status: active
+ shapes: [shield_volcano, {shape: lava, offset: [0, -1]}]
+ - tags:
+ natural: volcano
+ volcano:type: scoria
+ volcano:status: active
+ shapes: [volcanic_cone, {shape: lava, offset: [0, -2]}]
+ - tags:
+ natural: volcano
+ volcano:type: stratovolcano
+ volcano:status: dormant
+ shapes: [stratovolcano, {shape: smoke, offset: [1, -3]}]
+ - tags:
+ natural: volcano
+ volcano:type: shield
+ volcano:status: dormant
+ shapes: [shield_volcano, {shape: smoke, offset: [1, -1]}]
+ - tags:
+ natural: volcano
+ volcano:type: scoria
+ volcano:status: dormant
+ shapes: [volcanic_cone, {shape: smoke, offset: [1, -2]}]
+ - tags:
+ natural: volcano
+ volcano:type: stratovolcano
+ volcano:status: extinct
+ shapes: [stratovolcano]
+ - tags:
+ natural: volcano
+ volcano:type: shield
+ volcano:status: extinct
+ shapes: [shield_volcano]
+ - tags:
+ natural: volcano
+ volcano:type: scoria
+ volcano:status: extinct
+ shapes: [volcanic_cone]
+
+ - tags: {historic: castle}
+ location_restrictions: {include: [jp]}
+ shapes: [japan_castle]
+ - tags: {historic: fort}
+ shapes: [fort]
+ - tags: {shop: mall}
+ shapes: [bag]
+ - tags: {shop: department_store}
+ shapes: [bag]
+ - tags: {shop: mall, building: "yes"}
+ shapes: [bag]
+ - tags: {leisure: water_park}
+ shapes: [slide_and_water]
+
+ - group: "Important big objects"
+ start_zoom_level: 14.0
+ tags:
+ - tags: {amenity: fire_station}
+ location_restrictions: {include: [jp]}
+ shapes: [{shape: japan_fire_station, color: emergency_color}]
+ - tags: {amenity: courthouse}
+ shapes: [gavel]
+ - tags: {amenity: police}
+ location_restrictions: {include: [jp]}
+ shapes: [japan_police_station]
+ - tags: {amenity: police, name:en: Koban}
+ location_restrictions: {include: [jp]}
+ shapes: [japan_koban]
+ - tags: {building: townhall}
+ shapes: [townhall]
+ - tags: {amenity: townhall}
+ shapes: [townhall]
+ - tags: {historic: city_gate}
+ shapes: [city_gate]
+ - tags: {amenity: pharmacy}
+ shapes: [medicine_bottle]
+ - tags: {amenity: embassy}
+ shapes: [waving_flag]
+ - tags: {office: diplomatic, diplomatic: embassy}
+ shapes: [waving_flag]
+ - tags: {man_made: monitoring_station, monitoring:weather: "yes"}
+ shapes: [japan_weather_station]
+ location_restrictions: {include: [jp]}
+ - tags: {amenity: veterinary}
+ shapes: [dog_and_cross]
+ - tags: {tourism: apartment}
+ shapes: [bed_with_floor_and_ceiling]
+ - tags: {tourism: hotel}
+ shapes: [bed]
+ - tags: {building: hotel}
+ shapes: [bed]
+ - tags: {tourism: hostel}
+ shapes: [two_beds]
+ - tags: {tourism: motel}
+ shapes: [{shape: car, offset: [0, 4]}, {shape: bed, offset: [0, -2]}]
+ - tags: {tourism: guest_house}
+ shapes: [bed_and_roof]
+ - tags: {amenity: hospital}
+ location_restrictions: {include: world, exclude: [jp]}
+ shapes: [greek_cross]
+ - tags: {amenity: hospital}
+ location_restrictions: {include: [jp]}
+ shapes: [japan_public_health_center]
+ - tags: {amenity: clinic}
+ shapes: [greek_cross_in_box]
+ - tags: {amenity: doctors}
+ shapes: [greek_cross_in_box]
+ - tags: {amenity: dentist}
+ shapes: [tooth]
+ - tags: {amenity: post_office}
+ location_restrictions: {include: world, exclude: [jp]}
+ shapes: [envelope]
+ - tags: {amenity: post_office}
+ location_restrictions: {include: [jp]}
+ shapes: [japan_post]
+ - tags: {shop: car_repair}
+ shapes: [{shape: car, offset: [0, 3]}, {shape: wrench, offset: [0, -4]}]
+ - tags: {amenity: car_rental}
+ shapes: [{shape: car, offset: [0, 3]}, {shape: key, offset: [1, -3]}]
+ - tags: {amenity: car_sharing}
+ shapes: [{shape: car, offset: [0, 3]}, {shape: sharing, offset: [0, -4]}]
+ - tags: {amenity: car_wash}
+ shapes:
+ - {shape: car, offset: [0, 3]}
+ - {shape: shower_head, offset: [0, -4]}
+ # Place of worship
+ - tags: {building: shrine, religion: shinto}
+ shapes: [japan_shinto_shrine]
+ - tags: {religion: christian}
+ shapes: [latin_cross]
+ - tags: {amenity: place_of_worship, religion: christian}
+ shapes: [latin_cross]
+ - tags:
+ amenity: place_of_worship
+ religion: christian
+ denomination: catholic
+ shapes: [latin_cross]
+ - tags:
+ amenity: place_of_worship
+ religion: christian
+ denomination: russian_orthodox
+ shapes: [russian_orthodox_cross]
+ - tags:
+ amenity: place_of_worship
+ religion: christian
+ denomination: orthodox
+ shapes: [orthodox]
+ - tags:
+ amenity: place_of_worship
+ religion: christian
+ denomination: baptist
+ shapes: [baptist]
+ - tags: {amenity: place_of_worship, religion: muslim}
+ shapes: [crescent]
+ - tags: {amenity: place_of_worship, religion: buddhist}
+ shapes: [dharmachakra]
+ - tags: {amenity: place_of_worship, religion: jewish}
+ shapes: [star_of_david]
+ - tags: {historic: tomb, tomb: mausoleum}
+ shapes: [mausoleum]
+ - tags: {historic: tomb, tomb: pyramid}
+ shapes: [pyramid]
+ - tags: {historic: "*"}
+ shapes: [japan_historic]
+ replace_shapes: no
+ location_restrictions: {include: [jp]}
+
+ - group: "Normal big objects"
+ start_zoom_level: 15.0
+ tags:
+ - tags: {shop: supermarket}
+ shapes: [supermarket_cart]
+ - tags: {shop: variety_store}
+ shapes: [bag_with_percent]
+ - tags: {shop: general}
+ shapes: [bag]
+ - tags: {amenity: arts_centre}
+ shapes: [picture]
+ - tags: {amenity: bank}
+ shapes: [money]
+ - tags: {amenity: cinema}
+ shapes: [film]
+ - tags: {amenity: casino}
+ shapes: [card_and_dice]
+ - tags: {amenity: community_centre}
+ shapes: [two_people_together]
+ - tags: {amenity: gym}
+ shapes: [dumbbell]
+ - tags: {amenity: social_facility}
+ shapes: [two_people_together]
+ - tags: {amenity: internet_cafe}
+ shapes: [at_in_square]
+ - tags: {amenity: library}
+ shapes: [book]
+ - tags: {amenity: marketplace}
+ shapes: [marketplace]
+ - tags: {amenity: prison}
+ shapes: [prison]
+ - tags: {amenity: stripclub}
+ shapes: [pole_dancer]
+ - tags: {club: computer}
+ shapes: [glider]
+ - tags: {leisure: hackerspace}
+ shapes: [glider]
+ - tags: {man_made: survey_point}
+ shapes: [survey_point]
+ - tags: {leisure: amusement_arcade}
+ shapes: [pac_man]
+ - tags: {leisure: fitness_centre}
+ shapes: [dumbbell]
+ - tags: {leisure: fitness_station}
+ shapes: [dumbbell]
+ - tags: {leisure: bird_hide}
+ shapes: [binoculars]
+ - tags: {leisure: bleachers}
+ shapes: [bleachers]
+ - tags: {leisure: bowling_alley}
+ shapes: [bowling_ball]
+ - tags: {leisure: dog_park}
+ shapes: [dog]
+ - tags: {leisure: escape_game}
+ shapes: [{shape: maze, offset: [-2, 0]}, {shape: arrow_right_short, offset: [5, -1]}]
+ - tags: {leisure: maze}
+ shapes: [maze]
+ - tags: {attraction: maze}
+ shapes: [maze]
+ - tags: {maze: labyrinth}
+ shapes: [maze]
+ - tags: {tourism: maze}
+ shapes: [maze]
+ - tags: {leisure: miniature_golf}
+ shapes: [golf_club_and_ball]
+ - tags: {leisure: sauna}
+ shapes: [sauna]
+ - tags: {leisure: outdoor_seating}
+ shapes: [table_and_two_chairs, umbrella]
+ - tags: {leisure: outdoor_seating, weather_protection: parasol}
+ shapes: [table_and_two_chairs, umbrella]
+ - tags: {leisure: outdoor_seating, weather_protection: roof}
+ shapes: [table_and_two_chairs, roof]
+ - tags: {leisure: outdoor_seating, weather_protection: awning}
+ shapes: [table_and_two_chairs, awning]
+ - tags: {leisure: outdoor_seating, weather_protection: pavilion}
+ shapes: [table_and_two_chairs, roof_and_walls]
+ - tags: {leisure: outdoor_seating, weather_protection: pergola}
+ shapes: [table_and_two_chairs, pergola]
+ - tags: {leisure: playground}
+ shapes: [toy_horse]
+ - tags: {amenity: theatre}
+ shapes: [curtains]
+ - tags: {amenity: bar}
+ shapes: [cocktail_glass]
+ - tags: {amenity: pub}
+ shapes: [beer_mug]
+ - tags: {amenity: fast_food}
+ shapes: [burger]
+ - tags: {amenity: fast_food, cuisine: steak_house}
+ shapes: [steak_and_fork]
+ - tags: {amenity: food_court}
+ shapes: [food_court]
+ - tags: {craft: shoemaker}
+ shapes: [{shape: shoe, color: sell_color}]
+ - tags: {shop: fishing}
+ shapes: [{shape: fishing_angle, color: sell_color}]
+ - tags: {shop: alcohol}
+ shapes: [{shape: bottle, color: sell_color}]
+ - tags: {shop: antiques}
+ shapes: [{shape: amphora, color: sell_color}]
+ - tags: {shop: art}
+ shapes: [{shape: picture, color: sell_color}]
+ - tags: {shop: bakery}
+ shapes: [cupcake]
+ - tags: {shop: bag}
+ shapes: [{shape: bag, color: sell_color}]
+ - tags: {shop: bed}
+ shapes: [{shape: bed, color: sell_color}]
+ - tags: {shop: beauty}
+ shapes: [vanity_mirror]
+ - tags: {shop: cosmetics}
+ shapes: [vanity_mirror]
+ - tags: {shop: bicycle}
+ shapes: [{shape: bicycle, color: sell_color}]
+ - tags: {shop: books}
+ shapes: [{shape: book, color: sell_color}]
+ - tags: {shop: butcher}
+ shapes: [knives]
+ - tags: {shop: car}
+ shapes: [{shape: car, color: sell_color}]
+ - tags: {shop: car_parts}
+ shapes: [shape: engine]
+ - tags: {shop: chocolate}
+ shapes: [cupcake]
+ - tags: {shop: coffee}
+ shapes: [{shape: coffee_cup, color: sell_color}]
+ - tags: {shop: confectionery}
+ shapes: [cupcake]
+ - tags: {shop: copyshop}
+ shapes: [sheets]
+ - tags: {shop: dairy}
+ shapes: [aseptic_carton]
+ - tags: {shop: doityourself}
+ shapes: [wretch_and_hammer]
+ - tags: {shop: dry_cleaning}
+ shapes: [washing_machine]
+ - tags: {shop: farm}
+ shapes: [{shape: apple, color: sell_color}]
+ - tags: {shop: fireplace}
+ shapes: [{shape: fireplace, color: sell_color}]
+ - tags: {shop: florist}
+ shapes: [{shape: flower_in_pot, color: sell_color}]
+ - tags: {shop: furniture}
+ shapes: [{shape: drawer, color: sell_color}]
+ - tags: {shop: greengrocer}
+ shapes: [{shape: apple, color: sell_color}]
+ - tags: {shop: hairdresser}
+ shapes: [comb_and_scissors]
+ - tags: {shop: hardware}
+ shapes: [wretch_and_hammer]
+ - tags: {shop: hifi}
+ shapes: [{shape: hi_fi, color: sell_color}]
+ - tags: {shop: houseware}
+ shapes: [{shape: pan, color: sell_color}]
+ - tags: {shop: jewelry}
+ shapes: [{shape: diamond, color: sell_color}]
+ - tags: {shop: jewellery}
+ shapes: [{shape: diamond, color: sell_color}]
+ - tags: {craft: jeweller}
+ shapes: [{shape: diamond, color: craft_color}]
+ - tags: {shop: laundry}
+ shapes: [washing_machine]
+ - tags: {shop: massage}
+ shapes: [massage]
+ - tags: {shop: medical_supply}
+ shapes: [{shape: medicine_bottle, color; sell_color}]
+ - tags: {shop: mobile_phone}
+ shapes: [{shape: phone, color: sell_color}]
+ - tags: {shop: newsagent}
+ shapes: [gazette]
+ - tags: {shop: optician}
+ shapes: [glasses]
+ - tags: {shop: pastry}
+ shapes: [cupcake]
+ - tags: {shop: pet}
+ shapes: [{shape: dog, color: sell_color}]
+ - tags: {shop: photo}
+ shapes: [{shape: photo_camera, color: sell_color}]
+ - tags: {shop: photography}
+ shapes: [{shape: photo_camera, color: sell_color}]
+ - tags: {shop: shoes}
+ shapes: [{shape: shoe, color: sell_color}]
+ - tags: {shop: sports}
+ shapes: [{shape: dumbbell, color: sell_color}]
+ - tags: {shop: travel_agency}
+ shapes: [globe]
+ - tags: {shop: milk}
+ shapes: [aseptic_carton]
+ - tags: {shop: wine}
+ shapes: [{shape: bottle_and_wine_glass, color: sell_color}]
+ - tags: {building: store}
+ shapes: [shop_convenience]
+ - tags: {shop: ticket}
+ shapes: [ticket]
+ - tags: {shop: tailor}
+ shapes: [t_shirt_and_scissors]
+ - tags: {shop: tyres}
+ shapes: [{shape: tyre, color: sell_color}]
+ - tags: {shop: toys}
+ shapes: [{shape: toy_horse, color: sell_color}]
+ - tags: {craft: tailor}
+ shapes: [t_shirt_and_scissors]
+ - tags: {shop: video}
+ shapes: [{shape: film, color: sell_color}]
+ - tags: {shop: video_games}
+ shapes: [{shape: pac_man, color: sell_color}]
+ - tags: {shop: watches}
+ shapes: [{shape: watches, color: sell_color}]
+ - tags: {craft: watchmaker}
+ shapes: [watches]
+ - tags: {shop: frame}
+ shapes: [{shape: frame, color: sell_color}]
+ - tags: {tourism: gallery}
+ shapes: [picture]
+ - tags: {amenity: cafe}
+ shapes: [coffee_cup]
+ - tags: {amenity: ice_cream}
+ shapes: [ice_cream]
+ - tags: {amenity: biergarten}
+ shapes: [beer_mug]
+ - tags: {amenity: nightclub}
+ shapes: [cocktail_glass_with_straw]
+ - tags: {amenity: restaurant}
+ shapes: [fork_and_knife]
+ - tags: {amenity: restaurant;bar}
+ shapes: [fork_and_knife]
+ add_shapes: [cocktail_glass]
+ - tags: {shop: ice_cream}
+ shapes: [ice_cream]
+ - tags: {shop: gift}
+ shapes: [gift]
+ - tags: {shop: clothes}
+ shapes: [{shape: t_shirt, color: sell_color}]
+ - tags: {amenity: shop, shop: clothes}
+ shapes: [t_shirt]
+ - tags: {shop: convenience}
+ shapes: [shop_convenience]
+ - tags: {amenity: shop, shop: convenience}
+ shapes: [shop_convenience]
+ - tags: {shop: electronics}
+ shapes: [{shape: tv, color: sell_color}]
+ - tags: {tourism: camp_site}
+ shapes: [camp]
+ - tags: {tourism: caravan_site}
+ shapes: [caravan]
+ - tags: {leisure: picnic_site}
+ shapes: [table]
+
+ - group: "Big objects not for all"
+ start_zoom_level: 15.0
+ tags:
+ - tags: {building: container}
+ shapes: [building_container]
+ - tags: {building: houseboat}
+ shapes: [houseboat]
+
+ - tags: {building: apartments}
+ shapes: [apartments_2_story]
+
+ - tags: {building: "*", building:levels: "1"}
+ shapes: [apartments_1_story]
+ - tags: {building: "*", building:levels: "1", roof:shape: gabled}
+ shapes: [apartments_1_story_gabled_roof]
+ - tags: {building: "*", building:levels: "1", roof:shape: hipped}
+ shapes: [apartments_1_story_gabled_roof]
+ - tags: {building: "*", building:levels: "1", roof:shape: pyramidal}
+ shapes: [apartments_1_story_gabled_roof]
+ - tags: {building: "*", building:levels: "1", roof:shape: skillion}
+ shapes: [apartments_1_story_skillion_roof]
+
+ - tags: {building: "*", building:levels: "2"}
+ shapes: [apartments_2_story]
+ - tags: {building: "*", building:levels: "2", roof:shape: gabled}
+ shapes: [apartments_2_story_gabled_roof]
+ - tags: {building: "*", building:levels: "2", roof:shape: hipped}
+ shapes: [apartments_2_story_gabled_roof]
+ - tags: {building: "*", building:levels: "2", roof:shape: pyramidal}
+ shapes: [apartments_2_story_gabled_roof]
+ - tags: {building: "*", building:levels: "2", roof:shape: skillion}
+ shapes: [apartments_2_story_skillion_roof]
+
+ - tags: {building: "*", building:levels: "3"}
+ shapes: [apartments_3_story]
+ - tags: {building: "*", building:levels: "3", roof:shape: gabled}
+ shapes: [apartments_3_story_gabled_roof]
+ - tags: {building: "*", building:levels: "3", roof:shape: hipped}
+ shapes: [apartments_3_story_gabled_roof]
+ - tags: {building: "*", building:levels: "3", roof:shape: pyramidal}
+ shapes: [apartments_3_story_gabled_roof]
+ - tags: {building: "*", building:levels: "3", roof:shape: skillion}
+ shapes: [apartments_3_story_skillion_roof]
+
+ - tags: {building: "*", building:levels: "4"}
+ shapes: [apartments_4_story]
+ - tags: {building: "*", building:levels: "4", roof:shape: gabled}
+ shapes: [apartments_4_story_gabled_roof]
+ - tags: {building: "*", building:levels: "4", roof:shape: hipped}
+ shapes: [apartments_4_story_gabled_roof]
+ - tags: {building: "*", building:levels: "4", roof:shape: pyramidal}
+ shapes: [apartments_4_story_gabled_roof]
+ - tags: {building: "*", building:levels: "4", roof:shape: skillion}
+ shapes: [apartments_4_story_skillion_roof]
+
+ - tags: {building: "*", building:levels: "5"}
+ shapes: [apartments_5_story]
+ - tags: {building: "*", building:levels: "5", roof:shape: gabled}
+ shapes: [apartments_5_story_gabled_roof]
+ - tags: {building: "*", building:levels: "5", roof:shape: hipped}
+ shapes: [apartments_5_story_gabled_roof]
+ - tags: {building: "*", building:levels: "5", roof:shape: pyramidal}
+ shapes: [apartments_5_story_gabled_roof]
+ - tags: {building: "*", building:levels: "5", roof:shape: skillion}
+ shapes: [apartments_5_story_skillion_roof]
+
+ - tags: {building: construction}
+ shapes: [building_construction]
+ - tags: {building: apartments, construction: "yes"}
+ shapes: [building_construction]
+ - tags: {building: "yes", construction: "yes"}
+ shapes: [building_construction]
+
+ - tags: {building: kindergarten}
+ shapes: [toy_horse]
+ - tags: {amenity: kindergarten}
+ shapes: [toy_horse]
+ - tags: {building: kindergarten, amenity: kindergarten}
+ shapes: [toy_horse]
+ - tags: {leisure: indoor_playground}
+ shapes: [toy_horse]
+ - tags: {building: office}
+ shapes: [briefcase]
+ - tags: {amenity: school}
+ location_restrictions: {include: [jp]}
+ shapes: [japan_elementary_school]
+ - tags: {office: "yes"}
+ shapes: [briefcase]
+ - tags: {office: company}
+ shapes: [briefcase]
+ - tags: {office: government}
+ shapes: [government]
+ - tags: {office: it}
+ shapes: [glider]
+ - tags: {office: telecommunication}
+ shapes: [telephone]
+
+ - group: "Not important big objects"
+ start_zoom_level: 15.0
+ tags:
+ - tags: {building: garage}
+ shapes: [garages]
+ - tags: {building: garages}
+ shapes: [garages]
+ - tags: {landuse: garages}
+ shapes: [garages]
+ - tags: {man_made: communications_tower}
+ location_restrictions: {include: [jp]}
+ shapes: [japan_tv_tower]
+ - tags: {man_made: communications_tower}
+ shapes: [tower_communication]
+ - tags: {man_made: telescope}
+ shapes: [telescope_radio]
+ - tags: {man_made: telescope, telescope:type: radio}
+ shapes: [telescope_radio]
+ - tags: {man_made: telescope, telescope:type: gamma}
+ shapes: [telescope_gamma]
+ - tags: {man_made: telescope, telescope:type: optical}
+ shapes: [observatory]
+ - tags: {man_made: tower}
+ shapes: [tower]
+ - tags: {man_made: tower, tower:construction: dish}
+ shapes: [telescope_radio]
+ - tags: {man_made: tower, tower:construction: dish, telescope:type: radio}
+ shapes: [telescope_radio]
+ - tags: {man_made: tower, tower:construction: dish, telescope:type: gamma}
+ shapes: [telescope_gamma]
+ - tags: {man_made: crane}
+ shapes: [crane]
+ - tags: {man_made: crane, crane:type: gantry_crane}
+ shapes: [crane_gantry]
+ - tags: {man_made: crane, crane:type: floor-mounted_crane}
+ shapes: [crane]
+ - tags: {man_made: crane, crane:type: portal_crane}
+ shapes: [crane_portal]
+ - tags: {man_made: crane, crane:type: travel_lift}
+ shapes: [crane_travel_lift]
+ - tags: {man_made: crane, crane:type: tower_crane}
+ shapes: [crane]
+
+ - group: "Emergency"
+ start_zoom_level: 15.0
+ tags:
+ - tags: {emergency: defibrillator}
+ shapes: [{shape: defibrillator, color: emergency_color}]
+ - tags: {emergency: fire_extinguisher}
+ shapes: [{shape: fire_extinguisher, color: emergency_color}]
+ - tags: {emergency: fire_hydrant}
+ shapes: [fire_hydrant]
+ - tags: {emergency: life_ring}
+ shapes: [{shape: life_ring, color: emergency_color}]
+ - tags: {emergency: phone}
+ shapes: [{shape: sos_phone, color: emergency_color}]
+
+ - group: "Transport-important middle objects"
+ start_zoom_level: 16.0
+ tags:
+ - tags: {ford: "yes"}
+ shapes: [ford]
+ - tags: {amenity: charging_station}
+ shapes: [charging_station]
+ - tags: {amenity: bicycle_repair_station}
+ shapes:
+ - {shape: bicycle, offset: [0, 2]}
+ - {shape: wrench, offset: [1, -5]}
+ - tags: {amenity: bicycle_rental}
+ shapes: [{shape: bicycle, offset: [0, 2]}, {shape: key, offset: [1, -4]}]
+ - tags: {amenity: fuel}
+ shapes: [fuel_station]
+ - tags: {amenity: parking}
+ shapes: [p]
+ - tags: {amenity: parking, parking: multi-storey}
+ shapes: [{shape: car, offset: [0, 4]}, {shape: car, offset: [0, -3]}]
+ - tags: {highway: turning_circle}
+ shapes: [circle_empty]
+ - tags: {highway: turning_loop}
+ shapes: [turning_loop]
+ - tags: {highway: crossing}
+ shapes: [crossing]
+ - tags: {crossing: zebra}
+ shapes: [crossing]
+ - tags: {highway: crossing, crossing: zebra}
+ shapes: [crossing]
+ - tags: {crossing: marked}
+ shapes: [crossing]
+ - tags: {highway: crossing, crossing: marked}
+ shapes: [crossing]
+ - tags: {highway: crossing, crossing_ref: zebra}
+ shapes: [crossing]
+ - tags: {highway: crossing, crossing: uncontrolled}
+ add_shapes: [no_traffic_signals]
+ - tags: {highway: crossing, crossing: traffic_signals}
+ add_shapes: [traffic_signals]
+ - tags: {highway: traffic_signals}
+ shapes: [traffic_signals]
+ - tags: {crossing_ref: toucan}
+ shapes: [toucan_crossing]
+
+ - tags: {traffic_calming: bump}
+ shapes: [bump]
+ - tags: {traffic_calming: mini_bumps}
+ shapes: [mini_bumps]
+ - tags: {traffic_calming: hump}
+ shapes: [hump]
+ - tags: {traffic_calming: table}
+ shapes: [traffic_table]
+ - tags: {traffic_calming: cushion}
+ shapes: [traffic_cushion]
+ - tags: {traffic_calming: rumble_strip}
+ shapes: [rumble_strip]
+ - tags: {traffic_calming: dip}
+ shapes: [dip]
+ - tags: {traffic_calming: double_dip}
+ shapes: [double_dip]
+
+ - group: "Important middle objects"
+ start_zoom_level: 16.0
+ tags:
+ - tags: {tourism: attraction, attraction: amusement_ride}
+ shapes: [amusement_ride]
+ - tags: {amenity: toilets}
+ shapes: [woman_and_man]
+ - tags: {amenity: shelter}
+ shapes: [shelter]
+ - tags: {man_made: obelisk}
+ shapes: [obelisk]
+ - tags: {historic: monument}
+ shapes: [monument]
+
+ - group: "Normal middle objects"
+ start_zoom_level: 17.0
+ tags:
+ - tags: {shop: kiosk}
+ shapes: [kiosk]
+ - tags: {building: "yes", shop: kiosk}
+ shapes: [kiosk]
+ - tags: {amenity: shop, shop: kiosk}
+ shapes: [kiosk]
+ - tags: {amenity: stage}
+ shapes: [curtains]
+ - tags: {amenity: hunting_stand}
+ shapes: [hunting_stand]
+ - tags: {natural: cave_entrance}
+ shapes: [cave]
+ - tags: {amenity: bureau_de_change}
+ shapes:
+ - {shape: exchange}
+ - {shape: dollar, offset: [-4, 3]}
+ - {shape: pound, offset: [5, -2]}
+ - tags: {sport: skateboard}
+ shapes: [skateboard]
+ - tags: {pipeline: substation}
+ shapes: [pipeline]
+
+ - group: "Towers, poles, masts"
+ start_zoom_level: 15.0
+ tags:
+ - tags: {building: ventilation_shaft}
+ shapes: [ventilation]
+ - tags: {power: generator}
+ shapes: [power_generator]
+ - tags: {amenity: public_bookcase}
+ shapes: [books]
+ - tags: {power: transformer}
+ shapes: [transformer]
+ - tags: {power: generator, generator:source: solar}
+ shapes: [solar_panel]
+ - tags: {power: heliostat}
+ shapes: [solar_panel]
+ - tags: {power: generator, generator:source: wind}
+ shapes: [wind_turbine]
+ - tags: {power: tower}
+ shapes: [power_tower_2_level]
+ - tags: {power: tower, design: one-level}
+ shapes: [power_tower_1_level]
+ - tags: {power: tower, design: two-level}
+ shapes: [power_tower_2_level]
+ - tags: {power: tower, design: three-level}
+ shapes: [power_tower_3_level]
+ - tags: {power: tower, design: four-level}
+ shapes: [power_tower_4_level]
+ - tags: {power: tower, design: donau}
+ shapes: [power_tower_donau]
+ - tags: {power: tower, design: donau_inverse}
+ shapes: [power_tower_donau_inverse]
+ - tags: {power: tower, design: barrel}
+ shapes: [power_tower_barrel]
+ - tags: {power: tower, design: asymmetric}
+ shapes: [power_tower_asymmetric]
+ - tags: {power: tower, design: triangle}
+ shapes: [power_tower_triangle]
+ - tags: {power: tower, design: flag}
+ shapes: [power_tower_flag]
+ - tags: {power: tower, design: delta}
+ shapes: [power_tower_delta]
+ - tags: {power: tower, design: delta_two-level}
+ shapes: [power_tower_delta_2_level]
+ - tags: {power: tower, design: delta_three-level}
+ shapes: [power_tower_delta_3_level]
+ - tags: {power: tower, design: y-frame}
+ shapes: [power_tower_y_frame]
+ - tags: {power: tower, design: x-frame}
+ shapes: [power_tower_x_frame]
+ - tags: {power: tower, design: h-frame}
+ shapes: [power_tower_h_frame]
+ - tags: {power: tower, design: h-frame_two-level}
+ shapes: [power_tower_h_frame_2_level]
+ - tags: {power: tower, design: guyed_h-frame}
+ shapes: [power_tower_guyed_h_frame]
+ - tags: {power: portal}
+ shapes: [power_tower_portal]
+ - tags: {power: tower, design: portal}
+ shapes: [power_tower_portal]
+ - tags: {power: tower, design: portal_two-level}
+ shapes: [power_tower_portal_2_level]
+ - tags: {power: portal, design: portal_two-level}
+ shapes: [power_tower_portal_2_level]
+ - tags: {power: tower, design: portal_three-level}
+ shapes: [power_tower_portal_3_level]
+ - tags: {power: portal, design: portal_three-level}
+ shapes: [power_tower_portal_3_level]
+
+ - tags: {power: pole}
+ shapes: [power_pole_2_level]
+ - tags: {power: pole, design: one-level}
+ shapes: [power_pole_1_level]
+ - tags: {power: pole, design: two-level}
+ shapes: [power_pole_2_level]
+ - tags: {power: pole, design: three-level}
+ shapes: [power_pole_3_level]
+ - tags: {power: pole, design: four-level}
+ shapes: [power_pole_4_level]
+ - tags: {power: pole, design: asymmetric}
+ shapes: [power_pole_asymmetric]
+ - tags: {power: pole, design: triangle}
+ shapes: [power_pole_triangle]
+ - tags: {power: pole, design: flag}
+ shapes: [power_pole_flag]
+ - tags: {power: pole, design: armless_asymmetric}
+ shapes: [power_pole_asymmetric_armless]
+ - tags: {power: pole, design: armless_triangle}
+ shapes: [power_pole_triangle_armless]
+ - tags: {power: pole, design: delta}
+ shapes: [power_pole_delta]
+ - tags: {power: pole, design: delta_two-level}
+ shapes: [power_pole_delta] # power_pole_delta_2_level
+ - tags: {power: pole, design: delta_three-level}
+ shapes: [power_pole_delta] # power_pole_delta_3_level
+
+ - tags: {man_made: chimney}
+ shapes: [chimney]
+ - tags: {man_made: tower, tower:type: cooling}
+ shapes: [tower_cooling]
+ - tags: {man_made: tower, tower:type: defensive}
+ shapes: [tower_defensive]
+ - tags: {man_made: tower, tower:type: pagoda}
+ shapes: [pagoda]
+ - tags: {man_made: tower, tower:type: observation}
+ shapes: [tower_observation]
+ - tags: {man_made: tower, tower:type: watchtower}
+ shapes: [tower_observation]
+ - tags: {man_made: tower, tower:type: minaret}
+ shapes: [minaret]
+ - tags: {man_made: mast}
+ shapes: [tube]
+ - tags: {man_made: stupa}
+ shapes: [stupa]
+
+ - tags: {man_made: mast, tower:construction: guyed_tube}
+ shapes: [tube_guyed]
+ - tags: {man_made: mast, tower:construction: freestanding}
+ shapes: [tube]
+ - tags: {man_made: mast, tower:construction: lattice}
+ shapes: [lattice]
+ - tags: {man_made: mast, tower:construction: guyed_lattice}
+ shapes: [lattice_guyed]
+ - tags: {man_made: mast, tower:type: lighting}
+ shapes:
+ - tube
+ - {shape: light_left, offset: [-3, -3]}
+ - {shape: light_right, offset: [3, -3]}
+ - tags: {man_made: mast, tower:type: communication}
+ shapes:
+ - tube
+ - {shape: wave_left, offset: [-3, -3]}
+ - {shape: wave_right, offset: [3, -3]}
+ - tags: {man_made: mast, tower:type: siren}
+ shapes:
+ - tube
+ - {shape: siren_left, offset: [-3, -3]}
+ - {shape: siren_right, offset: [3, -3]}
+ - tags: {man_made: mast, tower:type: monitoring}
+ shapes:
+ - tube
+ - {shape: dish_antenna_left, offset: [-3, -3]}
+ - {shape: dish_antenna_right, offset: [3, -3]}
+ - tags:
+ man_made: mast
+ tower:type: lighting
+ tower:construction: guyed_tube
+ shapes:
+ - tube_guyed
+ - {shape: light_left, offset: [-3, -3]}
+ - {shape: light_right, offset: [3, -3]}
+ - tags:
+ man_made: mast
+ tower:type: communication
+ tower:construction: guyed_tube
+ shapes:
+ - tube_guyed
+ - {shape: wave_left, offset: [-3, -3]}
+ - {shape: wave_right, offset: [3, -3]}
+ - tags:
+ man_made: mast
+ tower:type: siren
+ tower:construction: guyed_tube
+ shapes:
+ - tube_guyed
+ - {shape: siren_left, offset: [-3, -3]}
+ - {shape: siren_right, offset: [3, -3]}
+ - tags:
+ man_made: mast
+ tower:type: monitoring
+ tower:construction: guyed_tube
+ shapes:
+ - tube_guyed
+ - {shape: dish_antenna_left, offset: [-3, -3]}
+ - {shape: dish_antenna_right, offset: [3, -3]}
+ - tags:
+ man_made: mast
+ tower:type: lighting
+ tower:construction: freestanding
+ shapes:
+ - tube
+ - {shape: light_left, offset: [-3, -3]}
+ - {shape: light_right, offset: [3, -3]}
+ - tags:
+ man_made: mast
+ tower:type: communication
+ tower:construction: freestanding
+ shapes:
+ - tube
+ - {shape: wave_left, offset: [-3, -3]}
+ - {shape: wave_right, offset: [3, -3]}
+ - tags:
+ man_made: mast
+ tower:type: siren
+ tower:construction: freestanding
+ shapes:
+ - tube
+ - {shape: siren_left, offset: [-3, -3]}
+ - {shape: siren_right, offset: [3, -3]}
+ - tags:
+ man_made: mast
+ tower:type: monitoring
+ tower:construction: freestanding
+ shapes:
+ - tube
+ - {shape: dish_antenna_left, offset: [-3, -3]}
+ - {shape: dish_antenna_right, offset: [3, -3]}
+ - tags:
+ man_made: mast
+ tower:type: lighting
+ tower:construction: guyed_lattice
+ shapes:
+ - lattice_guyed
+ - {shape: light_left, offset: [-4, -3]}
+ - {shape: light_right, offset: [3, -3]}
+ - tags:
+ man_made: mast
+ tower:type: communication
+ tower:construction: guyed_lattice
+ shapes:
+ - lattice_guyed
+ - {shape: wave_left, offset: [-4, -3]}
+ - {shape: wave_right, offset: [3, -3]}
+ - tags:
+ man_made: mast
+ tower:type: siren
+ tower:construction: guyed_lattice
+ shapes:
+ - lattice_guyed
+ - {shape: siren_left, offset: [-4, -3]}
+ - {shape: siren_right, offset: [3, -3]}
+ - tags:
+ man_made: mast
+ tower:type: monitoring
+ tower:construction: guyed_lattice
+ shapes:
+ - lattice_guyed
+ - {shape: dish_antenna_left, offset: [-4, -3]}
+ - {shape: dish_antenna_right, offset: [3, -3]}
+ - tags:
+ man_made: mast
+ tower:type: lighting
+ tower:construction: lattice
+ shapes:
+ - lattice
+ - {shape: light_left, offset: [-4, -3]}
+ - {shape: light_right, offset: [3, -3]}
+ - tags:
+ man_made: mast
+ tower:type: communication
+ tower:construction: lattice
+ shapes:
+ - lattice
+ - {shape: wave_left, offset: [-4, -3]}
+ - {shape: wave_right, offset: [3, -3]}
+ - tags:
+ man_made: mast
+ tower:type: siren
+ tower:construction: lattice
+ shapes:
+ - lattice
+ - {shape: siren_left, offset: [-4, -3]}
+ - {shape: siren_right, offset: [3, -3]}
+ - tags:
+ man_made: mast
+ tower:type: monitoring
+ tower:construction: lattice
+ shapes:
+ - lattice
+ - {shape: dish_antenna_left, offset: [-4, -3]}
+ - {shape: dish_antenna_right, offset: [3, -3]}
+
+ - tags: {man_made: tower, tower:construction: guyed_tube}
+ shapes: [tube_guyed]
+ - tags: {man_made: tower, tower:construction: freestanding}
+ shapes: [tube]
+ - tags: {man_made: tower, tower:construction: lattice}
+ shapes: [lattice]
+ - tags: {man_made: tower, tower:construction: lattice_guyed}
+ shapes: [lattice_guyed]
+ - tags: {man_made: tower, tower:type: lighting}
+ shapes:
+ - tube
+ - {shape: light_left, offset: [-3, -3]}
+ - {shape: light_right, offset: [3, -3]}
+ - tags: {man_made: tower, tower:type: communication}
+ shapes:
+ - tube
+ - {shape: wave_left, offset: [-3, -3]}
+ - {shape: wave_right, offset: [3, -3]}
+ - tags:
+ man_made: tower
+ tower:type: lighting
+ tower:construction: guyed_tube
+ shapes:
+ - tube_guyed
+ - {shape: light_left, offset: [-3, -3]}
+ - {shape: light_right, offset: [3, -3]}
+ - tags:
+ man_made: tower
+ tower:type: communication
+ tower:construction: guyed_tube
+ shapes:
+ - tube_guyed
+ - {shape: wave_left, offset: [-3, -3]}
+ - {shape: wave_right, offset: [3, -3]}
+ - tags:
+ man_made: tower
+ tower:type: lighting
+ tower:construction: freestanding
+ shapes:
+ - tube
+ - {shape: light_left, offset: [-3, -3]}
+ - {shape: light_right, offset: [3, -3]}
+ - tags:
+ man_made: tower
+ tower:type: communication
+ tower:construction: freestanding
+ shapes:
+ - tube
+ - {shape: wave_left, offset: [-3, -3]}
+ - {shape: wave_right, offset: [3, -3]}
+ - tags:
+ man_made: tower
+ tower:type: lighting
+ tower:construction: guyed_lattice
+ shapes:
+ - lattice_guyed
+ - {shape: light_left, offset: [-4, -3]}
+ - {shape: light_right, offset: [3, -3]}
+ - tags:
+ man_made: tower
+ tower:type: communication
+ tower:construction: guyed_lattice
+ shapes:
+ - lattice_guyed
+ - {shape: wave_left, offset: [-4, -3]}
+ - {shape: wave_right, offset: [3, -3]}
+ - tags:
+ man_made: tower
+ tower:type: lighting
+ tower:construction: lattice
+ shapes:
+ - lattice
+ - {shape: light_left, offset: [-4, -3]}
+ - {shape: light_right, offset: [3, -3]}
+ - tags:
+ man_made: tower
+ tower:type: communication
+ tower:construction: lattice
+ shapes:
+ - lattice
+ - {shape: wave_left, offset: [-4, -3]}
+ - {shape: wave_right, offset: [3, -3]}
+
+ - tags: {man_made: flagpole, country: US}
+ shapes: [flag_usa]
+ - tags: {man_made: flagpole, country: Tanzania}
+ shapes: [flag_bend_sinister]
+ - tags: {man_made: flagpole, country: PH}
+ shapes: [flag_triangle_flanche]
+ - tags: {man_made: flagpole, country: FR}
+ shapes: [flag_vertical_triband]
+
+ # Diving towers.
+
+ - tags: {man_made: tower, tower:type: diving}
+ shapes: [diving_1_platforms]
+ - tags: {man_made: tower, tower:type: diving, tower:platforms: "1"}
+ shapes: [diving_1_platforms]
+ - tags: {man_made: tower, tower:type: diving, tower:platforms: "2"}
+ shapes: [diving_2_platforms]
+ - tags: {man_made: tower, tower:type: diving, tower:platforms: "3"}
+ shapes: [diving_3_platforms]
+ - tags: {man_made: tower, tower:type: diving, tower:platforms: "4"}
+ shapes: [diving_4_platforms]
+
+ - tags: {communication:mobile_phone: "yes"}
+ add_shapes: [phone]
+
+ - group: "Moon small objects"
+ start_zoom_level: 0.0
+ tags:
+ - tags: {man_made: rover}
+ shapes: [lunokhod]
+ - tags: {man_made: probe}
+ shapes: [probe]
+ - tags: {man_made: orbiter}
+ shapes: [orbiter]
+ - tags: {man_made: descent_stage}
+ shapes: [descent_stage]
+ - tags: {man_made: third_stage}
+ shapes: [third_stage]
+ - tags: {man_made: lander}
+ shapes: [lander]
+
+ - tags: {man_made: rover, condition: landed}
+ shapes: [lunokhod]
+ set_opacity: 0.8
+ - tags: {man_made: probe, condition: landed}
+ shapes: [probe]
+ set_opacity: 0.8
+ - tags: {man_made: orbiter, condition: landed}
+ shapes: [orbiter]
+ set_opacity: 0.8
+ - tags: {man_made: descent_stage, condition: landed}
+ shapes: [descent_stage]
+ set_opacity: 0.8
+ - tags: {man_made: third_stage, condition: landed}
+ shapes: [third_stage]
+ set_opacity: 0.8
+ - tags: {man_made: lander, condition: landed}
+ shapes: [lander]
+ set_opacity: 0.8
+
+ - tags: {man_made: rover, condition: crashed}
+ shapes: [lunokhod]
+ set_opacity: 0.5
+ - tags: {man_made: probe, condition: crashed}
+ shapes: [probe]
+ set_opacity: 0.5
+ - tags: {man_made: orbiter, condition: crashed}
+ shapes: [orbiter]
+ set_opacity: 0.5
+ - tags: {man_made: descent_stage, condition: crashed}
+ shapes: [descent_stage]
+ set_opacity: 0.5
+ - tags: {man_made: third_stage, condition: crashed}
+ shapes: [third_stage]
+ set_opacity: 0.5
+ - tags: {man_made: lander, condition: crashed}
+ shapes: [lander]
+ set_opacity: 0.5
+
+ - group: "Golf objects"
+ start_zoom_level: 16.0
+ tags:
+ - tags: {golf: tee}
+ shapes: [golf_tee]
+ - tags: {golf: pin}
+ shapes: [golf_pin]
+
+ - group: "Important small objects"
+ start_zoom_level: 17.0
+ tags:
+ - tags: {natural: spring}
+ shapes: [{shape: spring, color: water_border_color}]
+ - tags: {highway: elevator}
+ shapes: [elevator]
+ - tags: {historic: cannon}
+ shapes: [cannon]
+ - tags: {historic: olympic_flame}
+ shapes: [torch]
+ - tags: {historic: memorial}
+ shapes: [memorial]
+ - tags: {historic: memorial, memorial: plaque}
+ shapes: [plaque]
+ - tags: {historic: memorial, memorial: statue}
+ shapes: [statue]
+ - tags: {barrier: artwork, artwork_type: statue}
+ shapes: [statue]
+ - tags: {historic: stone}
+ shapes: [stone_with_inscription]
+ - tags: {historic: wayside_cross}
+ shapes: [cross_and_horizontal_bar]
+ - tags: {historic: wayside_shrine}
+ shapes: [wayside_shrine]
+ - tags: {historic: memorial, memorial: stone}
+ shapes: [stone_with_inscription]
+ - tags: {historic: tomb}
+ shapes: [tomb]
+ - tags: {tomb: "*"}
+ exception: {tomb: mausoleum} # TODO: add exception "tomb: pyramid"
+ shapes: [tomb]
+ - tags: {barrier: toll_booth}
+ shapes: [toll_booth]
+ - tags: {barrier: lift_gate}
+ shapes: [lift_gate]
+ - tags: {barrier: turnstile}
+ shapes: [turnstile]
+ - tags: {barrier: log}
+ shapes: [wood]
+ - tags: {barrier: chain}
+ shapes: [chain_barrier]
+ - tags: {railway: crossing}
+ shapes: [x]
+ - tags: {railway: railway_crossing}
+ shapes: [x]
+ - tags: {railway: level_crossing}
+ shapes: [x]
+ - tags: {railway: signal}
+ shapes: [signal]
+ - tags: {amenity: atm}
+ shapes: [atm]
+ - tags: {amenity: bicycle_parking}
+ shapes:
+ - {shape: bicycle, offset: [0, 2]}
+ - {shape: p_small, offset: [-6, -3]}
+ - tags: {amenity: bicycle_parking, bicycle_parking: stands}
+ shapes:
+ - {shape: p_small, offset: [-5, -3]}
+ - bicycle_parking_stand
+ - tags: {amenity: bicycle_parking, bicycle_parking: wall_loops}
+ shapes:
+ - {shape: p_small, offset: [-5, -3]}
+ - bicycle_parking_wall_loops
+ - tags: {amenity: bicycle_parking, bicycle_parking: rack}
+ shapes:
+ - {shape: p_small, offset: [-5, -3]}
+ - bicycle_parking_rack
+ - tags: {amenity: telephone}
+ shapes: [telephone]
+ - tags: {information: "*"}
+ shapes: [i]
+ replace_shapes: no
+ # tags: {tourism: "*"}
+ # shapes: [historic]
+ # replace_shapes: no
+ - tags: {tourism: information}
+ shapes: [i]
+ - tags: {information: guidepost}
+ shapes: [guidepost]
+ - tags: {tourism: viewpoint}
+ shapes: [binoculars]
+ - tags: {information: board}
+ shapes: [i_in_square]
+ - tags: {buoy: "*"}
+ shapes: [buoy]
+ - tags: {"seamark:type": "*"}
+ shapes: [buoy]
+ - tags: {"waterway:sign": "*"}
+ shapes: [buoy]
+ - tags: {amenity: drinking_water}
+ shapes: [drinking_water]
+ - tags: {tourism: artwork}
+ shapes: [picture]
+ - tags: {tourism: artwork, artwork_type: statue}
+ shapes: [statue]
+ - tags: {tourism: artwork, artwork_type: stone}
+ shapes: [stone_with_inscription]
+ - tags: {exhibit: artwork, artwork_type: statue}
+ shapes: [statue_exhibit]
+ - tags: {tourism: artwork, artwork_type: sculpture}
+ shapes: [statue]
+ - tags: {exhibit: artwork, artwork_type: sculpture}
+ shapes: [statue_exhibit]
+ - tags: {exhibit: artwork, artwork_type: painting}
+ shapes: [picture]
+ - tags: {exhibit: artwork, artwork_type: stained_glass}
+ shapes: [stained_glass]
+ - tags: {tourism: attraction}
+ shapes: [photo_camera]
+ - tags: {xmas:feature: tree}
+ shapes: {christmas_tree}
+
+ - group: "Normal small objects"
+ start_zoom_level: 18.0
+ tags:
+ - tags: {railway: switch}
+ shapes: [y]
+ - tags: {amenity: binoculars}
+ shapes: [binoculars_on_pole]
+ - tags: {amenity: parking_space}
+ shapes: [p_small]
+ - tags: {amenity: parking, parking: lane}
+ shapes: [p_small]
+ - tags: {amenity: parking, parking: street_side}
+ shapes: [p_small]
+ - tags: {amenity: post_box}
+ shapes: [envelope]
+ - tags: {amenity: recycling}
+ shapes: [recycling_container]
+ - tags: {amenity: recycling, recycling_type: container}
+ shapes: [recycling_container]
+ - tags: {amenity: shower}
+ shapes: [shower]
+ - tags: {amenity: vending_machine}
+ shapes: [vending_machine]
+ - tags: {vending: admission_tickets}
+ shapes: [vending_tickets]
+ - tags: {amenity: vending_machine, vending: admission_tickets}
+ shapes: [vending_tickets]
+ - tags: {vending: candles}
+ shapes: [vending_candles]
+ - tags: {amenity: vending_machine, vending: candles}
+ shapes: [vending_candles]
+ - tags: {vending: chemist}
+ shapes: [vending_chemist]
+ - tags: {amenity: vending_machine, vending: chemist}
+ shapes: [vending_chemist]
+ - tags: {vending: drinks}
+ shapes: [vending_bottle]
+ - tags: {amenity: vending_machine, vending: drinks}
+ shapes: [vending_bottle]
+ - tags: {vending: excrement_bags}
+ shapes: [vending_excrement_bag]
+ - tags: {amenity: vending_machine, vending: excrement_bags}
+ shapes: [vending_excrement_bag]
+ - tags: {vending: fishing_tackle}
+ shapes: [vending_angle]
+ - tags: {amenity: vending_machine, vending: fishing_tackle}
+ shapes: [vending_angle]
+ - tags: {vending: public_transport_tickets}
+ shapes: [vending_tickets]
+ - tags: {amenity: vending_machine, vending: public_transport_tickets}
+ shapes: [vending_tickets]
+ - tags: {vending: parking_tickets}
+ shapes: [vending_p]
+ - tags: {amenity: vending_machine, vending: parking_tickets}
+ shapes: [vending_p]
+ - tags: {vending: water}
+ shapes: [vending_drop]
+ - tags: {amenity: vending_machine, vending: water}
+ shapes: [vending_drop]
+ - tags: {fitness_station: horizontal_bar}
+ shapes: [horizontal_bar]
+ - tags: {fitness_station: rings}
+ shapes: [rings]
+ - tags: {fitness_station: wall_bars}
+ shapes: [wall_bars]
+ - tags: {fitness_station: sit-up}
+ shapes: [sit_up]
+ - tags: {fitness_station: horizontal_ladder}
+ shapes: [horizontal_ladder]
+ - tags: {fitness_station: push-up}
+ shapes: [low_horizontal_bars]
+ - tags: {playground: hopscotch}
+ shapes: [hopscotch]
+ - tags: {playground: slide}
+ shapes: [slide]
+ - tags: {attraction: water_slide}
+ shapes: [slide_and_water]
+ - tags: {playground: roundabout}
+ shapes: [roundabout]
+ - tags: {playground: sandpit}
+ shapes: [sandpit]
+ - tags: {playground: seesaw}
+ shapes: [seesaw]
+ - tags: {playground: horizontal_bar}
+ shapes: [horizontal_bar]
+ - tags: {leisure: picnic_table}
+ shapes: [table]
+ - tags: {highway: traffic_mirror}
+ shapes: [side_mirror]
+ - tags: {amenity: dressing_room}
+ shapes: [hanger]
+
+ - group: "Entrances"
+ start_zoom_level: 18.0
+ tags:
+ - tags: {amenity: parking_entrance}
+ shapes:
+ - {shape: p, offset: [-1, 0]}
+ - {shape: arrow_right, offset: [4, 5]}
+ - tags: {amenity: parking_entrance, parking: underground}
+ shapes: [{shape: p, offset: [-1, 0]}, {shape: arrow_down, offset: [4, 5]}]
+ - tags: {amenity: parking_entrance, parking: multi-storey}
+ shapes: [{shape: p, offset: [-1, 0]}, {shape: arrow_up, offset: [4, 5]}]
+ - tags: {entrance: gate}
+ shapes: [gate]
+ - tags: {barrier: gate}
+ shapes: [gate]
+ - tags: {entrance: garage}
+ shapes: [garage_door]
+ - tags: {entrance: main}
+ shapes: [main_entrance]
+ - tags: {barrier: entrance}
+ shapes: [entrance]
+ - tags: {barrier: door}
+ shapes: [entrance]
+ - tags: {entrance: "yes"}
+ shapes: [entrance]
+ - tags: {building: entrance}
+ shapes: [{shape: entrance, color: "#FF0000"}]
+ - tags: {entrance: shop}
+ shapes: [entrance]
+ - tags: {entrance: exit}
+ shapes: [exit]
+ - tags: {entrance: service}
+ shapes: [door_with_keyhole]
+ - tags: {entrance: staircase}
+ shapes: [staircase]
+ - tags: {door: "no"}
+ shapes: [no_door]
+
+ - group: "Not important small objects"
+ start_zoom_level: 18.0
+ tags:
+ - tags: {amenity: bench}
+ shapes: [bench]
+ - tags: {amenity: bench, backrest: "yes"}
+ shapes: [bench_backrest]
+ - tags: {amenity: bench, backrest: "no"}
+ shapes: [bench_no_backrest]
+ - tags: {amenity: bench, tourism: artwork, artwork_type: sculpture}
+ shapes: [bench_with_statue]
+ - tags: {amenity: bench, tourism: artwork, artwork_type: statue}
+ shapes: [bench_with_statue]
+ - tags: {historic: memorial, memorial: bench, amenity: bench}
+ shapes: [bench_with_inscription]
+ - tags: {historic: memorial, memorial: bench}
+ shapes: [bench_with_inscription]
+ - tags: {memorial: bench}
+ shapes: [bench_with_inscription]
+ - tags: {amenity: clock}
+ shapes: [clock]
+ - tags: {amenity: fountain}
+ shapes: [{shape: fountain, color: water_border_color}]
+ - tags: {fountain: bubbler}
+ shapes: [fountain_bubbler]
+ - tags: {fountain: roman_wolf}
+ shapes: [fountain_roman_wolf]
+ - tags: {fountain: toret}
+ shapes: [fountain_toret]
+ - tags: {amenity: waste_basket}
+ shapes: [waste_basket]
+ - tags: {amenity: waste_disposal}
+ shapes: [waste_disposal]
+ - tags: {highway: street_lamp}
+ shapes: [street_lamp]
+ - tags: {amenity: bbq}
+ shapes: [bbq]
+ - tags: {leisure: firepit}
+ shapes: [fire_pit]
+ - tags: {man_made: cross}
+ shapes: [latin_cross]
+ - tags: {man_made: flagpole}
+ shapes: [flagpole]
+ - tags: {man_made: manhole}
+ shapes: [circle_9]
+ - tags: {manhole: drain}
+ shapes: [manhole_drain]
+ - tags: {man_made: pole}
+ shapes: [pole]
+ - tags: {man_made: pole, highway: street_lamp}
+ shapes: [pole_lamp]
+ - tags: {man_made: street_cabinet}
+ shapes: [street_cabinet]
+ - tags: {man_made: surveillance}
+ shapes: [cctv]
+ - tags: {man_made: surveillance, camera:type: dome, camera:mount: ceiling}
+ shapes: [cctv_dome_ceiling]
+ - tags: {man_made: surveillance, camera:type: dome, camera:mount: wall}
+ shapes: [cctv_dome_wall]
+ - tags: {man_made: ventilation_shaft}
+ shapes: [ventilation]
+ - tags: {railway: ventilation_shaft}
+ shapes: [ventilation]
+ - tags: {advertising: billboard}
+ shapes: [billboard]
+ - tags: {advertising: column}
+ shapes: [advertising_column]
+ - tags: {natural: human}
+ shapes: [human]
+ - tags: {natural: rock}
+ shapes: [stone]
+ - tags: {natural: stone}
+ shapes: [stone]
+ - tags: {sloped_curb: "yes"}
+ shapes: [lowered_kerb]
+ - tags: {kerb: lowered}
+ shapes: [lowered_kerb]
+
+ - tags: {railway: buffer_stop}
+ shapes: [buffer_stop]
+ - tags: {traffic_sign: "*"}
+ shapes: [guidepost]
+ - tags: {traffic_sign: city_limit}
+ shapes: [city_limit_sign]
+ - tags: {traffic_sign: maxspeed, maxspeed: "^(\\d)(\\d)$"}
+ shapes: [
+ circle_11,
+ {shape: digit_#maxspeed0, offset: [-2, 0], color: "#FFFFFF"},
+ {shape: digit_#maxspeed1, offset: [2, 0], color: "#FFFFFF"},
+ ]
+ - tags: {traffic_sign: maxspeed, maxspeed: "^(\\d)(\\d) mph$"}
+ shapes: [
+ speed_limit_mph,
+ {shape: digit_#maxspeed0, offset: [-2, 2]},
+ {shape: digit_#maxspeed1, offset: [2, 2]},
+ ]
+ - tags: {highway: milestone}
+ shapes: [milestone]
+ - tags: {traffic_sign: stop}
+ shapes: [stop]
+ - tags: {highway: give_way}
+ shapes: [triangle_down_hollow]
+ - tags: {noexit: "yes"}
+ shapes: [t]
+ - tags: {barrier: block}
+ shapes: [block]
+ - tags: {barrier: rock}
+ shapes: [stone]
+ - tags: {barrier: bollard}
+ shapes: [bollard]
+ - tags: {barrier: kerb}
+ shapes: [kerb]
+ - tags: {tank_trap: czech_hedgehog}
+ shapes: [czech_hedgehog]
+ - tags: {tank_trap: dragons_teeth}
+ shapes: [dragons_teeth]
+ - tags: {tank_trap: toblerone}
+ shapes: [dragons_teeth]
+
+ - group: "Trees"
+ start_zoom_level: 18.0
+ tags:
+ - tags: {natural: tree}
+ shapes: [{shape: tree, color: tree_color, outline: no}]
+ - tags: {leaf_type: broadleaved}
+ shapes: [{shape: tree_with_leaf, color: tree_color}]
+ - tags: {leaf_type: needleleaved}
+ shapes: [{shape: needleleaved_tree, color: tree_color}]
+ - tags: {leaf_type: palm}
+ shapes: [{shape: palm, color: tree_color}]
+ - tags: {natural: tree, leaf_type: broadleaved}
+ shapes: [{shape: tree_with_leaf, color: tree_color}]
+ - tags: {natural: tree, leaf_type: needleleaved}
+ shapes: [{shape: needleleaved_tree, color: tree_color}]
+ - tags: {natural: tree, leaf_type: palm}
+ shapes: [{shape: palm, color: tree_color}]
+ - tags: {natural: tree, type: conifer}
+ shapes: [{shape: needleleaved_tree, color: tree_color}]
+ - tags: {leaf_cycle: deciduous}
+ set_main_color: decidious_color
+ - tags: {leaf_cycle: evergreen}
+ set_main_color: evergreen_color
+ - tags: {natural: tree, leaf_cycle: deciduous}
+ set_main_color: decidious_color
+ - tags: {natural: tree, leaf_cycle: evergreen}
+ set_main_color: evergreen_color
+ - tags: {natural: bush}
+ shapes: [{shape: bush, color: tree_color}]
+
+ - tags: {natural: tree, genus: Betula}
+ shapes: [{shape: betula, color: tree_color}]
+ - tags: {natural: tree, "genus:en": Birch}
+ shapes: [{shape: betula, color: tree_color}]
+ - tags: {natural: tree, "genus:ru": Берёза}
+ shapes: [{shape: betula, color: tree_color}]
+
+ - tags: {natural: tree, genus: Acer}
+ shapes: [{shape: tree, color: tree_color}]
+ add_shapes: [{shape: leaf_maple, color: tree_color}]
+ - tags: {natural: tree, "genus:en": Maple}
+ shapes: [{shape: tree, color: tree_color}]
+ add_shapes: [{shape: leaf_maple, color: tree_color}]
+
+ - tags: {natural: tree, genus: Malus}
+ shapes: [{shape: tree, color: tree_color}]
+ add_shapes: [{shape: apple, color: tree_color}]
+
+ - tags: {natural: tree, genus: Pyrus}
+ shapes: [{shape: tree, color: tree_color}]
+ add_shapes: [{shape: pear, color: tree_color}]
+
+ - group: "Indoor"
+ start_zoom_level: 18.0
+ tags:
+ - tags: {door: "yes"}
+ shapes: [entrance]
+ - tags: {indoor: pillar}
+ shapes: [pillar]
+
+ - group: "Mass objects"
+ start_zoom_level: 19.0
+ tags:
+ - tags: {tourism: camp_pitch}
+ shapes: [camp]
+
+ - group: "Add and over"
+ tags:
+ - tags: {support: pole}
+ over_icon: [support_pole]
+ under_icon: [clock, i_in_square]
+ - tags: {support: wall_mounted}
+ over_icon: [support_wall]
+ under_icon: [clock, i_in_square]
+ - tags: {support: column}
+ over_icon: [support_column]
+ under_icon: [clock, i_in_square]
+ - tags: {amenity: "*", karaoke: "yes"}
+ add_shapes: [microphone]
+ - tags: {building: "*", "roof:shape": onion}
+ add_shapes: [onion_roof_shape]
+ - tags: {natural: tree, denotation: urban}
+ over_icon: [urban_tree_pot]
+ under_icon: [tree, tree_with_leaf, needleleaved_tree, betula, palm]
+ - tags: {natural: tree, denotation: avenue}
+ over_icon: [bottom_right_horizontal_line]
+ under_icon: [tree, tree_with_leaf, needleleaved_tree, betula, palm]
+
+ - tags: {wheelchair: "yes"}
+ add_shapes: [wheelchair]
+ - tags: {wheelchair: "no"}
+ add_shapes: [no_wheelchair]
+ - tags: {foot: "yes"}
+ add_shapes: [foot]
+ - tags: {foot: "no"}
+ add_shapes: [no_foot]
+ - tags: {bicycle: "yes"}
+ add_shapes: [bicycle]
+ - tags: {bicycle: "no"}
+ shapes:
+ - {shape: bicycle, offset: [0, 2]}
+ - {shape: x_4, offset: [-5, -4]}
+ - tags: {internet_access: wlan, "internet_access:fee": "no"}
+ add_shapes:
+ - {shape: wlan, offset: [0, -3]}
+ - {shape: free, offset: [0, 5]}
+ - tags: {internet_access: wlan}
+ exception: {"internet_access:fee": "*"}
+ add_shapes: [wlan]
+ - tags: {material: wood}
+ add_shapes: [{shape: wood, color: trunk_color}]
+ - tags: {access: private}
+ add_shapes: [lock_with_keyhole]
+ - tags: {access: "no"}
+ add_shapes: [lock]
+ - tags: {direction: clockwise}
+ add_shapes: [clockwise]
+ - tags: {direction: anticlockwise}
+ add_shapes: [counterclockwise]
+ - tags: {atm: "yes"}
+ add_shapes: [atm]
+ - tags: {tactile_paving: "yes"}
+ add_shapes: [tactile_paving]
+ - tags: {tactile_paving: "no"}
+ add_shapes:
+ - {shape: tactile_paving, offset: [0, 2]}
+ - {shape: x_5, offset: [0, -3]}
+ - tags: {"payment:credit_cards": "yes"}
+ add_shapes: [credit_card]
+
+ - tags: {bus: "yes"}
+ add_shapes: [bus]
+ - tags: {motorcar: "yes"}
+ add_shapes: [car]
+ - tags: {car: "yes"}
+ add_shapes: [car]
+ - tags: {monorail: "yes"}
+ add_shapes: [monorail]
+ - tags: {trolleybus: "yes"}
+ add_shapes: [trolleybus]
+
+ - tags: {recycling:glass_bottles: "yes"}
+ add_shapes: [bottle]
+ - tags: {recycling:paper: "yes"}
+ add_shapes: [gazette]
+ - tags: {recycling:glass: "yes"}
+ add_shapes: [bottle_and_wine_glass]
+ - tags: {recycling:clothes: "yes"}
+ add_shapes: [t_shirt]
+ - tags: {recycling:shoes: "yes"}
+ add_shapes: [shoe]
+ - tags: {recycling:green_waste: "yes"}
+ add_shapes: [apple]
+ - tags: {recycling:paper_packaging: "yes"}
+ add_shapes: [gazette]
+ - tags: {recycling:newspaper: "yes"}
+ add_shapes: [gazette]
+ - tags: {recycling:magazines: "yes"}
+ add_shapes: [gazette]
+ - tags: {recycling:books: "yes"}
+ add_shapes: [book]
+ - tags: {recycling:wood: "yes"}
+ add_shapes: [{shape: wood, color: trunk_color}]
+ - tags: {recycling:glass_bottles:colour: "yes"}
+ add_shapes: [{shape: bottle, color: green}]
+ - tags: {recycling:cartons: "yes"}
+ add_shapes: [aseptic_carton]
+ - tags: {recycling:beverage_cartons: "yes"}
+ add_shapes: [aseptic_carton]
+ - tags: {recycling:organic: "yes"}
+ add_shapes: [apple]
+ - tags: {recycling:tetrapak: "yes"}
+ add_shapes: [aseptic_carton]
+ - tags: {recycling:tyres: "yes"}
+ add_shapes: [tyre]
+ - tags: {recycling:toys: "yes"}
+ add_shapes: [toy_horse]
+ - tags: {recycling:verre: "yes"}
+ add_shapes: [bottle_and_wine_glass]
+ - tags: {recycling:bags: "yes"}
+ add_shapes: [bag]
+
+ - tags: {crossing:island: "yes"}
+ add_shapes: [rectangle_vertical_rounded]
+ - tags: {crossing:island: "no"}
+ add_shapes: [rectangle_vertical_rounded_crossed]
+
+ - tags: {parking: "yes"}
+ add_shapes: [p]
+ - tags: {drinking_water: "yes"}
+ add_shapes: [drinking_water]
+ - tags: {toilets: "yes"}
+ add_shapes: [woman_and_man]
+ - tags: {washing_machine: "yes"}
+ add_shapes: [washing_machine]
+ - tags: {shower: "yes"}
+ add_shapes: [shower]
+ - tags: {dog: "yes"}
+ add_shapes: [dog]
+
+ # For tourism=camp_pitch
+
+ - tags: {caravans: "yes"}
+ add_shapes: [caravan]
+ - tags: {tents: "yes"}
+ add_shapes: [camp]
+ - tags: {fireplace: "yes"}
+ add_shapes: [fireplace]
+ - tags: {openfire: "yes"}
+ add_shapes: [fire_pit]
+ - tags: {bbq: "yes"}
+ add_shapes: [bbq]
+
+roads:
+ - tags: {highway: motorway}
+ default_width: 7.0
+ border_color: motorway_border_color
+ color: motorway_color
+ priority: 41.8
+ - tags: {highway: trunk}
+ default_width: 7.0
+ border_color: motorway_border_color
+ color: motorway_color
+ priority: 41.0
+ - tags: {highway: trunk_link}
+ default_width: 7.0
+ border_color: motorway_border_color
+ color: motorway_color
+ priority: 41.0
+ - tags: {highway: primary}
+ default_width: 7.0
+ border_color: primary_border_color
+ color: primary_color
+ priority: 41.7
+ - tags: {highway: motorway_link}
+ default_width: 7.0
+ border_color: motorway_border_color
+ color: motorway_color
+ priority: 41.8
+ - tags: {highway: secondary}
+ default_width: 7.0
+ border_color: secondary_border_color
+ priority: 41.6
+ color: secondary_color
+ - tags: {highway: secondary_link}
+ default_width: 7.0
+ border_color: secondary_border_color
+ priority: 41.6
+ color: secondary_color
+ - tags: {highway: tertiary}
+ default_width: 7.0
+ border_color: tertiary_border_color
+ priority: 41.5
+ color: tertiary_color
+ - tags: {highway: tertiary_link}
+ default_width: 7.0
+ border_color: tertiary_border_color
+ priority: 41.5
+ color: tertiary_color
+ - tags: {highway: unclassified}
+ default_width: 5.0
+ border_color: road_border_color
+ priority: 41.0
+ - tags: {highway: residential}
+ default_width: 5.0
+ border_color: road_border_color
+ priority: 41.0
+ - tags: {highway: living_street}
+ default_width: 4.0
+ border_color: road_border_color
+ priority: 41.0
+ - tags: {highway: service}
+ exception: {service: parking_aisle}
+ default_width: 3.0
+ border_color: road_border_color
+ priority: 41.0
+ - tags: {highway: service, service: parking_aisle}
+ default_width: 2.0
+ border_color: road_border_color
+ priority: 41.0
+ - tags: {leisure: track}
+ color: pitch_color
+ border_color: pitch_border_color
+ default_width: 5.0
+ priority: 21.0
+
+ - tags: {highway: raceway}
+ color: pitch_color
+ border_color: pitch_border_color
+ default_width: 7.0
+ priority: 21.0
+
+ways:
+ - tags: {man_made: bridge}
+ style: {fill: "#AAAAAA"}
+ priority: 22.0
+ - tags: {indoor: area}
+ style:
+ stroke: indoor_border_color
+ stroke-width: 1.0
+ fill: indoor_color
+ priority: 10.0
+ - tags: {indoor: corridor}
+ style:
+ stroke: indoor_border_color
+ stroke-width: 1.0
+ fill: indoor_color
+ priority: 11.0
+ - tags: {highway: corridor}
+ style:
+ stroke: "#00FF00"
+ stroke-width: 5.0
+ priority: 11.0
+ - tags: {indoor: "yes", area: "yes"}
+ style:
+ stroke: indoor_color
+ stroke-width: 1.0
+ fill: indoor_color
+ priority: 12.0
+ - tags: {indoor: room}
+ style:
+ stroke: indoor_border_color
+ stroke-width: 1.0
+ priority: 12.0
+ - tags: {indoor: elevator, area: "yes"}
+ style:
+ stroke: indoor_color
+ stroke-width: 1.0
+ fill: indoor_color
+ priority: 12.0
+ - tags: {indoor: column}
+ style:
+ stroke: indoor_column_color
+ stroke-width: 1.0
+ fill: indoor_column_color
+ priority: 13.0
+
+ - tags: {power: line}
+ style:
+ stroke: "#000000"
+ stroke-width: 1.0
+ opacity: 0.2
+ priority: 80.0
+ - tags: {power: cable}
+ style:
+ stroke: "#000000"
+ stroke-width: 1.0
+ opacity: 0.1
+ priority: 80.0
+ - tags: {golf: hole}
+ style:
+ stroke: "#44AA33"
+ stroke-width: 1.0
+ opacity: 0.6
+ priority: 80.0
+ - tags: {man_made: pipeline}
+ style:
+ stroke: "#888888"
+ stroke-width: 1.0
+ stroke-dasharray: "12.0,1.5"
+ priority: 80.0
+ - tags: {man_made: pipeline}
+ style:
+ stroke: "#888888"
+ stroke-width: 3.0
+ stroke-dasharray: "1.0,10.0,1.0,1.5"
+ priority: 80.0
+
+ - tags: {attraction: water_slide}
+ style:
+ stroke: "#FFFFFF"
+ stroke-width: 2.0
+ priority: 81
+ - tags: {attraction: water_slide}
+ style:
+ stroke: "#888888"
+ stroke-width: 4.0
+ priority: 80.0
+
+ - tags: {highway: track}
+ style:
+ stroke-width: 1.5
+ stroke: track_color
+ stroke-linecap: round
+ stroke-linejoin: round
+ priority: 41.0
+ - tags: {highway: footway}
+ exception: {area: "yes", type: "multipolygon"}
+ style:
+ stroke-width: 3.0
+ stroke: foot_border_color
+ stroke-linecap: round
+ stroke-linejoin: round
+ priority: 41.0
+ - tags: {highway: pedestrian}
+ exception: {area: "yes"}
+ style:
+ stroke-width: 3.0
+ stroke: foot_border_color
+ stroke-linecap: round
+ stroke-linejoin: round
+ priority: 41.0
+ - tags: {highway: cycleway}
+ exception: {area: "yes"}
+ style:
+ stroke-width: 3.0
+ stroke: foot_border_color
+ stroke-linecap: round
+ stroke-linejoin: round
+ priority: 41.0
+ - tags: {highway: steps}
+ style:
+ stroke-width: 6.0
+ stroke: foot_border_color
+ stroke-linecap: butt
+ - tags: {highway: path}
+ style:
+ stroke-width: 3.0
+ stroke: foot_border_color
+ priority: 41.0
+
+ - tags: {highway: footway}
+ exception: {area: "yes", type: "multipolygon"}
+ style:
+ stroke-width: 1.5
+ stroke-dasharray: 7.0,3.0
+ stroke-linecap: round
+ stroke-linejoin: round
+ stroke: foot_color
+ priority: 42.0
+ - tags: {highway: pedestrian}
+ exception: {area: "yes"}
+ style:
+ stroke-width: 1.5
+ stroke-dasharray: 7.0,3.0
+ stroke-linecap: round
+ stroke-linejoin: round
+ stroke: foot_color
+ priority: 42.0
+ - tags: {highway: footway, area: "yes"}
+ style:
+ stroke: foot_area_border_color
+ fill: foot_area_color
+ stroke-linecap: round
+ stroke-linejoin: round
+ priority: 55.0
+ - tags: {highway: footway, type: "multipolygon"}
+ style:
+ stroke: foot_area_border_color
+ fill: foot_area_color
+ stroke-linecap: round
+ stroke-linejoin: round
+ priority: 55.0
+ - tags: {highway: pedestrian, area: "yes"}
+ style:
+ stroke: none
+ fill: foot_area_color
+ stroke-linecap: round
+ stroke-linejoin: round
+ priority: -55.0 # FIXME
+ - tags: {highway: cycleway}
+ exception: {area: "yes"}
+ style:
+ stroke-width: 1.0
+ stroke: cycle_color
+ stroke-dasharray: 8.0,2.0
+ stroke-linecap: butt
+ priority: 42.0
+ - tags: {highway: steps, conveying: "*"}
+ style:
+ stroke-width: 5.0
+ stroke-dasharray: 1.5,2.0
+ stroke-linecap: butt
+ stroke: "#888888"
+ priority: 42.0
+ - tags: {highway: steps}
+ exception: {conveying: "*"}
+ style:
+ stroke-width: 5.0
+ stroke-dasharray: 1.5,2.0
+ stroke-linecap: butt
+ stroke: foot_color
+ priority: 42.0
+ - tags: {highway: path}
+ style:
+ stroke-width: 1.5
+ stroke-dasharray: 5.0,3.0
+ stroke-linecap: butt
+ stroke: foot_color
+ priority: 42.0
+
+ - tags: {aeroway: runway}
+ style:
+ stroke-width: 50.0
+ stroke: runway_color
+ priority: 22.0
+ - tags: {aeroway: taxiway}
+ style:
+ stroke-width: 50.0
+ stroke: taxiway_color
+ priority: 21.0
+ - tags: {aeroway: runway}
+ style:
+ stroke-width: 2.0
+ stroke: "#DDDDDD"
+ stroke-dasharray: 40.0,20.0
+ priority: 23.0
+ - tags: {aeroway: taxiway}
+ style:
+ stroke-width: 1.0
+ stroke: "#CCCCCC"
+ priority: 23.0
+ - tags: {aeroway: parking_position}
+ style:
+ stroke-width: 1.0
+ stroke: "#DDCC00"
+ priority: 23.0
+ - tags: {area:aeroway: taxiway}
+ style:
+ fill: "#CCCCCC"
+ priority: 20.0
+
+ - tags: {natural: wood}
+ style:
+ fill: wood_color
+ priority: 21.0
+ - tags: {natural: wetland}
+ style:
+ fill: wetland_color
+ priority: 21.0
+ - tags: {natural: grassland}
+ style:
+ fill: grass_color
+ stroke: grass_border_color
+ priority: 20.0
+ - tags: {natural: scrub}
+ style:
+ fill: wood_color
+ priority: 21.0
+ - tags: {natural: sand}
+ style:
+ fill: sand_color
+ priority: 20.0
+ - tags: {natural: beach}
+ style:
+ fill: beach_color
+ priority: 20.0
+ - tags: {natural: heath}
+ style:
+ fill: "#DDDDDD"
+ priority: 20.0
+ - tags: {natural: glacier}
+ style:
+ fill: "#FFFFFF"
+ priority: 20.0
+ - tags: {natural: desert}
+ style:
+ fill: desert_color
+ priority: 20.0
+ - tags: {natural: forest}
+ style:
+ fill: wood_color
+ priority: 21.0
+ - tags: {natural: tree_row}
+ priority: 21.0
+ style:
+ stroke: wood_color
+ stroke-width: 5.0
+ stroke-linecap: round
+ stroke-linejoin: round
+ - tags: {natural: water}
+ exception: {intermittent: "yes"}
+ style:
+ fill: water_color
+ # stroke: water_border_color
+ # stroke-width: 1.0
+ priority: 21.0
+ - tags: {natural: water, intermittent: "yes"}
+ style:
+ fill: water_color
+ opacity: 0.5
+ # stroke: water_border_color
+ # stroke-width: 1.0
+ priority: 21.0
+ - tags: {natural: coastline}
+ style:
+ # fill: water_color
+ stroke: water_border_color
+ stroke-width: 1.0
+ priority: 21.0
+ - tags: {natural: ridge}
+ style:
+ stroke-width: 2.0
+ opacity: 0.3
+ stroke: ridge_color
+ priority: 21.0
+ - tags: {natural: bare_rock}
+ style:
+ fill: rock_color
+ - tags: {natural: cliff}
+ style:
+ stroke-width: 1.0
+ stroke: "#BBBBBB"
+ - tags: {natural: cliff}
+ parallel_offset: -2.5
+ style:
+ stroke-width: 5.0
+ stroke: "#BBBBBB"
+ stroke-dasharray: "1,10"
+ - tags: {natural: scree}
+ style:
+ fill: scree_color
+
+ - tags: {landuse: allotments}
+ style:
+ fill: allotments_color
+ priority: 20.0
+ - tags: {landuse: conservation}
+ style:
+ fill: grass_color
+ priority: 20.0
+ - tags: {landuse: construction}
+ style:
+ fill: construction_color
+ - tags: {landuse: farmland}
+ style: {fill: farmland_color, stroke: farmland_border_color}
+ priority: 20.0
+ - tags: {landuse: greenhouse_horticulture}
+ style: {fill: farmland_color, stroke: farmland_border_color}
+ priority: 20.0
+ - tags: {landuse: farmyard}
+ style: {fill: farmland_color, stroke: farmland_border_color} # FIXME
+ priority: 20.0
+ - tags: {landuse: farmland, crop: wheat}
+ style: {fill: wheat_color, stroke: wheat_border_color}
+ priority: 20.0
+ - tags: {landuse: farmland, crop: barley}
+ style: {fill: barley_color, stroke: barley_border_color}
+ priority: 20.0
+ - tags: {landuse: farmland, crop: rye}
+ style: {fill: rye_color, stroke: rye_dark_color}
+ priority: 20.0
+ - tags: {landuse: forest}
+ style:
+ fill: wood_color
+ priority: 20.0
+ - tags: {landuse: garages}
+ style:
+ fill: parking_color
+ priority: 21.0
+ - tags: {landuse: village_green}
+ style:
+ fill: village_green_color
+ priority: 20.0
+ - tags: {landuse: grass}
+ style:
+ fill: grass_color
+ stroke: grass_border_color
+ priority: 20.0
+ - tags: {landuse: orchard}
+ style:
+ fill: orchard_color
+ priority: 21.0
+ - tags: {landuse: meadow}
+ style:
+ fill: meadow_color
+ stroke: meadow_border_color
+ priority: 20.0
+ - tags: {tourism: camp_pitch}
+ style:
+ stroke: "#000000"
+ opacity: 0.1
+ priority: 100.0
+
+ # Hidden land use
+
+ - tags: {landuse: cemetery}
+ style:
+ fill: hidden_color
+ opacity: 0.05
+ priority: 1.0
+ - tags: {landuse: commercial}
+ style:
+ fill: hidden_color
+ opacity: 0.05
+ priority: 1.0
+ - tags: {landuse: industrial}
+ style:
+ fill: hidden_color
+ opacity: 0.05
+ priority: 1.0
+ - tags: {landuse: military}
+ style:
+ fill: hidden_color
+ opacity: 0.05
+ priority: 1.0
+ - tags: {landuse: railway}
+ style:
+ fill: hidden_color
+ opacity: 0.05
+ priority: 1.0
+ - tags: {landuse: residential}
+ style:
+ fill: hidden_color
+ opacity: 0.05
+ priority: 1.0
+ - tags: {power: substation}
+ style:
+ fill: hidden_color
+ opacity: 0.05
+ priority: 1.0
+
+ - tags: {amenity: ferry_terminal}
+ style:
+ fill: ferry_terminal_color
+ priority: 50.0
+ - tags: {amenity: parking}
+ style:
+ fill: parking_color
+ opacity: 0.5
+ - tags: {amenity: parking_space}
+ style:
+ stroke: "#FFFFFF"
+
+ - tags: {aeroway: landingpad}
+ style:
+ fill: "#000000"
+ opacity: 0.1
+ - tags: {aeroway: helipad}
+ style:
+ fill: "#440044"
+ opacity: 0.1
+
+ - tags: {waterway: river}
+ style:
+ stroke: water_color
+ stroke-width: 2.5
+ priority: 22.0
+ - tags: {waterway: canal}
+ style:
+ stroke: water_color
+ stroke-width: 2.0
+ priority: 22.0
+ - tags: {waterway: stream}
+ style:
+ stroke: water_color
+ stroke-width: 1.5
+ priority: 22.0
+ - tags: {waterway: riverbank}
+ style:
+ fill: water_color
+ stroke: water_border_color
+ stroke-width: 1.0
+ priority: 22.0
+ - tags: {waterway: ditch}
+ style:
+ fill: water_color
+ stroke: water_color
+ stroke-width: 2.0
+ priority: 22.0
+
+ - tags: {railway: subway}
+ style:
+ stroke-width: 3.5
+ opacity: 0.7
+ stroke: "#AAAAAA"
+ priority: 41.0
+ - tags: {railway: rail}
+ style:
+ stroke-width: 3.0
+ stroke: "#BBBBBB"
+ priority: 42.0
+ - tags: {railway: light_rail}
+ style:
+ stroke-width: 3.0
+ stroke: "#CCCCCC"
+ priority: 42.0
+ - tags: {railway: monorail}
+ style:
+ stroke-width: 3.0
+ stroke: "#CCCCCC"
+ priority: 42.0
+ - tags: {railway: funicular}
+ style:
+ stroke-width: 3.0
+ stroke: "#CCCCCC"
+ priority: 42.0
+ - tags: {railway: narrow_gauge}
+ style:
+ stroke-width: 3.0
+ stroke: "#DDDDDD"
+ priority: 42.0
+ - tags: {railway: tram}
+ style:
+ stroke-width: 3.0
+ stroke: "#BBBBBB"
+ priority: 42.0
+ - tags: {railway: construction}
+ style:
+ stroke-width: 2.0
+ stroke: "#000000"
+ stroke-dasharray: 6,3
+ opacity: 0.3
+ priority: 42.0
+ - tags: {railway: disused}
+ style:
+ stroke-width: 3.0
+ stroke: "#000000"
+ stroke-dasharray: 6,6
+ opacity: 0.3
+ priority: 42.0
+ - tags: {railway: abandoned}
+ style:
+ stroke-width: 3.0
+ stroke: "#000000"
+ stroke-dasharray: 6,9
+ opacity: 0.3
+ priority: 42.0
+
+ - tags: {railway: rail}
+ style:
+ stroke-width: 1.0
+ stroke: "#444444"
+ priority: 43.0
+ - tags: {railway: light_rail}
+ style:
+ stroke-width: 1.0
+ stroke: "#444444"
+ priority: 43.0
+ - tags: {railway: monorail}
+ style:
+ stroke-width: 1.0
+ stroke: "#444444"
+ priority: 43.0
+ - tags: {railway: funicular}
+ style:
+ stroke-width: 1.0
+ stroke: "#444444"
+ priority: 43.0
+ - tags: {railway: narrow_gauge}
+ style:
+ stroke-width: 1.0
+ stroke: "#444444"
+ priority: 43.0
+ - tags: {railway: tram}
+ style:
+ stroke-width: 1.0
+ stroke: "#444444"
+ priority: 43.0
+
+ - tags: {railway: platform}
+ style:
+ fill: platform_color
+ stroke-width: 1.0
+ stroke: platform_border_color
+ priority: 41.0
+
+ - tags: {route: ferry}
+ style:
+ stroke-width: 1.0
+ stroke-dasharray: 3.0,3.0
+ stroke-linecap: butt
+ stroke: route_color
+ priority: 42.0
+
+ - tags: {leisure: garden}
+ style:
+ fill: grass_color
+ priority: 21.0
+ - tags: {leisure: park}
+ style:
+ fill: park_color
+ opacity: 0.5
+ - tags: {landuse: recreation_ground}
+ style:
+ fill: grass_color
+ opacity: 0.5
+ - tags: {leisure: recreation_ground}
+ style:
+ fill: grass_color
+ opacity: 0.5
+ - tags: {leisure: stadium}
+ style:
+ fill: grass_color
+ opacity: 0.5
+ - tags: {leisure: golf_course}
+ style:
+ fill: grass_color
+ opacity: 0.5
+ - tags: {leisure: pitch}
+ style:
+ fill: pitch_color
+ stroke: pitch_border_color
+ stroke-width: 1.0
+ priority: 21.0
+ - tags: {leisure: bleachers}
+ style:
+ fill: pitch_color
+ stroke: pitch_border_color
+ stroke-width: 1.0
+ priority: 21.0
+ - tags: {leisure: track, area: "yes"}
+ style:
+ fill: pitch_color
+ stroke: pitch_border_color
+ stroke-width: 1.0
+ priority: 21.0
+ - tags: {leisure: fitness_station}
+ style:
+ fill: pitch_color
+ stroke: pitch_border_color
+ stroke-width: 1.0
+ priority: 21.0
+ - tags: {leisure: playground}
+ style:
+ fill: playground_color
+ stroke: playground_border_color
+ priority: 21.0
+ - tags: {leisure: swimming_pool}
+ style:
+ fill: water_color
+ stroke: water_border_color
+ stroke-width: 1.0
+
+ - tags: {tourism: artwork}
+ style:
+ stroke: "#888888"
+ stroke-width: 1.0
+ priority: 10.0
+ - tags: {barrier: hedge}
+ style:
+ fill: none
+ stroke: wood_color
+ stroke-width: 4.0
+ priority: 40.0
+ - tags: {barrier: city_wall}
+ style:
+ fill: none
+ stroke: "#000000"
+ stroke-width: 2.0
+ opacity: 0.5
+ priority: 40.0
+ - tags: {barrier: wall}
+ style:
+ fill: none
+ stroke: "#000000"
+ stroke-width: 1.5
+ opacity: 0.4
+ priority: 40.0
+ - tags: {man_made: embankment}
+ style:
+ fill: none
+ stroke: "#000000"
+ stroke-width: 1.0
+ opacity: 0.3
+ priority: 40.0
+ - tags: {man_made: embankment}
+ parallel_offset: -1.5
+ style:
+ fill: none
+ stroke: "#000000"
+ stroke-width: 3.0
+ opacity: 0.3
+ stroke-dasharray: "1,7"
+ priority: 40.0
+ - tags: {barrier: fence}
+ style:
+ fill: none
+ stroke: "#000000"
+ stroke-width: 1.0
+ opacity: 0.25
+ priority: 40.0
+ - tags: {barrier: retaining_wall}
+ style:
+ fill: none
+ stroke: "#000000"
+ stroke-width: 1.0
+ opacity: 0.25
+ priority: 40.0
+ - tags: {barrier: handrail}
+ style:
+ fill: none
+ stroke: "#000000"
+ stroke-width: 1.0
+ opacity: 0.2
+ priority: 40.0
+ - tags: {barrier: kerb}
+ style:
+ fill: none
+ stroke: "#000000"
+ stroke-width: 1.0
+ opacity: 0.15
+ priority: 40.0
+
+ - tags: {border: "*"}
+ style:
+ stroke: "#FF0000"
+ stroke-width: 0.5
+ stroke-dasharray: 10.0,20.0
+ - tags: {"area:highway": "*"}
+
+ - tags: {boundary: "*"}
+ # style:
+ # stroke: boundary_color
+ # stroke-width: 0.3
+ # stroke-dasharray: 10.0,5.0
+ priority: 60.0
+
+area_tags:
+ - tags: {aeroway: "*"}
+ - tags: {building: "*"}
+ - tags: {landuse: "*"}
+ - tags: {leisure: "*"}
+ - tags: {natural: "*"}
+ exception: {natural: "tree_row"}
+ - tags: {indoor: "corridor"}
+ - tags: {power: "compensator"}
+ - tags: {power: "substation"}
+
+keys_to_write:
+ - "STIF:zone"
+ - "alt_name"
+ - "artist_name"
+ - "booth"
+ - "branch"
+ - "brand"
+ - "capacity"
+ - "cladr:code"
+ - "collection_times"
+ - "created_by"
+ - "cuisine"
+ - "cyclestreets_id"
+ - "description"
+ - "designation"
+ - "destination"
+ - "ele"
+ - "email"
+ - "end_date"
+ - "facebook"
+ - "fax"
+ - "fhrs:confidence_management"
+ - "fhrs:hygiene"
+ - "fhrs:id"
+ - "fhrs:inspectiondate"
+ - "fhrs:local_authority_id"
+ - "fhrs:rating"
+ - "fhrs:rating_date"
+ - "flickr"
+ - "full_name"
+ - "genus"
+ - "height"
+ - "image"
+ - "information"
+ - "inscription"
+ - "int_name"
+ - "is_in"
+ - "last_collection"
+ - "local_ref"
+ - "manufacturer"
+ - "media:commons"
+ - "min_height"
+ - "name"
+ - "naptan:AltCommonName"
+ - "naptan:AltStreet"
+ - "naptan:AtcoCode"
+ - "naptan:Bearing"
+ - "naptan:BusStopType"
+ - "naptan:CommonName"
+ - "naptan:Crossing"
+ - "naptan:Indicator"
+ - "naptan:Landmark"
+ - "naptan:NaptanCode"
+ - "naptan:Notes"
+ - "naptan:PlusbusZoneRef"
+ - "naptan:ShotCommonName"
+ - "naptan:Street"
+ - "naptan:verified"
+ - "network"
+ - "official_name"
+ - "old_name"
+ - "opening_hours"
+ - "opening_hours:url"
+ - "operator"
+ - "phone"
+ - "phone_1"
+ - "platforms"
+ - "postal_code"
+ - "ref"
+ - "ref_no"
+ - "route_ref"
+ - "royal_cypher"
+ - "seats"
+ - "species"
+ - "start_date"
+ - "survey:date"
+ - "taxon"
+ - "telephone"
+ - "twitter"
+ - "uk_postcode_centroid"
+ - "uri"
+ - "url"
+ - "voltage"
+ - "website"
+ - "website_2"
+ - "wikidata"
+ - "wikipedia"
+
+prefix_to_write:
+ - "addr"
+ - "alt_name"
+ - "contact"
+ - "description"
+ - "genus"
+ - "inscription"
+ - "is_in"
+ - "manufacturer"
+ - "name"
+ - "old_name"
+ - "operator"
+ - "route_ref"
+ - "species"
+ - "taxon"
+ - "website"
+ - "wikipedia"
+
+keys_to_skip:
+ - "FIXME"
+ - "attribution"
+ - "building:levels"
+ - "building:part"
+ - "comment"
+ - "created_by"
+ - "curve_geometry"
+ - "diameter_crown"
+ - "fixme"
+ - "import_uuid"
+ - "indoor"
+ - "junction"
+ - "layer"
+ - "level"
+ - "level:ref"
+ - "location:transition"
+ - "mapillary"
+ - "naptan:verified:note"
+ - "note"
+ - "osak:identifier"
+ - "place"
+ - "ref:opendataparis:adresse"
+ - "ref:opendataparis:geo_point_2d"
+ - "source"
+ - "source_ref"
+
+prefix_to_skip:
+ - "demolished"
+ - "mapillary"
+ - "source"
+ - "razed"
+ - "removed"
+
+tags_to_skip:
+ highway: motorway_junction
diff --git a/mia/bev/styles/mia.yml b/mia/bev/styles/mia.yml
new file mode 100644
index 0000000000000000000000000000000000000000..56dd490c6864a296c3c92d5c1ce1540fac6ebc6e
--- /dev/null
+++ b/mia/bev/styles/mia.yml
@@ -0,0 +1,1116 @@
+options:
+
+ draw_nodes: no
+ draw_trees: no
+ draw_craters: no
+ draw_buildings: yes
+ draw_directions: no
+
+ driving_side: right
+ infer_sidewalks: false
+
+colors:
+ # General guidelines:
+ def_road_color: "#000" # Black
+ def_crossing_color: "#F00" # Red
+ def_explicit_pedestrian: "#FF0" # Yellow
+ def_explicit_void: "#FFF" # White
+ def_park_color: "#0F0" # Green
+ def_building_color: "#F0F" # Magenta
+ def_water_color: "#00F" # Blue
+ def_terrain_color: "#0FF" # Cyan
+ def_parking_color: "#AAA" # Dark Grey
+ def_train_color: "#555" # Light Grey
+
+
+ # Entity
+ def_default: {color: def_explicit_void}
+ default: "#FFF" # Must be hardcoded for some reason
+ extra: {color: def_default}
+
+ background_color: {color: def_default}
+ road_color: {color: def_road_color}
+
+ wheat_color: {color: def_terrain_color}
+ wheat_border_color: {color: def_terrain_color}
+ wheat_dark_color: {color: def_terrain_color}
+ rye_color: {color: def_terrain_color}
+ rye_dark_color: {color: def_terrain_color}
+ oat_color: {color: def_terrain_color}
+ oat_dark_color: {color: def_terrain_color}
+ barley_color: {color: def_terrain_color}
+ barley_border_color: {color: def_terrain_color}
+ barley_dark_color: {color: def_terrain_color}
+ sunflower_dark_color: {color: def_terrain_color}
+
+ motorway_border_color: {color: def_road_color}
+ motorway_color: {color: def_road_color}
+ primary_border_color: {color: def_road_color}
+ primary_color: {color: def_road_color}
+ secondary_border_color: {color: def_road_color}
+ secondary_color: {color: def_road_color}
+ tertiary_border_color: {color: def_road_color}
+ tertiary_color: {color: def_road_color}
+
+ bridge_color: {color: def_road_color}
+ ford_color: {color: def_water_color}
+ embankment_color_color: {color: def_default}
+
+ allotments_color: "#FFFFFF"
+ beach_color: {color: def_default}
+ building_border_color: {color: def_building_color}
+ building_color: {color: def_building_color}
+ building_construction_border_color: {color: def_building_color}
+ building_construction_color: {color: def_building_color}
+ construction_color: {color: def_building_color}
+
+ cycle_color: {color: def_explicit_pedestrian}
+ desert_color: {color: def_terrain_color}
+ decidious_color: {color: def_default}
+ emergency_color: {color: def_default}
+ evergreen_color: {color: def_terrain_color}
+ farmland_color: {color: def_terrain_color}
+ farmland_border_color: {color: def_terrain_color}
+ farmland_darker_color: {color: def_terrain_color}
+ ferry_terminal_color: {color: def_building_color}
+ foot_area_color: {color: def_explicit_pedestrian}
+ foot_area_border_color: {color: def_explicit_pedestrian}
+ foot_border_color: {color: def_explicit_pedestrian}
+ foot_color: {color: def_explicit_pedestrian}
+ grass_border_color: {color: def_terrain_color}
+ grass_color: {color: def_terrain_color}
+ hidden_color: {color: def_default}
+ indoor_border_color: {color: def_building_color}
+ indoor_color: {color: def_building_color}
+ indoor_column_color: {color: def_building_color}
+ meadow_border_color: {color: def_terrain_color}
+ meadow_color: {color: def_terrain_color}
+ orchard_color: {color: def_terrain_color}
+ orchard_border_color: {color: def_terrain_color}
+ outline_color: {color: def_default}
+ parking_color: {color: def_parking_color}
+ park_color: {color: def_park_color}
+ pitch_color: {color: def_terrain_color}
+ pitch_border_color: {color: def_terrain_color}
+ platform_border_color: {color: def_building_color}
+ platform_color: {color: def_building_color}
+ playground_border_color: {color: def_park_color}
+ playground_color: {color: def_park_color}
+ ridge_color: {color: def_default}
+ road_border_color: {color: def_road_color}
+ rock_color: {color: def_default}
+ route_color: {color: def_road_color}
+ sand_color: {color: def_default}
+ scree_color: {color: def_default}
+ track_color: {color: def_default}
+ trunk_color: {color: def_default}
+ tree_color: {color: def_default}
+ village_green_color: {color: def_default}
+ wall_bottom_1_color: {color: def_building_color}
+ wall_bottom_2_color: {color: def_building_color}
+ wall_color: {color: def_building_color}
+ wall_construction_color: {color: def_building_color}
+ water_border_color: {color: def_water_color}
+ water_color: {color: def_water_color}
+ wetland_color: {color: def_water_color}
+ wood_border_color: {color: def_terrain_color}
+ wood_color: {color: def_terrain_color}
+ sidewalk_color: {color: def_explicit_pedestrian}
+
+ runway_color: {color: def_road_color}
+ runway_border_color: {color: def_road_color}
+ taxiway_color: {color: def_road_color}
+ taxiway_border_color: {color: def_road_color}
+
+ sell_color: "#880088"
+ craft_color: "#008800"
+
+ # Colors not in W3C
+
+ rose: "#FF007F" # Wikipedia
+ slate_blue: "#6A5ACD" # W3C slateblue
+
+carto_colors:
+ building_border_color: "#00FF00"
+ building_color: "#00FF00"
+ cemetery_color: "#00FFFF"
+ commercial_color: "#FFFF00"
+ commercial_border_color: "#FFFF00"
+ grass_color: "#00FF00"
+ industrial_color: "#FFFF00"
+ industrial_border_color: "#FFFF00"
+ military_color: "#FFFF00"
+ park_color: "#00FF00"
+ residential_color: "#FFFF00"
+ residential_border_color: "#FFFF00"
+
+material_colors:
+
+ bronze: "#CD7F32"
+ concrete: "#AAAAAA"
+ glass: "#CCEEFF"
+
+roads:
+ - tags: {highway: motorway}
+ default_width: 13.0
+ border_color: motorway_border_color
+ color: motorway_color
+ priority: 41.8
+ - tags: {highway: trunk}
+ default_width: 13.0
+ border_color: motorway_border_color
+ color: motorway_color
+ priority: 41.0
+ - tags: {highway: trunk_link}
+ default_width: 12.0
+ border_color: motorway_border_color
+ color: motorway_color
+ priority: 41.0
+ - tags: {highway: primary}
+ default_width: 13.0
+ border_color: primary_border_color
+ color: primary_color
+ priority: 41.7
+ - tags: {highway: motorway_link}
+ default_width: 12.0
+ border_color: motorway_border_color
+ color: motorway_color
+ priority: 41.8
+ - tags: {highway: secondary}
+ default_width: 13.0
+ border_color: secondary_border_color
+ priority: 41.6
+ color: secondary_color
+ - tags: {highway: secondary_link}
+ default_width: 12.0
+ border_color: secondary_border_color
+ priority: 41.6
+ color: secondary_color
+ - tags: {highway: tertiary}
+ default_width: 12.0
+ border_color: tertiary_border_color
+ priority: 41.5
+ color: tertiary_color
+ - tags: {highway: tertiary_link}
+ default_width: 11.0
+ border_color: tertiary_border_color
+ priority: 41.5
+ color: tertiary_color
+ - tags: {highway: unclassified}
+ default_width: 8.0
+ border_color: road_border_color
+ priority: 41.0
+ - tags: {highway: residential}
+ default_width: 10.0
+ border_color: road_border_color
+ priority: 41.0
+ - tags: {highway: living_street}
+ default_width: 7.0
+ border_color: road_border_color
+ priority: 41.0
+ - tags: {highway: service}
+ exception: {service: parking_aisle}
+ default_width: 5.0
+ border_color: road_border_color
+ priority: 41.0
+ - tags: {highway: service, service: parking_aisle}
+ default_width: 4.0
+ border_color: road_border_color
+ priority: 41.0
+ - tags: {leisure: track}
+ color: pitch_color
+ border_color: pitch_border_color
+ default_width: 6.0
+ priority: 21.0
+
+ - tags: {highway: raceway}
+ color: pitch_color
+ border_color: pitch_border_color
+ default_width: 8.0
+ priority: 21.0
+
+ways:
+ - tags: {man_made: bridge}
+ style: {fill: def_road_color}
+ priority: 22.0
+ - tags: {indoor: area}
+ style:
+ stroke: indoor_border_color
+ stroke-width: 1.0
+ fill: indoor_color
+ priority: 10.0
+ - tags: {indoor: corridor}
+ style:
+ stroke: indoor_border_color
+ stroke-width: 1.0
+ fill: indoor_color
+ priority: 11.0
+ - tags: {highway: corridor}
+ style:
+ stroke: "#00FF00"
+ stroke-width: 5.0
+ priority: 11.0
+ - tags: {indoor: "yes", area: "yes"}
+ style:
+ stroke: indoor_color
+ stroke-width: 1.0
+ fill: indoor_color
+ priority: 12.0
+ - tags: {indoor: room}
+ style:
+ stroke: indoor_border_color
+ stroke-width: 1.0
+ priority: 12.0
+ - tags: {indoor: elevator, area: "yes"}
+ style:
+ stroke: indoor_color
+ stroke-width: 1.0
+ fill: indoor_color
+ priority: 12.0
+ - tags: {indoor: column}
+ style:
+ stroke: indoor_column_color
+ stroke-width: 1.0
+ fill: indoor_column_color
+ priority: 13.0
+
+ - tags: {power: line}
+ style:
+ stroke: "#000000"
+ stroke-width: 1.0
+ opacity: 00
+ priority: 80.0
+ - tags: {power: cable}
+ style:
+ stroke: "#000000"
+ stroke-width: 1.0
+ opacity: 0
+ priority: 80.0
+ - tags: {golf: hole}
+ style:
+ stroke: "#44AA33"
+ stroke-width: 1.0
+ opacity: 0
+ priority: 80.0
+ - tags: {man_made: pipeline}
+ style:
+ stroke: "#888888"
+ stroke-width: 1.0
+ stroke-dasharray: "12.0,1.5"
+ opacity: 0
+ priority: 80.0
+ - tags: {man_made: pipeline}
+ style:
+ stroke: "#888888"
+ stroke-width: 3.0
+ stroke-dasharray: "1.0,10.0,1.0,1.5"
+ opacity: 0
+ priority: 80.0
+
+ - tags: {attraction: water_slide}
+ style:
+ stroke: "#FFFFFF"
+ stroke-width: 2.0
+ priority: 81
+ - tags: {attraction: water_slide}
+ style:
+ stroke: "#888888"
+ stroke-width: 4.0
+ priority: 80.0
+
+ - tags: {highway: track}
+ style:
+ stroke-width: 2
+ stroke: track_color
+ stroke-linecap: round
+ stroke-linejoin: round
+ priority: 41.0
+ - tags: {highway: footway}
+ exception: {area: "yes", type: "multipolygon"}
+ style:
+ stroke-width: 5.0
+ stroke: foot_border_color
+ stroke-linecap: round
+ stroke-linejoin: round
+ priority: 41.0
+ - tags: {footway: crossing}
+ style:
+ stroke-width: 5.0
+ stroke: def_crossing_color
+ stroke-linecap: round
+ stroke-linejoin: round
+ priority: 99.0
+ - tags: {cycleway: crossing}
+ style:
+ stroke-width: 5.0
+ stroke: def_crossing_color
+ stroke-linecap: round
+ stroke-linejoin: round
+ priority: 99.0
+ - tags: {highway: pedestrian}
+ exception: {area: "yes"}
+ style:
+ stroke-width: 5.0
+ stroke: foot_border_color
+ stroke-linecap: round
+ stroke-linejoin: round
+ priority: 41.0
+ - tags: {highway: cycleway}
+ exception: {area: "yes"}
+ style:
+ stroke-width: 3.0
+ stroke: foot_border_color
+ stroke-linecap: round
+ stroke-linejoin: round
+ priority: 41.0
+ - tags: {highway: steps}
+ style:
+ stroke-width: 6.0
+ stroke: foot_border_color
+ stroke-linecap: butt
+ - tags: {highway: path}
+ style:
+ stroke-width: 3.0
+ stroke: foot_border_color
+ priority: 41.0
+
+ - tags: {highway: footway}
+ exception: {area: "yes", type: "multipolygon"}
+ style:
+ stroke-width: 4
+ stroke-dasharray: 7.0,3.0
+ stroke-linecap: round
+ stroke-linejoin: round
+ stroke: foot_color
+ priority: 42.0
+ - tags: {highway: pedestrian}
+ exception: {area: "yes"}
+ style:
+ stroke-width: 4
+ stroke-dasharray: 7.0,3.0
+ stroke-linecap: round
+ stroke-linejoin: round
+ stroke: foot_color
+ priority: 42.0
+ - tags: {highway: footway, area: "yes"}
+ style:
+ stroke: foot_area_border_color
+ fill: foot_area_color
+ stroke-linecap: round
+ stroke-linejoin: round
+ priority: 55.0
+ - tags: {highway: footway, type: "multipolygon"}
+ style:
+ stroke: foot_area_border_color
+ fill: foot_area_color
+ stroke-linecap: round
+ stroke-linejoin: round
+ priority: 55.0
+ - tags: {highway: pedestrian, area: "yes"}
+ style:
+ stroke: none
+ fill: foot_area_color
+ stroke-linecap: round
+ stroke-linejoin: round
+ priority: 55.0
+ - tags: {highway: cycleway}
+ exception: {area: "yes"}
+ style:
+ stroke-width: 1.5
+ stroke: cycle_color
+ stroke-dasharray: 8.0,2.0
+ stroke-linecap: butt
+ priority: 42.0
+ - tags: {highway: steps, conveying: "*"}
+ style:
+ stroke-width: 5.0
+ stroke-dasharray: 1.5,2.0
+ stroke-linecap: butt
+ stroke: "#888888"
+ priority: 42.0
+ - tags: {highway: steps}
+ exception: {conveying: "*"}
+ style:
+ stroke-width: 5.0
+ stroke-dasharray: 1.5,2.0
+ stroke-linecap: butt
+ stroke: foot_color
+ priority: 42.0
+ - tags: {highway: path}
+ style:
+ stroke-width: 4
+ stroke-dasharray: 5.0,3.0
+ stroke-linecap: butt
+ stroke: foot_color
+ priority: 42.0
+
+ - tags: {aeroway: runway}
+ style:
+ stroke-width: 50.0
+ stroke: runway_color
+ priority: 22.0
+ - tags: {aeroway: taxiway}
+ style:
+ stroke-width: 50.0
+ stroke: taxiway_color
+ priority: 21.0
+ - tags: {aeroway: runway}
+ style:
+ stroke-width: 2.0
+ stroke: "#DDDDDD"
+ stroke-dasharray: 40.0,20.0
+ priority: 23.0
+ - tags: {aeroway: taxiway}
+ style:
+ stroke-width: 1.0
+ stroke: "#CCCCCC"
+ priority: 23.0
+ - tags: {aeroway: parking_position}
+ style:
+ stroke-width: 1.0
+ stroke: "#DDCC00"
+ priority: 23.0
+ - tags: {area:aeroway: taxiway}
+ style:
+ fill: "#CCCCCC"
+ priority: 20.0
+
+ - tags: {natural: wood}
+ style:
+ fill: wood_color
+ priority: 21.0
+ - tags: {natural: wetland}
+ style:
+ fill: wetland_color
+ priority: 21.0
+ - tags: {natural: grassland}
+ style:
+ fill: grass_color
+ stroke: grass_border_color
+ priority: 20.0
+ - tags: {natural: scrub}
+ style:
+ fill: wood_color
+ priority: 21.0
+ - tags: {natural: sand}
+ style:
+ fill: sand_color
+ priority: 20.0
+ - tags: {natural: beach}
+ style:
+ fill: beach_color
+ priority: 20.0
+ - tags: {natural: heath}
+ style:
+ fill: "#DDDDDD"
+ priority: 20.0
+ - tags: {natural: glacier}
+ style:
+ fill: "#FFFFFF"
+ priority: 20.0
+ - tags: {natural: desert}
+ style:
+ fill: desert_color
+ priority: 20.0
+ - tags: {natural: forest}
+ style:
+ fill: wood_color
+ priority: 21.0
+ - tags: {natural: tree_row}
+ priority: 21.0
+ style:
+ stroke: wood_color
+ stroke-width: 5.0
+ stroke-linecap: round
+ stroke-linejoin: round
+ - tags: {natural: water}
+ exception: {intermittent: "yes"}
+ style:
+ fill: def_water_color
+ # stroke: water_border_color
+ # stroke-width: 1.0
+ priority: 21.0
+ - tags: {natural: water, intermittent: "yes"}
+ style:
+ fill: def_water_color
+ opacity: 1
+ # stroke: water_border_color
+ # stroke-width: 1.0
+ priority: 21.0
+ - tags: {natural: coastline}
+ style:
+ # fill: def_water_color
+ stroke: water_border_color
+ stroke-width: 1.0
+ priority: 21.0
+ - tags: {natural: ridge}
+ style:
+ stroke-width: 2.0
+ opacity: 0
+ stroke: ridge_color
+ priority: 21.0
+ - tags: {natural: bare_rock}
+ style:
+ fill: rock_color
+ - tags: {natural: cliff}
+ style:
+ stroke-width: 1.0
+ stroke: "#BBBBBB"
+ - tags: {natural: cliff}
+ parallel_offset: -2.5
+ style:
+ stroke-width: 5.0
+ stroke: "#BBBBBB"
+ stroke-dasharray: "1,10"
+ - tags: {natural: scree}
+ style:
+ fill: scree_color
+
+ - tags: {landuse: allotments}
+ style:
+ fill: allotments_color
+ priority: 20.0
+ - tags: {landuse: conservation}
+ style:
+ fill: grass_color
+ priority: 20.0
+ - tags: {landuse: construction}
+ style:
+ fill: construction_color
+ - tags: {landuse: farmland}
+ style: {fill: farmland_color, stroke: farmland_border_color}
+ priority: 20.0
+ - tags: {landuse: greenhouse_horticulture}
+ style: {fill: farmland_color, stroke: farmland_border_color}
+ priority: 20.0
+ - tags: {landuse: farmyard}
+ style: {fill: farmland_color, stroke: farmland_border_color} # FIXME
+ priority: 20.0
+ - tags: {landuse: farmland, crop: wheat}
+ style: {fill: wheat_color, stroke: wheat_border_color}
+ priority: 20.0
+ - tags: {landuse: farmland, crop: barley}
+ style: {fill: barley_color, stroke: barley_border_color}
+ priority: 20.0
+ - tags: {landuse: farmland, crop: rye}
+ style: {fill: rye_color, stroke: rye_dark_color}
+ priority: 20.0
+ - tags: {landuse: forest}
+ style:
+ fill: wood_color
+ priority: 20.0
+ - tags: {landuse: garages}
+ style:
+ fill: parking_color
+ priority: 21.0
+ - tags: {landuse: village_green}
+ style:
+ fill: village_green_color
+ priority: 20.0
+ - tags: {landuse: grass}
+ style:
+ fill: grass_color
+ stroke: grass_border_color
+ priority: 20.0
+ - tags: {landuse: orchard}
+ style:
+ fill: orchard_color
+ priority: 21.0
+ - tags: {landuse: meadow}
+ style:
+ fill: meadow_color
+ stroke: meadow_border_color
+ priority: 20.0
+ - tags: {tourism: camp_pitch}
+ style:
+ stroke: "#000000"
+ opacity: 0
+ priority: 100.0
+
+ # Hidden land use
+
+ - tags: {landuse: cemetery}
+ style:
+ fill: hidden_color
+ opacity: 0
+ priority: 1.0
+ - tags: {landuse: commercial}
+ style:
+ fill: hidden_color
+ opacity: 0
+ priority: 1.0
+ - tags: {landuse: industrial}
+ style:
+ fill: hidden_color
+ opacity: 0
+ priority: 1.0
+ - tags: {landuse: military}
+ style:
+ fill: hidden_color
+ opacity: 0
+ priority: 1.0
+ - tags: {landuse: railway}
+ style:
+ fill: hidden_color
+ opacity: 0
+ priority: 1.0
+ - tags: {landuse: residential}
+ style:
+ fill: hidden_color
+ opacity: 0
+ priority: 1.0
+ - tags: {power: substation}
+ style:
+ fill: hidden_color
+ opacity: 0
+ priority: 1.0
+
+ - tags: {amenity: ferry_terminal}
+ style:
+ fill: ferry_terminal_color
+ priority: 50.0
+ - tags: {amenity: parking}
+ style:
+ fill: parking_color
+ opacity: 1
+ - tags: {amenity: parking_space}
+ style:
+ stroke: "#FFFFFF"
+
+ - tags: {aeroway: landingpad}
+ style:
+ fill: "#000000"
+ opacity: 0
+ - tags: {aeroway: helipad}
+ style:
+ fill: "#440044"
+ opacity: 0
+
+ - tags: {waterway: river}
+ style:
+ stroke: def_water_color
+ stroke-width: 2.5
+ priority: 22.0
+ - tags: {waterway: canal}
+ style:
+ stroke: def_water_color
+ stroke-width: 2.0
+ priority: 22.0
+ - tags: {waterway: stream}
+ style:
+ stroke: def_water_color
+ stroke-width: 1.5
+ priority: 22.0
+ - tags: {waterway: riverbank}
+ style:
+ fill: def_water_color
+ stroke: water_border_color
+ stroke-width: 1.0
+ priority: 22.0
+ - tags: {waterway: ditch}
+ style:
+ fill: def_water_color
+ stroke: def_water_color
+ stroke-width: 2.0
+ priority: 22.0
+
+ - tags: {railway: subway}
+ style:
+ stroke-width: 7
+ opacity: 0
+ stroke: def_train_color
+ priority: 41.0
+ - tags: {railway: rail}
+ style:
+ stroke-width: 7.0
+ stroke: def_train_color
+ priority: 42.0
+ - tags: {railway: light_rail}
+ style:
+ stroke-width: 7.0
+ stroke: def_train_color
+ priority: 42.0
+ - tags: {railway: monorail}
+ style:
+ stroke-width: 7.0
+ stroke: def_train_color
+ priority: 42.0
+ - tags: {railway: funicular}
+ style:
+ stroke-width: 7.0
+ stroke: def_train_color
+ priority: 42.0
+ - tags: {railway: narrow_gauge}
+ style:
+ stroke-width: 7.0
+ stroke: def_train_color
+ priority: 42.0
+ - tags: {railway: tram}
+ style:
+ stroke-width: 7.0
+ stroke: def_train_color
+ priority: 42.0
+ - tags: {railway: construction}
+ style:
+ stroke-width: 7.0
+ stroke: def_train_color
+ stroke-dasharray: 6,3
+ opacity: 0
+ priority: 42.0
+ - tags: {railway: disused}
+ style:
+ stroke-width: 7.0
+ stroke: def_train_color
+ stroke-dasharray: 6,6
+ opacity: 0
+ priority: 42.0
+ - tags: {railway: abandoned}
+ style:
+ stroke-width: 7.0
+ stroke: def_train_color
+ stroke-dasharray: 6,9
+ opacity: 0
+ priority: 42.0
+
+ - tags: {railway: rail}
+ style:
+ stroke-width: 7
+ stroke: def_train_color
+ priority: 43.0
+ - tags: {railway: light_rail}
+ style:
+ stroke-width: 7
+ stroke: def_train_color
+ priority: 43.0
+ - tags: {railway: monorail}
+ style:
+ stroke-width: 7
+ stroke: def_train_color
+ priority: 43.0
+ - tags: {railway: funicular}
+ style:
+ stroke-width: 7
+ stroke: def_train_color
+ priority: 43.0
+ - tags: {railway: narrow_gauge}
+ style:
+ stroke-width: 7
+ stroke: def_train_color
+ priority: 43.0
+ - tags: {railway: tram}
+ style:
+ stroke-width: 7
+ stroke: def_train_color
+ priority: 43.0
+
+ - tags: {railway: platform}
+ style:
+ fill: def_train_color
+ stroke-width: 7
+ stroke: def_train_color
+ priority: 41.0
+
+ - tags: {route: ferry}
+ style:
+ stroke-width: 1.0
+ stroke-dasharray: 3.0,3.0
+ stroke-linecap: butt
+ stroke: route_color
+ priority: 42.0
+
+ - tags: {leisure: garden}
+ style:
+ fill: grass_color
+ priority: 21.0
+ - tags: {leisure: park}
+ style:
+ fill: park_color
+ opacity: 1
+ - tags: {landuse: recreation_ground}
+ style:
+ fill: grass_color
+ opacity: 1
+ - tags: {leisure: recreation_ground}
+ style:
+ fill: grass_color
+ opacity: 1
+ - tags: {leisure: stadium}
+ style:
+ fill: grass_color
+ opacity: 1
+ - tags: {leisure: golf_course}
+ style:
+ fill: grass_color
+ opacity: 1
+ - tags: {leisure: pitch}
+ style:
+ fill: pitch_color
+ stroke: pitch_border_color
+ stroke-width: 1.0
+ priority: 21.0
+ - tags: {leisure: bleachers}
+ style:
+ fill: pitch_color
+ stroke: pitch_border_color
+ stroke-width: 1.0
+ priority: 21.0
+ - tags: {leisure: track, area: "yes"}
+ style:
+ fill: pitch_color
+ stroke: pitch_border_color
+ stroke-width: 1.0
+ priority: 21.0
+ - tags: {leisure: fitness_station}
+ style:
+ fill: pitch_color
+ stroke: pitch_border_color
+ stroke-width: 1.0
+ priority: 21.0
+ - tags: {leisure: playground}
+ style:
+ fill: playground_color
+ stroke: playground_border_color
+ priority: 21.0
+ - tags: {leisure: swimming_pool}
+ style:
+ fill: def_water_color
+ stroke: water_border_color
+ stroke-width: 1.0
+
+ - tags: {tourism: artwork}
+ style:
+ stroke: "#888888"
+ stroke-width: 1.0
+ priority: 10.0
+ - tags: {barrier: hedge}
+ style:
+ fill: none
+ stroke: wood_color
+ stroke-width: 4.0
+ priority: 40.0
+ - tags: {barrier: city_wall}
+ style:
+ fill: none
+ stroke: "#000000"
+ stroke-width: 2.0
+ opacity: 1
+ priority: 40.0
+ - tags: {barrier: wall}
+ style:
+ fill: none
+ stroke: "#000000"
+ stroke-width: 1.5
+ opacity: 1
+ priority: 40.0
+ - tags: {man_made: embankment}
+ style:
+ fill: none
+ stroke: "#000000"
+ stroke-width: 1.0
+ opacity: 0
+ priority: 40.0
+ - tags: {man_made: embankment}
+ parallel_offset: -1.5
+ style:
+ fill: none
+ stroke: "#000000"
+ stroke-width: 3.0
+ opacity: 0
+ stroke-dasharray: "1,7"
+ priority: 40.0
+ - tags: {barrier: fence}
+ style:
+ fill: none
+ stroke: "#000000"
+ stroke-width: 1.0
+ opacity: 0
+ priority: 40.0
+ - tags: {barrier: retaining_wall}
+ style:
+ fill: none
+ stroke: "#000000"
+ stroke-width: 1.0
+ opacity: 0
+ priority: 40.0
+ - tags: {barrier: handrail}
+ style:
+ fill: none
+ stroke: "#000000"
+ stroke-width: 1.0
+ opacity: 0
+ priority: 40.0
+ - tags: {barrier: kerb}
+ style:
+ fill: none
+ stroke: "#000000"
+ stroke-width: 1.0
+ opacity: 0
+ priority: 40.0
+
+ - tags: {border: "*"}
+ style:
+ stroke: "#FF0000"
+ stroke-width: 0.5
+ stroke-dasharray: 10.0,20.0
+ - tags: {"area:highway": "*"}
+
+ - tags: {boundary: "*"}
+ # style:
+ # stroke: boundary_color
+ # stroke-width: 0.3
+ # stroke-dasharray: 10.0,5.0
+ priority: 60.0
+
+area_tags:
+ - tags: {aeroway: "*"}
+ - tags: {building: "*"}
+ - tags: {landuse: "*"}
+ - tags: {leisure: "*"}
+ - tags: {natural: "*"}
+ exception: {natural: "tree_row"}
+ - tags: {indoor: "corridor"}
+ - tags: {power: "compensator"}
+ - tags: {power: "substation"}
+
+keys_to_write:
+ - "STIF:zone"
+ - "alt_name"
+ - "artist_name"
+ - "booth"
+ - "branch"
+ - "brand"
+ - "capacity"
+ - "cladr:code"
+ - "collection_times"
+ - "created_by"
+ - "cuisine"
+ - "cyclestreets_id"
+ - "description"
+ - "designation"
+ - "destination"
+ - "ele"
+ - "email"
+ - "end_date"
+ - "facebook"
+ - "fax"
+ - "fhrs:confidence_management"
+ - "fhrs:hygiene"
+ - "fhrs:id"
+ - "fhrs:inspectiondate"
+ - "fhrs:local_authority_id"
+ - "fhrs:rating"
+ - "fhrs:rating_date"
+ - "flickr"
+ - "full_name"
+ - "genus"
+ - "height"
+ - "image"
+ - "information"
+ - "inscription"
+ - "int_name"
+ - "is_in"
+ - "last_collection"
+ - "local_ref"
+ - "manufacturer"
+ - "media:commons"
+ - "min_height"
+ - "name"
+ - "naptan:AltCommonName"
+ - "naptan:AltStreet"
+ - "naptan:AtcoCode"
+ - "naptan:Bearing"
+ - "naptan:BusStopType"
+ - "naptan:CommonName"
+ - "naptan:Crossing"
+ - "naptan:Indicator"
+ - "naptan:Landmark"
+ - "naptan:NaptanCode"
+ - "naptan:Notes"
+ - "naptan:PlusbusZoneRef"
+ - "naptan:ShotCommonName"
+ - "naptan:Street"
+ - "naptan:verified"
+ - "network"
+ - "official_name"
+ - "old_name"
+ - "opening_hours"
+ - "opening_hours:url"
+ - "operator"
+ - "phone"
+ - "phone_1"
+ - "platforms"
+ - "postal_code"
+ - "ref"
+ - "ref_no"
+ - "route_ref"
+ - "royal_cypher"
+ - "seats"
+ - "species"
+ - "start_date"
+ - "survey:date"
+ - "taxon"
+ - "telephone"
+ - "twitter"
+ - "uk_postcode_centroid"
+ - "uri"
+ - "url"
+ - "voltage"
+ - "website"
+ - "website_2"
+ - "wikidata"
+ - "wikipedia"
+
+prefix_to_write:
+ - "addr"
+ - "alt_name"
+ - "contact"
+ - "description"
+ - "genus"
+ - "inscription"
+ - "is_in"
+ - "manufacturer"
+ - "name"
+ - "old_name"
+ - "operator"
+ - "route_ref"
+ - "species"
+ - "taxon"
+ - "website"
+ - "wikipedia"
+
+keys_to_skip:
+ - "FIXME"
+ - "attribution"
+ - "building:levels"
+ - "building:part"
+ - "comment"
+ - "created_by"
+ - "curve_geometry"
+ - "diameter_crown"
+ - "fixme"
+ - "import_uuid"
+ - "indoor"
+ - "junction"
+ - "layer"
+ - "level"
+ - "level:ref"
+ - "location:transition"
+ - "mapillary"
+ - "naptan:verified:note"
+ - "note"
+ - "osak:identifier"
+ - "place"
+ - "ref:opendataparis:adresse"
+ - "ref:opendataparis:geo_point_2d"
+ - "source"
+ - "source_ref"
+
+prefix_to_skip:
+ - "demolished"
+ - "mapillary"
+ - "source"
+ - "razed"
+ - "removed"
+
+tags_to_skip:
+ highway: motorway_junction
\ No newline at end of file
diff --git a/mia/conf/example.yaml b/mia/conf/example.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..12b73b3b3b4b6b43e469e428b9902f71b5bc34f9
--- /dev/null
+++ b/mia/conf/example.yaml
@@ -0,0 +1,94 @@
+# Where to save the downloaded dataset
+dataset_dir: "datasets/new_locations"
+
+fpv_options:
+ # Pipeline configuration to filter FPV images
+ filter_pipeline_cfg: "mia/fpv/filter_pipelines/mia.yaml"
+
+ # Chunk size is used for checkpointing as image metadata
+ # can be very memory intensive. This allows you to resume
+ # from the last chunk in case things go wrong.
+ metadata_download_chunk_size: 50000
+
+ # FPV retrieval is comprised of the stages bellow
+ # these boolean flags allow you to only execute certain stages
+ # Note that the stages are ordered and later stages assume previous stages
+ # are complete.
+ stages:
+ get_image_points_from_tiles: True
+ get_metadata: True
+ run_filter: True
+ download_images: True
+ to_process_sequence: True
+
+bev_options:
+ # Local planet dump OSM file path. File format should be either .osm or .json.
+ # If not provided, tiled OSM data will be downloaded from the internet.
+ # Additionally, the OSM data will be used to clip images such that they all lie within its boundary box.
+ osm_fp:
+
+ # Download the OSM file encompoassing the whole area + some padding
+ # Pretty slow to render but can fix the problem of missing road segments,
+ # which happens when the small bboxes do not contain any of the ends of the road segment.
+ one_big_osm: False
+
+ # Padding in meters that allows rendering a bigger bounding box then cropping.
+ # Useful to reduce the problem of missing road segments.
+ # If one_big_osm option is turned on, this padding is added over the big map only.
+ padding: 50
+
+ # Download OSM data only and do not process maps into semantic masks.
+ download_osm_only: False
+
+ # If enabled, the osm_cache will store a file per ID which removes the need for synchronization
+ # between processes. If disabled, osm_cache will store files based on bbox queries, enabling
+ # reuse of osm tiles. This was observed to reduce needed OSM downloads by ~20% but will
+ # trigger file lock synchronization if using multiworkers to avoid race conditions.
+ store_osm_per_id: False
+
+ # MapMachine style sheet
+ map_machine_scheme: "mia/bev/styles/mia.yml"
+
+ #Final map pixel size after rotation
+ map_length: 224
+
+ # Map resolution in meters per pixel
+ meters_per_pixel: 0.5
+
+ # If downsampling the map after processing is desired
+ # You can use the below downsampling factor
+ final_downsample: 1
+
+ # Store satelite images as well using google earth engine.
+ # Requires you to already have a google earth engine project, be authenticated using `earthengine authenticate`
+ # and a the project id set using gcloud auth `gcloud auth application-default set-quota-project PROJECT_ID`)
+ store_sat: False
+
+ # Whether or not to store RAW BEV svgs and rendered semantic masks.
+ store_all_steps: False
+
+ # How many processes to use to process images. Set to 0 to disable multiprocessing.
+ n_workers: 0
+
+ # Redownload existing BEV images. Useful if style sheet is updated.
+ redownload: False
+
+ # Should we sample from the dataframe instead of downloading everything?
+ # Set to -1 to download everything
+ n_samples: -1
+
+# List all locations you want to download and process below
+cities:
+ - name: "Mount Oliver"
+ state: "Pennsylvania" # (Optional)
+ country: "United States" # (Optional)
+ bound_type: "auto_shape" # ["auto_shape", "auto_bbox", "custom_size", "custom_bbox"]
+
+ - name: "Greensburg"
+ state: "Pennsylvania"
+ bound_type: "custom_size"
+ custom_size: 5 # 5km x 5km centered on auto-fetched city center
+
+ - name: "Frick Park"
+ bound_type: "custom_bbox"
+ custom_bbox: "-79.9086,40.4294,-79.9052,40.4315" # East,South,West,North
\ No newline at end of file
diff --git a/mia/conf/mia.yaml b/mia/conf/mia.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..f97720912c3d9065172801750759ecfb36f600bc
--- /dev/null
+++ b/mia/conf/mia.yaml
@@ -0,0 +1,109 @@
+dataset_dir: "datasets/new_locations"
+
+fpv_options:
+ # Pipeline configuration to filter FPV images
+ filter_pipeline_cfg: "mia/fpv/filter_pipelines/mia.yaml"
+
+ # Chunk size is used for checkpointing as image metadata
+ # can be very memory intensive. This allows you to resume
+ # from the last chunk in case things go wrong.
+ metadata_download_chunk_size: 50000
+
+ # FPV retrieval is comprised of the stages bellow
+ # these boolean flags allow you to only execute certain stages
+ # Note that the stages are ordered and later stages assume previous stages
+ # are complete.
+ stages:
+ get_image_points_from_tiles: True
+ get_metadata: True
+ run_filter: True
+ download_images: True
+ to_process_sequence: True
+
+bev_options:
+ # Local planet dump OSM file path. File format should be either .osm or .json.
+ # If not provided, tiled OSM data will be downloaded from the internet.
+ # Additionally, the OSM data will be used to clip images such that they all lie within its boundary box.
+ osm_fp:
+
+ # Download the OSM file encompoassing the whole area + some padding
+ # Pretty slow to render but can fix the problem of missing road segments,
+ # which happens when the small bboxes do not contain any of the ends of the road segment.
+ one_big_osm: False
+
+ # Padding in meters that allows rendering a bigger bounding box then cropping.
+ # Useful to reduce the problem of missing road segments.
+ # If one_big_osm option is turned on, this padding is added over the big map only.
+ padding: 50
+
+ # Download OSM data only and do not process maps into semantic masks.
+ download_osm_only: False
+
+ # If enabled, the osm_cache will store a file per ID which removes the need for synchronization
+ # between processes. If disabled, osm_cache will store files based on bbox queries, enabling
+ # reuse of osm tiles. This was observed to reduce needed OSM downloads by ~20% but will
+ # trigger file lock synchronization if using multiworkers to avoid race conditions.
+ store_osm_per_id: False
+
+ # MapMachine style sheet
+ map_machine_scheme: "mia/bev/styles/mia.yml"
+
+ #Final map pixel size after rotation
+ map_length: 224
+
+ # Map resolution in meters per pixel
+ meters_per_pixel: 0.5
+
+ # If downsampling the map after processing is desired
+ # You can use the below downsampling factor
+ final_downsample: 1
+
+ # Store satelite images as well using google earth engine.
+ # Requires you to already have a google earth engine project, be authenticated using `earthengine authenticate`
+ # and a the project id set using gcloud auth `gcloud auth application-default set-quota-project PROJECT_ID`)
+ store_sat: False
+
+ # Whether or not to store RAW BEV svgs and rendered semantic masks.
+ store_all_steps: False
+
+ # How many processes to use to process images. Set to 0 to disable multiprocessing.
+ n_workers: 0
+
+ # Redownload existing BEV images. Useful if style sheet is updated.
+ redownload: False
+
+ # Should we sample from the dataframe instead of downloading everything?
+ # Set to -1 to download everything
+ n_samples: -1
+
+
+cities:
+ - name: "Pittsburgh"
+ state: "Pennsylvania"
+ country: "United States"
+ bound_type: "auto_shape"
+
+ - name: "New York"
+ state: "New York"
+ country: "United States"
+ bound_type: "auto_shape"
+
+ - name: "Chicago"
+ state: "Illinois"
+ country: "United States"
+ bound_type: "auto_shape"
+
+ - name: "San Francisco"
+ state: "California"
+ country: "United States"
+ bound_type: "auto_shape"
+
+ - name: "Los Angeles"
+ state: "California"
+ country: "United States"
+ bound_type: "auto_shape"
+
+ - name: "Houston"
+ state: "Texas"
+ country: "United States"
+ bound_type: "auto_shape"
\ No newline at end of file
diff --git a/mia/conf/mia_rural.yaml b/mia/conf/mia_rural.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..556717dbd5a550fe923c190670c0117083ed2126
--- /dev/null
+++ b/mia/conf/mia_rural.yaml
@@ -0,0 +1,91 @@
+dataset_dir: "datasets/new_locations"
+
+fpv_options:
+ # Pipeline configuration to filter FPV images
+ filter_pipeline_cfg: "mia/fpv/filter_pipelines/mia.yaml"
+
+ # Chunk size is used for checkpointing as image metadata
+ # can be very memory intensive. This allows you to resume
+ # from the last chunk in case things go wrong.
+ metadata_download_chunk_size: 50000
+
+ # FPV retrieval is comprised of the stages bellow
+ # these boolean flags allow you to only execute certain stages
+ # Note that the stages are ordered and later stages assume previous stages
+ # are complete.
+ stages:
+ get_image_points_from_tiles: True
+ get_metadata: True
+ run_filter: True
+ download_images: True
+ to_process_sequence: True
+
+bev_options:
+ # Local planet dump OSM file path. File format should be either .osm or .json.
+ # If not provided, tiled OSM data will be downloaded from the internet.
+ # Additionally, the OSM data will be used to clip images such that they all lie within its boundary box.
+ osm_fp:
+
+ # Download the OSM file encompoassing the whole area + some padding
+ # Pretty slow to render but can fix the problem of missing road segments,
+ # which happens when the small bboxes do not contain any of the ends of the road segment.
+ one_big_osm: False
+
+ # Padding in meters that allows rendering a bigger bounding box then cropping.
+ # Useful to reduce the problem of missing road segments.
+ # If one_big_osm option is turned on, this padding is added over the big map only.
+ padding: 50
+
+ # Download OSM data only and do not process maps into semantic masks.
+ download_osm_only: False
+
+ # If enabled, the osm_cache will store a file per ID which removes the need for synchronization
+ # between processes. If disabled, osm_cache will store files based on bbox queries, enabling
+ # reuse of osm tiles. This was observed to reduce needed OSM downloads by ~20% but will
+ # trigger file lock synchronization if using multiworkers to avoid race conditions.
+ store_osm_per_id: False
+
+ # MapMachine style sheet
+ map_machine_scheme: "mia/bev/styles/mia.yml"
+
+ #Final map pixel size after rotation
+ map_length: 224
+
+ # Map resolution in meters per pixel
+ meters_per_pixel: 0.5
+
+ # If downsampling the map after processing is desired
+ # You can use the below downsampling factor
+ final_downsample: 1
+
+ # Store satelite images as well using google earth engine.
+ # Requires you to already have a google earth engine project, be authenticated using `earthengine authenticate`
+ # and a the project id set using gcloud auth `gcloud auth application-default set-quota-project PROJECT_ID`)
+ store_sat: False
+
+ # Whether or not to store RAW BEV svgs and rendered semantic masks.
+ store_all_steps: False
+
+ # How many processes to use to process images. Set to 0 to disable multiprocessing.
+ n_workers: 0
+
+ # Redownload existing BEV images. Useful if style sheet is updated.
+ redownload: False
+
+ # Should we sample from the dataframe instead of downloading everything?
+ # Set to -1 to download everything
+ n_samples: -1
+
+cities:
+ - name: "Willow"
+ state: "Alaska"
+ country: "United States"
+ bound_type: "auto_shape"
+
+ - name: "Ely"
+ state: "Nevada"
+ country: "United States"
+ bound_type: "custom_bbox"
+ custom_bbox: "-115.1024,39.0709,-114.7045,39.2321"
+
+
diff --git a/mia/dataset.md b/mia/dataset.md
new file mode 100644
index 0000000000000000000000000000000000000000..b79e36dbab13e2bbd40edb36069bf7bf76571d80
--- /dev/null
+++ b/mia/dataset.md
@@ -0,0 +1,134 @@
+# Map It Anywhere (MIA) Dataset
+
+## Table of Contents
+ - [Introduction](#introduction)
+ - [Data](#data)
+ - [Dataset Structure](#dataset-structure)
+ - [Format](#format)
+ - [Dataset Creation Summary](#dataset-creation)
+ - [Getting Started](#getting-started)
+ - [Licenses](#licenses)
+
+
+
+
+## Introduction
+The Map It Anywhere (MIA) dataset contains large-scale map-prediction-ready data curated from public datasets.
+Specifically, the dataset empowers Bird's Eye View (BEV) map prediction given First Person View (FPV) RGB images, by providing diversity in location and cameras beyond current datasets. The dataset contains 1.2 million high quality first-person-view (FPV) and bird's eye view (BEV) map pairs covering 470 squared kilometers, which to the best of our knowledge provides 6x more coverage than the closest publicly available map prediction dataset, thereby facilitating future map prediction research on generalizability and robustness. The dataset is curated using our MIA data engine [code](https://github.com/MapItAnywhere/MapItAnywhere) to sample from six urban-centered location: New York, Chicago, Houston, Los Angeles, Pittsburgh, and San Francisco.
+
+## Data
+### Dataset Structure
+
+```
+ROOT
+|
+--- LOCATION_0 # location folder
+| |
+| +--- images # FPV Images (XX.jpg)
+| +--- semantic_masks # Semantic Masks (XX.npz)
+| +--- flood_fill # Visibility Masks (XX.npz)
+| ---- dump.json # Camera pose information for IDs in LOCATION
+| ---- image_points.parquet
+| ---- image_metadata.parquet
+| ---- image_metadata_filtered.parquet
+| ---- image_metadata_filtered_processed.parquet
+--- LOCATION_1
+.
+.
+|
++-- LOCATION_2
+--- README.md
+--- samples.pdf # Visualization of sample data
+```
+
+
+## Format
+
+Each data sample has a unique ID given by Mapillary and is used to reference and associate attributes related to the sample throughout the dataset.
+
+**Each location has the following:**
+
+- `images` Directory containing all FPV images named as `_undistorted.png`
+- `semantic_masks` npz files named as `` containing semantic masks in the format of a single array `arr_0` with shape 224x224x8 where the 3rd dimension maps to classes as follows:
+ 0. road
+ 1. crossing
+ 2. explicit_pedestrian
+ 3. park (Unused by Mapper)
+ 4. building
+ 5. water (Unused by Mapper)
+ 6. terrain
+ 7. parking
+ 8. train (Unused by Mapper)
+- `flood_masks` npz files named as `` containing an observable region mask in the format of a single array `arr_0` with shape 224x224.
+- `image_points.parquet` dataframe containing all image points retrieved within the tiles encompassing the boundary.
+- `image_metadata.parquet` dataframe including metadata retrieved for each image point retrieved (After boundary filtering). The metadata retrieved is documented in the [Mapillary API](https://www.mapillary.com/developer/api-documentation#image)
+- `image_metadata_filtered.parquet` As above but only keeping filtered records
+- `image_metadata_filtered_processed.parquet` the final dataframe after FPV processing and spatial filtering and is the one that reflects what to expect in `images` directory.
+- `dump.json` a json file containing camera intrinsics and extrinsics for each image taken. Same format as [OrienterNet](https://github.com/facebookresearch/OrienterNet).
+
+In addition `split.json` is a file at the root that describes our training, validation, and testing splits.
+
+**Note** that throughout the pipeline, some data samples are unable to be processed fully due to API issues or processing limitations. Such data samples may have residues in dataframes or split files but may not have corresponding maps or flood masks. Thus, a valid data sample is defined as one that has a corresponding image, metadata record, semantic mask, and flood mask. The invalid data samples are less than 0.001% and will be cleaned up in later versions.
+
+
+## Dataset Creation
+
+
+**Overview of how MIA data engine enables automatic curation of FPV & BEV data.**
+Given names of cities as input from the left, the top row shows FPV processing, while the bottom row depicts BEV processing. Both pipelines converge on the right, producing FPV, BEV, and pose tuples. For more information, please reference the main paper.
+
+### Curation Rationale
+
+The MIA data engine and dataset were created to accelerate research progress towards anywhere map prediction. Current map prediction research builds on only a few map prediction datasets released by autonomous vehicle companies, which cover very limited area. We therefore present the MIA data engine, a more scalable approach by sourcing from large-scale crowd-sourced mapping platforms, Mapillary for FPV images and OpenStreetMap for BEV semantic maps.
+
+
+### Source Data
+
+The MIA dataset includes data from two sources: [Mapillary](https://www.mapillary.com/) for First-Person-View (FPV) images, and [OpenStreetMap](https://www.openstreetmap.org) for Bird-Eye-View (BEV) maps.
+
+For FPV retrieval, we leverage Mapillary, a massive public database, licensed under [CC BY-SA](https://creativecommons.org/licenses/by-sa/4.0/), with over 2 billion crowd-sourced images. The images span various weather and lighting conditions collected using diverse camera models and focal lengths. Furthermore, images are taken by pedestrians, vehicles, bicyclists, etc. This diversity enables the collection of more dynamic and difficult scenarios critical for anywhere map prediction.
+When uploading to the Mapillary platform, users submit them under Mapillary's terms and all images shared are under a CC-BY-SA license, more details can be found in [Mapillary License Page](https://help.mapillary.com/hc/en-us/articles/115001770409-Licenses).
+In addition, Mapillary integrates several mechanisms to minimize privacy concerns, such as applying technology to blur any faces and license plates, requiring users to notify if they observe any imageries that may contain personal data. More information can be found on the [Mapillary Privacy Policy page](https://www.mapillary.com/privacy).
+
+For BEV retrieval, we leverage OpenStreetMap (OSM), a global crowd-sourced mapping platform open-sourced under [Open Data Commons Open Database License (ODbL)](https://opendatacommons.org/licenses/odbl/). OSM provides
+rich vectorized annotations for streets, sidewalks, buildings, etc. OpenStreetMap has limitations on mapping private information where "it violates the privacy
+of people living in this world", with guidelines found [here](https://wiki.openstreetmap.org/wiki/Limitations_on_mapping_private_information).
+
+
+### Bias, Risks, and Limitations
+
+While we show promising generalization performance on conventional datasets, we note that label noise inherently exists, to a higher degree
+than manually collected data, in crowd sourced data, in both pose correspondence, and in BEV map labeling. Such noise is common across large-scale
+automatically scraped/curated benchmarks such as ImageNet. While we recognize that our sampled dataset is biased towards locations in the US, our MIA data engine is
+applicable to other world-wide locations.
+Our work relies heavily on crowd sourced data putting the burden of data collection on people and open-source contributions.
+
+
+## Getting Started
+1. [Download the dataset](https://cmu.box.com/s/6tnlvikg1rcsai0ve7t8kgdx9ago9x9q).
+2. Unzip all locations of interest into the same structure described above, such that a root folder contains all location folders directly.
+3. (Optional) Verify your download by visualizing a few samples using the tool `mia/misc_tools/vis_samples.py`.
+ 1. Build the docker image `mia/Dockerfile` if you haven't already by running:
+
+ docker build -t mia:release mia
+ 2. Launch the container while mounting your dataset root folder as well as this repository
+
+ docker run -v :/home/mia_dataset_release -v :/home/MapItAnywhere --network=bridge -it mia:release
+ 3. From inside the container run:
+
+ cd /home/MapItAnywhere
+
+ python3.9 -m mia.misc_tools.vis_samples --dataset_dir /home/mia_dataset_release --locations pittsburgh
+
+ If successful, the script will generate a PDF called `compare.pdf` in the pittsburgh directory. Upon openning you should see the metadata, FPVs, and BEVs of a few samples of the dataset. Note that satellite imagery is not provided as part of the dataset and is only used for comparison purposes.
+
+4. Enjoy and explore! Don't hesitate to raise a GitHub issue if you encounter any problems.
+
+Samples and key metadata information in `compare.pdf` will look like the following:
+
+
+## Licenses
+The FPVs were curated and processed from Mapillary and have the same [CC by SA license](https://creativecommons.org/licenses/by-sa/4.0/deed.en). These include all images files, parquet dataframes, and dump.json.
+The BEVs were curated and processed from OpenStreetMap and has the same [Open Data Commons Open Database (ODbL) License](https://opendatacommons.org/licenses/odbl/). These include all semantic masks and flood masks.
+The rest of the data is licensed under [CC by SA license](https://creativecommons.org/licenses/by-sa/4.0/deed.en).
+
diff --git a/mia/fpv/download.py b/mia/fpv/download.py
new file mode 100644
index 0000000000000000000000000000000000000000..274bb29f98cf7c180115b90208d2a9f2f0be76ea
--- /dev/null
+++ b/mia/fpv/download.py
@@ -0,0 +1,325 @@
+# Adapted from OrienterNet
+
+import json
+from pathlib import Path
+
+import numpy as np
+import httpx
+import asyncio
+from aiolimiter import AsyncLimiter
+import tqdm
+import requests
+import mercantile
+import geojson
+import turfpy.measurement
+from vt2geojson.tools import vt_bytes_to_geojson
+
+
+from opensfm.pygeometry import Camera, Pose
+from opensfm.pymap import Shot
+
+from .. import logger
+from .geo import Projection
+
+
+semaphore = asyncio.Semaphore(100) # number of parallel threads.
+image_filename = "{image_id}.jpg"
+info_filename = "{image_id}.json"
+
+
+class MapillaryDownloader:
+ image_fields = (
+ "id",
+ "height",
+ "width",
+ "camera_parameters",
+ "camera_type",
+ "captured_at",
+ "compass_angle",
+ "geometry",
+ "altitude",
+ "computed_compass_angle",
+ "computed_geometry",
+ "computed_altitude",
+ "computed_rotation",
+ "thumb_2048_url",
+ "thumb_original_url",
+ "sequence",
+ "sfm_cluster",
+ "creator",
+ "make",
+ "model",
+ "is_pano",
+ "quality_score",
+ "exif_orientation"
+ )
+ image_info_url = (
+ "https://graph.mapillary.com/{image_id}?access_token={token}&fields={fields}"
+ )
+ seq_info_url = "https://graph.mapillary.com/image_ids?access_token={token}&sequence_id={seq_id}"
+ tile_info_url = "https://tiles.mapillary.com/maps/vtp/mly1_public/2/{z}/{x}/{y}?access_token={token}"
+ max_requests_per_minute = 50_000
+
+ def __init__(self, token: str):
+ self.token = token
+ self.client = httpx.AsyncClient(
+ transport=httpx.AsyncHTTPTransport(retries=20), timeout=600
+ )
+ self.limiter = AsyncLimiter(self.max_requests_per_minute // 2, time_period=60)
+
+ async def call_api(self, url: str):
+ async with self.limiter:
+ r = await self.client.get(url)
+ if not r.is_success:
+ logger.error("Error in API call: %s", r.text)
+ return r
+
+
+ async def get_tile_image_points(self, tile):
+ url = self.tile_info_url.format(
+ x=tile.x,
+ y=tile.y,
+ z=tile.z,
+ token=self.token
+ )
+ try :
+ r = await self.call_api(url)
+ if r.is_success:
+ geo_d = vt_bytes_to_geojson(
+ b_content=r._content,
+ x=tile.x,
+ y=tile.y,
+ z=tile.z,
+ layer="image",
+ )
+ d = geo_d["features"]
+ return tile, d
+ except Exception as e:
+ logger.error(f"{type(e).__name__}: {e}")
+ return tile, None
+
+ async def get_tiles_image_points(self, tiles, retries=3):
+ tile_to_images = {}
+ tasks = [self.get_tile_image_points(t) for t in tiles]
+ for i in range(retries):
+ failed_tiles = list()
+ for task in tqdm.asyncio.tqdm.as_completed(tasks):
+ tile, image_ids = await task
+ if image_ids is not None:
+ tile_to_images[f"z_{tile.z}_x{tile.x}_y{tile.y}"] = image_ids
+ else:
+ logger.error(f"Error when retrieving tile z_{tile.z}_x{tile.x}_y{tile.y}. Image_ids is None. Skipping.")
+ failed_tiles.append(tile)
+ if len(failed_tiles) == 0:
+ break
+ else:
+ if i == retries-1:
+ logger.error(f"Failed to retrieve {len(failed_tiles)} tiles in attempt {i}. Maxed out retries. Skipping those tiles.")
+ else:
+ logger.error(f"Failed to retrieve {len(failed_tiles)} tiles in attempt {i}. Trying again..")
+ tasks = [self.get_tile_image_points(t) for t in failed_tiles]
+ return tile_to_images
+
+
+ async def get_image_info(self, image_id: int):
+ url = self.image_info_url.format(
+ image_id=image_id,
+ token=self.token,
+ fields=",".join(self.image_fields),
+ )
+ r = await self.call_api(url)
+ if r.is_success:
+ return json.loads(r.text)
+
+ async def get_sequence_info(self, seq_id: str):
+ url = self.seq_info_url.format(seq_id=seq_id, token=self.token)
+ r = await self.call_api(url)
+ if r.is_success:
+ return json.loads(r.text)
+
+ async def download_image_pixels(self, url: str, path: Path):
+ r = await self.call_api(url)
+ if r.is_success:
+ with open(path, "wb") as fid:
+ fid.write(r.content)
+ return r.is_success
+
+ async def get_image_info_cached(self, image_id: int, path: Path):
+ if path.exists():
+ info = json.loads(path.read_text())
+ else:
+ info = await self.get_image_info(image_id)
+ path.write_text(json.dumps(info))
+ return info
+
+ async def download_image_pixels_cached(self, url: str, path: Path):
+ if path.exists():
+ return True
+ else:
+ return await self.download_image_pixels(url, path)
+
+
+async def fetch_images_in_sequence(i, downloader):
+ async with semaphore:
+ info = await downloader.get_sequence_info(i)
+ image_ids = [int(d["id"]) for d in info["data"]]
+ return i, image_ids
+
+
+async def fetch_images_in_sequences(sequence_ids, downloader):
+ seq_to_images_ids = {}
+ tasks = [fetch_images_in_sequence(i, downloader) for i in sequence_ids]
+ for task in tqdm.asyncio.tqdm.as_completed(tasks):
+ i, image_ids = await task
+ seq_to_images_ids[i] = image_ids
+ return seq_to_images_ids
+
+
+async def fetch_image_info(i, downloader, dir_):
+ async with semaphore:
+ path = dir_ / info_filename.format(image_id=i)
+ # info = await downloader.get_image_info_cached(i, path)
+ info = await downloader.get_image_info(i) # FIXME: temporarily disable caching, takes too long to reads many (>1mil) files
+ return i, info
+
+
+async def fetch_image_infos(image_ids, downloader, dir_):
+ infos = {}
+ num_fail = 0
+ tasks = [fetch_image_info(i, downloader, dir_) for i in image_ids]
+ for task in tqdm.asyncio.tqdm.as_completed(tasks):
+ i, info = await task
+ if info is None:
+ num_fail += 1
+ else:
+ infos[i] = info
+ return infos, num_fail
+
+
+async def fetch_image_pixels(i, url, downloader, dir_, overwrite=False):
+ async with semaphore:
+ path = dir_ / image_filename.format(image_id=i)
+ if overwrite:
+ path.unlink(missing_ok=True)
+ success = await downloader.download_image_pixels_cached(url, path)
+ return i, success
+
+
+async def fetch_images_pixels(image_urls, downloader, dir_):
+ num_fail = 0
+ tasks = [fetch_image_pixels(*id_url, downloader, dir_) for id_url in image_urls]
+ for task in tqdm.asyncio.tqdm.as_completed(tasks):
+ i, success = await task
+ num_fail += not success
+ return num_fail
+
+
+def opensfm_camera_from_info(info: dict) -> Camera:
+ cam_type = info["camera_type"]
+ if cam_type == "perspective":
+ camera = Camera.create_perspective(*info["camera_parameters"])
+ elif cam_type == "fisheye":
+ camera = Camera.create_fisheye(*info["camera_parameters"])
+ elif Camera.is_panorama(cam_type):
+ camera = Camera.create_spherical()
+ else:
+ raise ValueError(cam_type)
+ camera.width = info["width"]
+ camera.height = info["height"]
+ camera.id = info["id"]
+ return camera
+
+
+def opensfm_shot_from_info(info: dict, projection: Projection) -> Shot:
+ latlong = info["computed_geometry.coordinates"][::-1]
+ alt = info["computed_altitude"]
+ xyz = projection.project(np.array([*latlong, alt]), return_z=True)
+ c_rotvec_w = np.array(info["computed_rotation"])
+ pose = Pose()
+ pose.set_from_cam_to_world(-c_rotvec_w, xyz)
+ camera = opensfm_camera_from_info(info)
+ return latlong, Shot(info["id"], camera, pose)
+
+
+def get_city_boundary(city, state=None, country=None, fetch_shape=False):
+ # Use Nominatim API to get the boundary of the city
+ base_url = "https://nominatim.openstreetmap.org/search"
+ params = {
+ 'city': city,
+ 'state': state,
+ 'country': country,
+ 'format': 'json',
+ 'limit': 1,
+ 'polygon_geojson': 1 if fetch_shape else 0
+ }
+
+ # Without a user-agent we may get blocked. This is an arbitrary user-agent and can be changed
+ # Rotating between user-agents may circumvent blocks but may not be fair
+ headers = {
+ 'User-Agent': f'mapperceptionnet_{city}_{state}'
+ }
+ response = requests.get(base_url, params=params, headers=headers)
+
+ if response.status_code != 200:
+ logger.error(f"Nominatim error when fetching boundary data for {city}, {state}.\n"
+ f"Status code: {response.status_code}. Content: {response.content}")
+ return None
+
+ data = response.json()
+
+ if data is None:
+ logger.warn(f"No data returned by Nominatim for {city}, {state}")
+ return None
+
+ # Extract bbox data from the API response
+ bbox_data = data[0]['boundingbox']
+ bbox = {
+ 'west': float(bbox_data[2]),
+ 'south': float(bbox_data[0]),
+ 'east': float(bbox_data[3]),
+ 'north': float(bbox_data[1])
+ }
+
+ if fetch_shape:
+ # Extract GeoJSON boundary data from the API response
+ boundary_geojson = data[0]['geojson']
+ boundary_geojson = {
+ "type": "FeatureCollection",
+ "features": [
+ {"type": "Feature",
+ "properties": {},
+ "geometry": boundary_geojson}]
+ }
+ return bbox, boundary_geojson
+ else:
+ return bbox
+
+
+def get_tiles_from_boundary(boundary_info, zoom=14):
+ if boundary_info["bound_type"] == "auto_shape":
+ # TODO: Instead of tiles from the big bbox, return tiles that hug the shape
+ geojson_shape = boundary_info["shape"]
+
+ # FIXME What to do when boundary is defined by multiple polygons!!
+ # Visualization tool https://geojson.tools/
+ coords = geojson_shape["features"][0]["geometry"]["coordinates"]
+ try:
+ polygon = geojson.Polygon(coords)
+ coordinates = turfpy.measurement.bbox(polygon)
+ except:
+ logger.warn(f"Boundary is defined by {len(coords)} polygons. Choosing first polygon blindly")
+ polygon = geojson.Polygon(coords[0])
+ coordinates = turfpy.measurement.bbox(polygon)
+
+ coordinates = dict(zip(["west", "south", "east", "north"], coordinates))
+ else:
+ coordinates = boundary_info["bbox"]
+
+ tiles = list(
+ mercantile.tiles(
+ **coordinates,
+ zooms=zoom,
+ )
+ )
+
+ return tiles
\ No newline at end of file
diff --git a/mia/fpv/filter_pipelines/mia.yaml b/mia/fpv/filter_pipelines/mia.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..0fb0f3f4f407180f23e76fa78e5962a9f13448d8
--- /dev/null
+++ b/mia/fpv/filter_pipelines/mia.yaml
@@ -0,0 +1,54 @@
+filter_pipeline:
+ name: "mia"
+ filters:
+ - value_missing_filter:
+ name: "required_fields_filter"
+ keys:
+ - "thumb_2048_url"
+ - "computed_geometry.coordinates"
+ - "sfm_cluster.id"
+ - date_filter:
+ from_year: 2017
+ - value_in_list_filter:
+ name: "camera_model_filter"
+ key: "model"
+ exclude: False
+ lst:
+ - "hdr-as200v"
+ - "hdr-as300"
+ - "fdr-x3000"
+ - "fdr-x1000v"
+ - "gopromax"
+ - "goprofusionfs1.04.01.80.00"
+ - "goprofusion"
+ - "goprofusionfs1.04.01.70.00"
+ - "iphone11"
+ - "iphone11pro"
+ - "iphone11promax"
+ - "iphone12"
+ - "iphone12pro"
+ - "iphone12promax"
+ - "iphone13"
+ - "iphone13pro"
+ - "iphone13promax"
+ - "sm-g930v"
+ - "sm-g970u"
+ - "lm-v405"
+ - value_in_list_filter:
+ name: "camera_type_filter"
+ key: "camera_type"
+ exclude: False
+ lst:
+ - "perspective"
+ - "fisheye"
+ - angle_discrip_filter:
+ thresh: 20
+ less_than: True
+ - loc_discrip_filter:
+ thresh: 3
+ less_than: True
+ - value_range_filter:
+ name: "exif_filter"
+ key: "exif_orientation"
+ from_v: 1
+ to_v: 1
\ No newline at end of file
diff --git a/mia/fpv/filter_pipelines/mia_rural.yaml b/mia/fpv/filter_pipelines/mia_rural.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..fea3fb201d3a0bed447c10dcabb2d3e0f4a39cce
--- /dev/null
+++ b/mia/fpv/filter_pipelines/mia_rural.yaml
@@ -0,0 +1,29 @@
+filter_pipeline:
+ name: "mia"
+ filters:
+ - value_missing_filter:
+ name: "required_fields_filter"
+ keys:
+ - "thumb_2048_url"
+ - "computed_geometry.coordinates"
+ - "sfm_cluster.id"
+ - date_filter:
+ from_year: 2017
+ - value_in_list_filter:
+ name: "camera_type_filter"
+ key: "camera_type"
+ exclude: False
+ lst:
+ - "perspective"
+ - "fisheye"
+ - angle_discrip_filter:
+ thresh: 20
+ less_than: True
+ - loc_discrip_filter:
+ thresh: 3
+ less_than: True
+ - value_range_filter:
+ name: "exif_filter"
+ key: "exif_orientation"
+ from_v: 1
+ to_v: 1
\ No newline at end of file
diff --git a/mia/fpv/filters.py b/mia/fpv/filters.py
new file mode 100644
index 0000000000000000000000000000000000000000..bda806e8ad81865817054984c46c96f70d3dfe97
--- /dev/null
+++ b/mia/fpv/filters.py
@@ -0,0 +1,216 @@
+"""
+Contains the filters used to filter out images from the Mapillary API.
+"""
+
+import inspect
+import yaml
+from datetime import datetime
+from functools import partial
+
+import numpy as np
+import pandas as pd
+import shapely
+import shapely.geometry
+from shapely.prepared import prep
+from shapely import contains_xy
+
+from .. import logger
+
+def in_shape_filter(df: pd.DataFrame, geojson_shape):
+ polygon = shapely.geometry.shape(geojson_shape["features"][0]["geometry"])
+ mask = contains_xy(polygon, x=df["geometry.long"], y=df["geometry.lat"])
+ return mask
+
+def value_range_filter(df: pd.DataFrame, key, from_v=None, to_v=None):
+ c = df[key]
+ if from_v is not None and to_v is not None:
+ if from_v == to_v:
+ return c == from_v
+ else:
+ return np.logical_and(c >= from_v, c <= to_v)
+ elif from_v is not None:
+ return c >= from_v
+ elif to_v is not None:
+ return c <= to_v
+ else:
+ raise Exception("from_v and to_v cannot both be None")
+
+def value_in_list_filter(df: pd.DataFrame, key, lst, exclude=False):
+ mask = df[key].isin(lst)
+ if exclude:
+ mask = ~mask
+ return mask
+
+
+def value_missing_filter(df: pd.DataFrame, keys):
+ return np.all(df[keys].notna(), axis=1)
+
+
+def date_filter(df: pd.DataFrame, from_year=None, to_year=None):
+ """
+ Args:
+ before_year: integer representing the year
+ after_year: integer representing the year
+ """
+ if from_year is not None:
+ from_year = int(datetime(from_year, 1, 1).timestamp())*1e3
+ if to_year is not None:
+ to_year = int(datetime(to_year, 1, 1).timestamp())*1e3
+ return value_range_filter(df, "captured_at", from_year, to_year)
+
+def quality_score_filter(df: pd.DataFrame, from_score=None, to_score=None):
+ return value_range_filter(df, "quality_score", from_v=from_score, to_v=to_score)
+
+def angle_dist(a1, a2):
+ a = a1-a2
+ return np.abs((a + 180) % 360 - 180)
+
+def angle_discrip_filter(df: pd.DataFrame, thresh, less_than=True):
+ """
+ Args:
+ thresh: Threshold in degrees
+ """
+ a1 = df["computed_compass_angle"]
+ a2 = df["compass_angle"]
+
+ diff = angle_dist(a1, a2)
+
+ if less_than:
+ return diff < thresh
+ else:
+ return diff > thresh
+
+def haversine_np(lon1, lat1, lon2, lat2):
+ """
+ Calculate the great circle distance between two points
+ on the earth (specified in decimal degrees)
+
+ All args must be of equal length.
+
+ """
+ lon1, lat1, lon2, lat2 = map(np.radians, [lon1, lat1, lon2, lat2])
+
+ dlon = lon2 - lon1
+ dlat = lat2 - lat1
+
+ a = np.sin(dlat/2.0)**2 + np.cos(lat1) * np.cos(lat2) * np.sin(dlon/2.0)**2
+
+ c = 2 * np.arcsin(np.sqrt(a))
+ km = 6378.137 * c
+ return km*1e3
+
+def loc_discrip_filter(df: pd.DataFrame, thresh, less_than=True):
+ """
+ Args:
+ thresh: Threshold in meters
+ """
+ lat1 = df["computed_geometry.lat"]
+ lon1 = df["computed_geometry.long"]
+ lat2 = df["geometry.lat"]
+ lon2 = df["geometry.long"]
+ diff = haversine_np(lon1, lat1, lon2, lat2)
+ if less_than:
+ return diff < thresh
+ else:
+ return diff > thresh
+
+def sequence_sparsity_filter(df: pd.DataFrame, dist_thresh):
+ """
+ TODO
+ This filter filters out images that are too close to each other within a sequence
+ """
+ pass
+
+
+class Filter():
+ def __init__(self, filter_func, name=None, **kwargs):
+ self.filter_func = filter_func
+ self.name = name
+ self.kwargs = kwargs
+
+ def __call__(self, df: pd.DataFrame):
+ return self.filter_func(df, **self.kwargs)
+
+ def __str__(self) -> str:
+ if self.name is None:
+ tag = self.filter_func.__name__
+ else:
+ tag = f"{self.filter_func.__name__}:{self.name}"
+ return tag
+
+ def __repr__(self):
+ kwargs_fmt = ", ".join([f"{k}={v}" for k,v in self.kwargs.items()])
+ return f"{self.__str__()} | kwargs({kwargs_fmt})"
+
+
+class FilterPipeline():
+ def __init__(self, filters: list, sequential=True, name=None, verbose=True):
+ """
+ Args:
+ sequential: Whether to apply filters sequentially or compute the masks
+ for all of them then apply once at the end.
+ verbose: Whether to log the effect of each filter or not
+ """
+ self.filters = filters
+ self.sequential = sequential
+ self.name = name
+ self.verbose = verbose
+
+ def __call__(self, df: pd.DataFrame):
+ N = df.shape[0]
+ if not self.sequential:
+ running_mask = np.full(df.shape[0], True, dtype=bool)
+
+ for f in self.filters:
+ mask = f(df)
+ if self.verbose:
+ s = np.sum(mask)
+ logger.info(f"{f} keeps {s}/{mask.shape[0]} ({s/mask.shape[0]*100:.2f}%) of the images")
+
+ if self.sequential:
+ df = df[mask]
+ if df.shape[0] == 0:
+ logger.warn("No images left during filtering.. Stopping pipeline")
+ return df
+ else:
+ running_mask = np.logical_and(running_mask, mask)
+
+ if not self.sequential:
+ df = df[running_mask]
+
+ logger.info(f"Filter Pipeline {self.name} kept {df.shape[0]}/{N} ({df.shape[0]/N*100:.2f}%) of the images")
+ return df
+
+ def __str__(self):
+ return f"Pipeline {self.name}: " + "\n".join([str(x) for x in self.filters])
+
+ def __repr__(self):
+ return f"Pipeline {self.name}: " + "\n".join([repr(x) for x in self.filters])
+
+ @staticmethod
+ def load_from_yaml(file_path):
+ def is_primitive(x):
+ return isinstance(x, (float, int, bool, str))
+
+ with open(file_path, 'r') as stream:
+ pipeline_dict = yaml.safe_load(stream)["filter_pipeline"]
+
+ sig = inspect.signature(FilterPipeline.__init__)
+ init_args = dict()
+ for param in sig.parameters.values():
+ if param.name in pipeline_dict and is_primitive(pipeline_dict[param.name]):
+ init_args[param.name] = pipeline_dict[param.name]
+
+ filter_dicts = pipeline_dict["filters"]
+ filters = list()
+
+ for filter_dict in filter_dicts:
+ filter_func_name, kwargs = list(filter_dict.items())[0]
+ filter_func = globals()[filter_func_name]
+ filters.append(Filter(filter_func=filter_func, **kwargs))
+
+ pipeline = FilterPipeline(filters, **init_args)
+ return pipeline
+
+if __name__ == "__main__":
+ FilterPipeline.load_from_yaml("mia/fpv/filter_pipelines/mia.yaml")
\ No newline at end of file
diff --git a/mia/fpv/geo.py b/mia/fpv/geo.py
new file mode 100644
index 0000000000000000000000000000000000000000..d458b6009ccd1a3ac973ff4e38f7465e5546895f
--- /dev/null
+++ b/mia/fpv/geo.py
@@ -0,0 +1,132 @@
+# Copied from OrienterNet
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+
+from typing import Union
+
+import numpy as np
+import torch
+
+from .. import logger
+from .geo_opensfm import TopocentricConverter
+
+
+class BoundaryBox:
+ def __init__(self, min_: np.ndarray, max_: np.ndarray):
+ self.min_ = np.asarray(min_)
+ self.max_ = np.asarray(max_)
+ assert np.all(self.min_ <= self.max_)
+
+ @classmethod
+ def from_string(cls, string: str):
+ return cls(*np.split(np.array(string.split(","), float), 2))
+
+ @property
+ def left_top(self):
+ return np.stack([self.min_[..., 0], self.max_[..., 1]], -1)
+
+ @property
+ def right_bottom(self) -> (np.ndarray, np.ndarray):
+ return np.stack([self.max_[..., 0], self.min_[..., 1]], -1)
+
+ @property
+ def center(self) -> np.ndarray:
+ return (self.min_ + self.max_) / 2
+
+ @property
+ def size(self) -> np.ndarray:
+ return self.max_ - self.min_
+
+ def translate(self, t: float):
+ return self.__class__(self.min_ + t, self.max_ + t)
+
+ def contains(self, xy: Union[np.ndarray, "BoundaryBox"]):
+ if isinstance(xy, self.__class__):
+ return self.contains(xy.min_) and self.contains(xy.max_)
+ return np.all((xy >= self.min_) & (xy <= self.max_), -1)
+
+ def normalize(self, xy):
+ min_, max_ = self.min_, self.max_
+ if isinstance(xy, torch.Tensor):
+ min_ = torch.from_numpy(min_).to(xy)
+ max_ = torch.from_numpy(max_).to(xy)
+ return (xy - min_) / (max_ - min_)
+
+ def unnormalize(self, xy):
+ min_, max_ = self.min_, self.max_
+ if isinstance(xy, torch.Tensor):
+ min_ = torch.from_numpy(min_).to(xy)
+ max_ = torch.from_numpy(max_).to(xy)
+ return xy * (max_ - min_) + min_
+
+ def format(self) -> str:
+ return ",".join(np.r_[self.min_, self.max_].astype(str))
+
+ def __add__(self, x):
+ if isinstance(x, (int, float)):
+ return self.__class__(self.min_ - x, self.max_ + x)
+ else:
+ raise TypeError(f"Cannot add {self.__class__.__name__} to {type(x)}.")
+
+ def __and__(self, other):
+ return self.__class__(
+ np.maximum(self.min_, other.min_), np.minimum(self.max_, other.max_)
+ )
+
+ def __repr__(self):
+ return self.format()
+
+
+class Projection:
+ def __init__(self, lat, lon, alt=0, max_extent=25e3):
+ # The approximation error is |L - radius * tan(L / radius)|
+ # and is around 13cm for L=25km.
+ self.latlonalt = (lat, lon, alt)
+ self.converter = TopocentricConverter(lat, lon, alt)
+ min_ = self.converter.to_lla(*(-max_extent,) * 2, 0)[:2]
+ max_ = self.converter.to_lla(*(max_extent,) * 2, 0)[:2]
+ self.bounds = BoundaryBox(min_, max_)
+
+ @classmethod
+ def from_points(cls, all_latlon):
+ assert all_latlon.shape[-1] == 2
+ all_latlon = all_latlon.reshape(-1, 2)
+ latlon_mid = (all_latlon.min(0) + all_latlon.max(0)) / 2
+ return cls(*latlon_mid)
+
+ def check_bbox(self, bbox: BoundaryBox):
+ if self.bounds is not None and not self.bounds.contains(bbox):
+ raise ValueError(
+ f"Bbox {bbox.format()} is not contained in "
+ f"projection with bounds {self.bounds.format()}."
+ )
+
+ def project(self, geo, return_z=False):
+ if isinstance(geo, BoundaryBox):
+ return BoundaryBox(*self.project(np.stack([geo.min_, geo.max_])))
+ geo = np.asarray(geo)
+ assert geo.shape[-1] in (2, 3)
+ if self.bounds is not None:
+ if not np.all(self.bounds.contains(geo[..., :2])):
+ raise ValueError(
+ f"Points {geo} are out of the valid bounds "
+ f"{self.bounds.format()}."
+ )
+ lat, lon = geo[..., 0], geo[..., 1]
+ if geo.shape[-1] == 3:
+ alt = geo[..., -1]
+ else:
+ alt = np.zeros_like(lat)
+ x, y, z = self.converter.to_topocentric(lat, lon, alt)
+ return np.stack([x, y] + ([z] if return_z else []), -1)
+
+ def unproject(self, xy, return_z=False):
+ if isinstance(xy, BoundaryBox):
+ return BoundaryBox(*self.unproject(np.stack([xy.min_, xy.max_])))
+ xy = np.asarray(xy)
+ x, y = xy[..., 0], xy[..., 1]
+ if xy.shape[-1] == 3:
+ z = xy[..., -1]
+ else:
+ z = np.zeros_like(x)
+ lat, lon, alt = self.converter.to_lla(x, y, z)
+ return np.stack([lat, lon] + ([alt] if return_z else []), -1)
\ No newline at end of file
diff --git a/mia/fpv/geo_opensfm.py b/mia/fpv/geo_opensfm.py
new file mode 100644
index 0000000000000000000000000000000000000000..166d501f17b8fb4d8ca535406bd303111be8afe3
--- /dev/null
+++ b/mia/fpv/geo_opensfm.py
@@ -0,0 +1,182 @@
+# Copied from OrienterNet
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+
+import numpy as np
+from numpy import ndarray
+from typing import Tuple
+
+WGS84_a = 6378137.0
+WGS84_b = 6356752.314245
+
+
+def ecef_from_lla(lat, lon, alt: float) -> Tuple[float, ...]:
+ """
+ Compute ECEF XYZ from latitude, longitude and altitude.
+
+ All using the WGS84 model.
+ Altitude is the distance to the WGS84 ellipsoid.
+ Check results here http://www.oc.nps.edu/oc2902w/coord/llhxyz.htm
+
+ >>> lat, lon, alt = 10, 20, 30
+ >>> x, y, z = ecef_from_lla(lat, lon, alt)
+ >>> np.allclose(lla_from_ecef(x,y,z), [lat, lon, alt])
+ True
+ """
+ a2 = WGS84_a**2
+ b2 = WGS84_b**2
+ lat = np.radians(lat)
+ lon = np.radians(lon)
+ L = 1.0 / np.sqrt(a2 * np.cos(lat) ** 2 + b2 * np.sin(lat) ** 2)
+ x = (a2 * L + alt) * np.cos(lat) * np.cos(lon)
+ y = (a2 * L + alt) * np.cos(lat) * np.sin(lon)
+ z = (b2 * L + alt) * np.sin(lat)
+ return x, y, z
+
+
+def lla_from_ecef(x, y, z):
+ """
+ Compute latitude, longitude and altitude from ECEF XYZ.
+
+ All using the WGS84 model.
+ Altitude is the distance to the WGS84 ellipsoid.
+ """
+ a = WGS84_a
+ b = WGS84_b
+ ea = np.sqrt((a**2 - b**2) / a**2)
+ eb = np.sqrt((a**2 - b**2) / b**2)
+ p = np.sqrt(x**2 + y**2)
+ theta = np.arctan2(z * a, p * b)
+ lon = np.arctan2(y, x)
+ lat = np.arctan2(
+ z + eb**2 * b * np.sin(theta) ** 3, p - ea**2 * a * np.cos(theta) ** 3
+ )
+ N = a / np.sqrt(1 - ea**2 * np.sin(lat) ** 2)
+ alt = p / np.cos(lat) - N
+ return np.degrees(lat), np.degrees(lon), alt
+
+
+def ecef_from_topocentric_transform(lat, lon, alt: float) -> ndarray:
+ """
+ Transformation from a topocentric frame at reference position to ECEF.
+
+ The topocentric reference frame is a metric one with the origin
+ at the given (lat, lon, alt) position, with the X axis heading east,
+ the Y axis heading north and the Z axis vertical to the ellipsoid.
+ >>> a = ecef_from_topocentric_transform(30, 20, 10)
+ >>> b = ecef_from_topocentric_transform_finite_diff(30, 20, 10)
+ >>> np.allclose(a, b)
+ True
+ """
+ x, y, z = ecef_from_lla(lat, lon, alt)
+ sa = np.sin(np.radians(lat))
+ ca = np.cos(np.radians(lat))
+ so = np.sin(np.radians(lon))
+ co = np.cos(np.radians(lon))
+ return np.array(
+ [
+ [-so, -sa * co, ca * co, x],
+ [co, -sa * so, ca * so, y],
+ [0, ca, sa, z],
+ [0, 0, 0, 1],
+ ]
+ )
+
+
+def ecef_from_topocentric_transform_finite_diff(lat, lon, alt: float) -> ndarray:
+ """
+ Transformation from a topocentric frame at reference position to ECEF.
+
+ The topocentric reference frame is a metric one with the origin
+ at the given (lat, lon, alt) position, with the X axis heading east,
+ the Y axis heading north and the Z axis vertical to the ellipsoid.
+ """
+ eps = 1e-2
+ x, y, z = ecef_from_lla(lat, lon, alt)
+ v1 = (
+ (
+ np.array(ecef_from_lla(lat, lon + eps, alt))
+ - np.array(ecef_from_lla(lat, lon - eps, alt))
+ )
+ / 2
+ / eps
+ )
+ v2 = (
+ (
+ np.array(ecef_from_lla(lat + eps, lon, alt))
+ - np.array(ecef_from_lla(lat - eps, lon, alt))
+ )
+ / 2
+ / eps
+ )
+ v3 = (
+ (
+ np.array(ecef_from_lla(lat, lon, alt + eps))
+ - np.array(ecef_from_lla(lat, lon, alt - eps))
+ )
+ / 2
+ / eps
+ )
+ v1 /= np.linalg.norm(v1)
+ v2 /= np.linalg.norm(v2)
+ v3 /= np.linalg.norm(v3)
+ return np.array(
+ [
+ [v1[0], v2[0], v3[0], x],
+ [v1[1], v2[1], v3[1], y],
+ [v1[2], v2[2], v3[2], z],
+ [0, 0, 0, 1],
+ ]
+ )
+
+
+def topocentric_from_lla(lat, lon, alt: float, reflat, reflon, refalt: float):
+ """
+ Transform from lat, lon, alt to topocentric XYZ.
+
+ >>> lat, lon, alt = -10, 20, 100
+ >>> np.allclose(topocentric_from_lla(lat, lon, alt, lat, lon, alt),
+ ... [0,0,0])
+ True
+ >>> x, y, z = topocentric_from_lla(lat, lon, alt, 0, 0, 0)
+ >>> np.allclose(lla_from_topocentric(x, y, z, 0, 0, 0),
+ ... [lat, lon, alt])
+ True
+ """
+ T = np.linalg.inv(ecef_from_topocentric_transform(reflat, reflon, refalt))
+ x, y, z = ecef_from_lla(lat, lon, alt)
+ tx = T[0, 0] * x + T[0, 1] * y + T[0, 2] * z + T[0, 3]
+ ty = T[1, 0] * x + T[1, 1] * y + T[1, 2] * z + T[1, 3]
+ tz = T[2, 0] * x + T[2, 1] * y + T[2, 2] * z + T[2, 3]
+ return tx, ty, tz
+
+
+def lla_from_topocentric(x, y, z, reflat, reflon, refalt: float):
+ """
+ Transform from topocentric XYZ to lat, lon, alt.
+ """
+ T = ecef_from_topocentric_transform(reflat, reflon, refalt)
+ ex = T[0, 0] * x + T[0, 1] * y + T[0, 2] * z + T[0, 3]
+ ey = T[1, 0] * x + T[1, 1] * y + T[1, 2] * z + T[1, 3]
+ ez = T[2, 0] * x + T[2, 1] * y + T[2, 2] * z + T[2, 3]
+ return lla_from_ecef(ex, ey, ez)
+
+
+class TopocentricConverter(object):
+ """Convert to and from a topocentric reference frame."""
+
+ def __init__(self, reflat, reflon, refalt):
+ """Init the converter given the reference origin."""
+ self.lat = reflat
+ self.lon = reflon
+ self.alt = refalt
+
+ def to_topocentric(self, lat, lon, alt):
+ """Convert lat, lon, alt to topocentric x, y, z."""
+ return topocentric_from_lla(lat, lon, alt, self.lat, self.lon, self.alt)
+
+ def to_lla(self, x, y, z):
+ """Convert topocentric x, y, z to lat, lon, alt."""
+ return lla_from_topocentric(x, y, z, self.lat, self.lon, self.alt)
+
+ def __eq__(self, o):
+ return np.allclose([self.lat, self.lon, self.alt], (o.lat, o.lon, o.alt))
\ No newline at end of file
diff --git a/mia/fpv/get_fpv.py b/mia/fpv/get_fpv.py
new file mode 100644
index 0000000000000000000000000000000000000000..0d7bc407597a1c8645552b1543f95535cab2f966
--- /dev/null
+++ b/mia/fpv/get_fpv.py
@@ -0,0 +1,397 @@
+"""python3.9 -m mia.fpv.get_fpv --cfg mia/conf/example.yaml"""
+
+import argparse
+import itertools
+import traceback
+from functools import partial
+from typing import Dict
+from pathlib import Path
+import tracemalloc
+import copy
+import json
+
+import numpy as np
+import asyncio
+from tqdm import tqdm
+from omegaconf import OmegaConf
+import pandas as pd
+
+from .. import logger
+from .geo import Projection
+
+from .download import (
+ MapillaryDownloader,
+ fetch_image_infos,
+ fetch_images_pixels,
+ get_city_boundary,
+ get_tiles_from_boundary,
+)
+from .prepare import process_sequence, default_cfg
+from .filters import in_shape_filter, FilterPipeline
+
+class JSONEncoder(json.JSONEncoder):
+ def default(self, obj):
+ if isinstance(obj, np.ndarray):
+ return obj.tolist()
+ elif isinstance(obj, np.generic):
+ return obj.item()
+ return json.JSONEncoder.default(self, obj)
+
+def write_json(path, data):
+ with open(path, "w") as f:
+ json.dump(data, f, cls=JSONEncoder)
+
+def get_token(token: str) -> str:
+ if Path(token).is_file():
+ logger.info(f"Reading token from file {token}")
+ with open(token, 'r') as file:
+ token = file.read().strip()
+
+ if not token.startswith("MLY"):
+ logger.fatal(f"The token '{token}' is invalid")
+ exit(1)
+ else:
+ logger.info(f"Using token {token}")
+ return token
+
+def fetch_city_boundaries(cities: list):
+ """
+ Args:
+ cities: List of dictionaries describing the city/region to fetch in the fpv.yaml format.
+ """
+ data = []
+ pbar = tqdm(cities)
+ for loc_info in pbar:
+ loc_fmt = loc_info["name"]
+
+ if "state" in loc_info:
+ loc_fmt = f"{loc_fmt}, {loc_info['state']}"
+ else:
+ loc_info["state"] = ""
+
+ if "country" in loc_info:
+ loc_fmt = f"{loc_fmt}, {loc_info['country']}"
+ else:
+ loc_info["country"] = ""
+
+ pbar.set_description(f"Getting boundary for {loc_fmt}")
+ entry = copy.copy(dict(loc_info))
+
+ get_city_boundary_ = partial(get_city_boundary, loc_info["name"], loc_info["state"], loc_info["country"])
+ if "bound_type" not in loc_info:
+ assert "sequence_ids" in loc_info
+ raise NotImplementedError()
+ elif loc_info["bound_type"] == "custom_bbox":
+ assert "custom_bbox" in loc_info
+ entry["bbox"] = dict(zip(["west", "south", "east", "north"],
+ [float(x) for x in loc_info["custom_bbox"].split(",")]))
+ elif loc_info["bound_type"] == "auto_shape":
+ entry["bbox"], entry["shape"] = get_city_boundary_(fetch_shape=True)
+ elif loc_info["bound_type"] == "auto_bbox":
+ entry["bbox"] = get_city_boundary_(fetch_shape=False)
+ elif loc_info["bound_type"] == "custom_size":
+ assert "custom_size" in loc_info
+ custom_size = loc_info["custom_size"]
+ bbox = get_city_boundary_(fetch_shape=False)
+ # Calculation below is obviously not very accurate.
+ # Good enough for small bounding boxes
+ bbox_center = [(bbox['west'] + bbox['east'])/2, (bbox['south'] + bbox['north'])/2]
+ bbox['west'] = bbox_center[0] - custom_size / (111.32*np.cos(np.deg2rad(bbox_center[1])))
+ bbox['east'] = bbox_center[0] + custom_size / (111.32*np.cos(np.deg2rad(bbox_center[1])))
+ bbox['south'] = bbox_center[1] - custom_size / 111.32
+ bbox['north'] = bbox_center[1] + custom_size / 111.32
+ entry["bbox"] = bbox
+ entry["custom_size"] = custom_size
+ else:
+ raise Exception(f"Unsupported bound_type type '{loc_info['bound_type']}'")
+
+ data.append(entry)
+ return data
+
+def geojson_feature_list_to_pandas(feature_list, split_coords=True):
+ t = pd.json_normalize(feature_list)
+ cols_to_drop = ["type", "geometry.type", "properties.organization_id", "computed_geometry.type"]
+ if split_coords:
+ t[['geometry.long','geometry.lat']] = pd.DataFrame(t["geometry.coordinates"].tolist(), index=t.index)
+ # Computed geometry maybe nan if its not available so we check if the value could be a nan (a float type)
+ if "computed_geometry.coordinates" in t.columns:
+ t["computed_geometry.long"] = t["computed_geometry.coordinates"].map(lambda x: (x if isinstance(x, float) else x[0]) )
+ t["computed_geometry.lat"] = t["computed_geometry.coordinates"].map(lambda x: (x if isinstance(x, float) else x[1]) )
+
+ t.drop(columns=cols_to_drop, inplace=True, errors="ignore")
+ t.columns = t.columns.str.removeprefix('properties.')
+ t["id"] = t["id"].astype(str)
+ return t
+
+def parse_image_points_json_data(rd: dict, combine=True) -> pd.DataFrame:
+ """
+ Parse the json in to a pandas dataframe
+ """
+ df_dict = dict()
+ for tile, feature_list in tqdm(rd.items(), total=len(rd)):
+ if len(feature_list) == 0:
+ continue
+ df_dict[tile] = geojson_feature_list_to_pandas(feature_list)
+
+ if combine:
+ logger.info(f"Joining all dataframes into one.")
+ return pd.concat(df_dict.values())
+ else:
+ return df_dict
+
+def log_memory_usage():
+ current, peak = tracemalloc.get_traced_memory()
+ current_gb = current / 10**9
+ peak_gb = peak / 10**9
+ logger.info(f"Current memory: {current_gb:.3f} GB; Peak was {peak_gb:.3f} GB")
+
+def main(args, cfgs):
+ pipeline = FilterPipeline.load_from_yaml(cfgs.fpv_options.filter_pipeline_cfg)
+
+ # setup the mapillary downloader
+ tracemalloc.start()
+ token = get_token(args.token)
+ downloader = MapillaryDownloader(token)
+ loop = asyncio.get_event_loop()
+
+ # setup file structure
+ dataset_dir = Path(cfgs.dataset_dir)
+ dataset_dir.mkdir(exist_ok=True, parents=True)
+
+ # Fetch the bounds for the cities
+ logger.info(f"Auto fetching boundaries for cities if needed.")
+ cities_bounds_info = fetch_city_boundaries(cfgs.cities)
+
+ log_memory_usage()
+
+ # loop through the cities and collect the mapillary data (images, metadata, etc.)
+ for city_boundary_info in cities_bounds_info:
+ # Clear out dataframes since we may use None checks to see if we need
+ # to load the dataframe for a particular stage
+ df = None
+ df_meta = None
+ df_meta_filtered = None
+ df_meta_filtered_processed = None
+
+ logger.info(f"Processing {city_boundary_info['name']}")
+ # setup the directories
+ location_name = city_boundary_info['name'].lower().replace(" ", "_")
+ location_dir = dataset_dir / location_name
+ infos_dir = location_dir / "image_infos_chunked"
+ raw_image_dir = location_dir / "images_raw"
+ out_image_dir = location_dir / "images"
+ for d in (infos_dir, raw_image_dir, out_image_dir, location_dir):
+ if not d.exists():
+ logger.info(f"{d} does not exist. Creating directory {d}")
+ d.mkdir(parents=True, exist_ok=True)
+ write_json(location_dir / "boundary_info.json", city_boundary_info)
+
+ # Stage 1: collect the id of the images in the specified bounding box
+ if cfgs.fpv_options.stages.get_image_points_from_tiles:
+ logger.info(f"[{location_name}] Stage 1 (Downloading image IDs) ------------------")
+ tiles = get_tiles_from_boundary(city_boundary_info)
+ logger.info(f"[{location_name}] Found {len(tiles)} zoom-14 tiles for this boundary. Starting image point download")
+ image_points_response = loop.run_until_complete(
+ downloader.get_tiles_image_points(tiles)
+ )
+ if image_points_response is None:
+ logger.warn(f"[{location_name}] No image points found in boundary. Skipping city")
+ continue
+ write_json(location_dir / 'images_points_dump.json', image_points_response)
+
+ # parse the data into a geopandas dataframe
+ logger.info(f"[{location_name}] Parsing image point json data into dataframe")
+ df = parse_image_points_json_data(image_points_response)
+
+ # Filter if needed
+ if city_boundary_info["bound_type"] == "auto_shape":
+ old_count = df.shape[0]
+ df = df[in_shape_filter(df, city_boundary_info["shape"])]
+ new_count = df.shape[0]
+ logger.info(f"[{location_name}] Keeping {new_count}/{old_count} ({new_count/old_count*100:.2f}%) "
+ "points that are within city boundaries")
+ df.to_parquet(location_dir / 'image_points.parquet')
+
+ # Stage 2: download the metadata
+ if cfgs.fpv_options.stages.get_metadata:
+ logger.info(f"[{location_name}] Stage 2 (Downloading Metadata) ------------------")
+ if df is None:
+ pq_name = 'image_points.parquet'
+ df = pd.read_parquet(location_dir / pq_name)
+ logger.info(f"[{location_name}] Loaded {df.shape[0]} image points from {pq_name}")
+ log_memory_usage()
+
+ # chunk settings
+ chunk_size = cfgs.fpv_options.metadata_download_chunk_size
+ num_split = int(np.ceil(df.shape[0] / chunk_size))
+ logger.info(f"[{location_name}] Splitting the {df.shape[0]} image points into {num_split} chunks of {chunk_size} image points each.")
+
+ # check if the metadata chunk has already been downloaded
+ num_downloaded_chunks = 0
+ num_of_chunks_in_dir = len(list(infos_dir.glob("image_metadata_chunk_*.parquet")))
+ df_meta_chunks = list()
+ df_meta = pd.DataFrame()
+ if infos_dir.exists() and num_of_chunks_in_dir > 0:
+ logger.info(f"[{location_name}] Found {len(list(infos_dir.glob('image_metadata_chunk_*.parquet')))} existing metadata chunks.")
+ downloaded_ids = []
+ num_downloaded_data_pts = 0
+ pbar = tqdm(infos_dir.glob("image_metadata_chunk_*.parquet"), total=num_of_chunks_in_dir)
+ for chunk_fp in pbar:
+ pbar.set_description(f"Loading {chunk_fp}")
+ chunk_df = pd.read_parquet(chunk_fp)
+ df_meta_chunks.append(chunk_df)
+ num_downloaded_chunks += 1
+ num_downloaded_data_pts += len(chunk_df)
+ log_memory_usage()
+
+ num_pts_left = df.shape[0] - num_downloaded_data_pts
+
+ df_meta = pd.concat(df_meta_chunks)
+ df_meta_chunks.clear()
+ df = df[~df["id"].isin(df_meta["id"])]
+
+ # some quick checks to make sure the data is consistent
+ left_num_split = int(np.ceil(df.shape[0] / chunk_size))
+ # if num_downloaded_chunks != (num_split - left_num_split):
+ # raise ValueError(f"Number of downloaded chunks {num_downloaded_chunks} does not match the number of chunks {num_split - left_num_split}")
+ if num_pts_left != len(df):
+ raise ValueError(f"Number of points left {num_pts_left} does not match the number of points in the dataframe {len(df)}")
+
+ if num_pts_left > 0:
+ logger.info(f"Restarting metadata download with {num_pts_left} points, {left_num_split} chunks left to download.")
+
+ # download the metadata
+ num_split = int(np.ceil(df.shape[0] / chunk_size))
+ groups = df.groupby(np.arange(len(df.index)) // chunk_size)
+
+ for (frame_num, frame) in groups:
+ frame_num = frame_num + num_downloaded_chunks
+ logger.info(f"[{location_name}] Fetching metadata for {frame_num+1}/{num_split} chunk of {frame.shape[0]} image points.")
+ image_ids = frame["id"]
+ image_infos, num_fail = loop.run_until_complete(
+ fetch_image_infos(image_ids, downloader, infos_dir)
+ )
+ logger.info("%d failures (%.1f%%).", num_fail, 100 * num_fail / len(image_ids))
+ if num_fail == len(image_ids):
+ logger.warn(f"[{location_name}] All images failed to be fetched. Skipping next steps")
+ continue
+ new_df_meta = geojson_feature_list_to_pandas(image_infos.values())
+ df_meta_chunks.append(new_df_meta)
+ new_df_meta.to_parquet(infos_dir / f'image_metadata_chunk_{frame_num}.parquet')
+ log_memory_usage()
+
+ # Combine all new chunks into one DF
+ df_meta = pd.concat([df_meta] + df_meta_chunks)
+ df_meta_chunks.clear()
+
+ # Some standardization of the data
+ df_meta["model"] = df_meta["model"].str.lower().str.replace(' ', '').str.replace('_', '')
+ df_meta["make"] = df_meta["make"].str.lower().str.replace(' ', '').str.replace('_', '')
+ df_meta.to_parquet(location_dir / 'image_metadata.parquet')
+
+ # Stage 3: run filter pipeline
+ if cfgs.fpv_options.stages.run_filter:
+ logger.info(f"[{location_name}] Stage 3 (Filtering) ------------------")
+
+ if df_meta is None:
+ pq_name = 'image_metadata.parquet'
+ df_meta = pd.read_parquet(location_dir / pq_name)
+ logger.info(f"[{location_name}] Loaded {df_meta.shape[0]} image metadata from {pq_name}")
+
+ df_meta_filtered = pipeline(df_meta)
+ df_meta_filtered.to_parquet(location_dir / f'image_metadata_filtered.parquet')
+ if df_meta_filtered.shape[0] == 0:
+ logger.warning(f"[{location_name}] No images to download. Moving on to next location.")
+ continue
+ else:
+ logger.info(f"[{location_name}] {df_meta_filtered.shape[0]} images to download.")
+
+ # Stage 4: Download filtered images
+ if cfgs.fpv_options.stages.download_images:
+ logger.info(f"[{location_name}] Stage 4 (Downloading Images) ------------------")
+ if df_meta_filtered is None:
+ pq_name = f'image_metadata_filtered.parquet'
+ df_meta_filtered = pd.read_parquet(location_dir / pq_name)
+ logger.info(f"[{location_name}] Loaded {df_meta_filtered.shape[0]} image metadata from {pq_name}")
+ log_memory_usage()
+ # filter out the images that have already been downloaded
+ downloaded_image_fps = list(raw_image_dir.glob("*.jpg"))
+ downloaded_image_ids = [fp.stem for fp in downloaded_image_fps]
+ df_to_download = df_meta_filtered[~df_meta_filtered["id"].isin(downloaded_image_ids)]
+ logger.info(f"[{location_name}] {len(downloaded_image_ids)} images already downloaded. {df_to_download.shape[0]} images left to download.")
+
+ # download the images
+ image_urls = list(df_to_download.set_index("id")["thumb_2048_url"].items())
+ if len(image_urls) > 0:
+ num_fail = loop.run_until_complete(
+ fetch_images_pixels(image_urls, downloader, raw_image_dir)
+ )
+ logger.info("%d failures (%.1f%%).", num_fail, 100 * num_fail / len(image_urls))
+
+ # Stage 5: process the sequences
+ if cfgs.fpv_options.stages.to_process_sequence:
+ logger.info(f"[{location_name}] Stage 5 (Sequence Processing) ------------------")
+ if df_meta_filtered is None:
+ pq_name = f'image_metadata_filtered.parquet'
+ df_meta_filtered = pd.read_parquet(location_dir / pq_name)
+ logger.info(f"[{location_name}] Loaded {df_meta_filtered.shape[0]} image metadata from {pq_name}")
+ log_memory_usage()
+
+ # prepare the data for processing
+ seq_to_image_ids = df_meta_filtered.groupby('sequence')['id'].agg(list).to_dict()
+ lon_center = (city_boundary_info['bbox']['east'] + city_boundary_info['bbox']['west']) / 2
+ lat_center = (city_boundary_info['bbox']['north'] + city_boundary_info['bbox']['south']) / 2
+ projection = Projection(lat_center, lon_center, max_extent=50e3) # increase to 50km max extent for the projection, otherwise it will throw an error
+
+ df_meta_filtered.index = df_meta_filtered["id"]
+ image_infos = df_meta_filtered.to_dict(orient="index")
+ process_sequence_args = default_cfg
+
+ log_memory_usage()
+
+ # process the sequences
+ dump = {}
+ logger.info(f"[{location_name}] Processing downloaded sequences..")
+
+ processed_ids = list()
+
+ for seq_id, seq_image_ids in tqdm(seq_to_image_ids.items()):
+ try:
+ d, pi = process_sequence(
+ seq_image_ids,
+ image_infos,
+ projection,
+ process_sequence_args,
+ raw_image_dir,
+ out_image_dir,
+ )
+ if d is None or pi is None:
+ raise Exception("process_sequence returned None")
+ processed_ids.append(pi)
+ # TODO We shouldn't need dumps
+ dump.update(d)
+
+ except Exception as e:
+ logger.error(f"[{location_name}] Failed to process sequence {seq_id} skipping it. Error: {repr(e)}.")
+ logger.error(traceback.format_exc())
+
+ write_json(location_dir / "dump.json", dump)
+
+ # TODO: Ideally we want to move the keyframe selection filter to
+ # The filtering pipeline such that we do not download unnecessary
+ # Raw Images. But for now, we will filter the dataframe one more time after processing
+ processed_ids = list(itertools.chain.from_iterable(processed_ids))
+ df_meta_filtered_processed = df_meta_filtered[ df_meta_filtered["id"].isin(processed_ids)]
+ logger.info(f"[{location_name}] Final yield after processing is {df_meta_filtered_processed.shape[0]} images.")
+ df_meta_filtered_processed.to_parquet(location_dir / f'image_metadata_filtered_processed.parquet')
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--cfg", type=str, default="mia/conf/example.yaml", help="Path to config yaml file.")
+ parser.add_argument("--token", type=str, default='mapillary_key', help="Either a token string or a path to a file containing the token.")
+ args = parser.parse_args()
+
+ cfgs = OmegaConf.load(args.cfg)
+
+ main(args, cfgs)
diff --git a/mia/fpv/prepare.py b/mia/fpv/prepare.py
new file mode 100644
index 0000000000000000000000000000000000000000..aebd42c7b16be566223c9c626fdff5aad4354c74
--- /dev/null
+++ b/mia/fpv/prepare.py
@@ -0,0 +1,183 @@
+# Adapted from prepare.py
+
+import asyncio
+import argparse
+from collections import defaultdict
+import json
+import shutil
+from pathlib import Path
+from typing import List, Dict
+
+import numpy as np
+import cv2
+from tqdm import tqdm
+from tqdm.contrib.concurrent import thread_map
+from omegaconf import DictConfig, OmegaConf
+from opensfm.pygeometry import Camera
+from opensfm.pymap import Shot
+from opensfm.undistort import (
+ perspective_camera_from_fisheye,
+ perspective_camera_from_perspective,
+)
+
+from .. import logger
+# from ...osm.tiling import TileManager
+# from ...osm.viz import GeoPlotter
+from .geo import BoundaryBox, Projection
+from .utils import decompose_rotmat
+from .utils_sfm import (
+ keyframe_selection,
+ perspective_camera_from_pano,
+ scale_camera,
+ CameraUndistorter,
+ PanoramaUndistorter,
+ undistort_shot,
+)
+from .download import (
+ opensfm_shot_from_info,
+ image_filename,
+)
+
+
+default_cfg = OmegaConf.create(
+ {
+ "max_image_size": 512,
+ "do_legacy_pano_offset": True,
+ "min_dist_between_keyframes": 4,
+ "tiling": {
+ "tile_size": 128,
+ "margin": 128,
+ "ppm": 2,
+ },
+ }
+)
+
+
+def get_pano_offset(image_info: dict, do_legacy: bool = False) -> float:
+ if do_legacy:
+ seed = int(image_info["sfm_cluster"]["id"])
+ else:
+ seed = image_info["sequence"].__hash__()
+ seed = seed % (2**32 - 1)
+ return np.random.RandomState(seed).uniform(-45, 45)
+
+
+def process_shot(
+ shot: Shot, info: dict, image_path: Path, output_dir: Path, cfg: DictConfig
+) -> List[Shot]:
+ if not image_path.exists():
+ logger.warn(f"Image {image_path} does not exist !")
+ return None
+
+ image_orig = cv2.imread(str(image_path))
+ max_size = cfg.max_image_size
+ pano_offset = None
+
+ camera = shot.camera
+ camera.width, camera.height = image_orig.shape[:2][::-1]
+ if camera.is_panorama(camera.projection_type):
+ camera_new = perspective_camera_from_pano(camera, max_size)
+ undistorter = PanoramaUndistorter(camera, camera_new)
+ pano_offset = get_pano_offset(info, cfg.do_legacy_pano_offset)
+ elif camera.projection_type in ["fisheye", "perspective"]:
+ if camera.projection_type == "fisheye":
+ camera_new = perspective_camera_from_fisheye(camera)
+ else:
+ camera_new = perspective_camera_from_perspective(camera)
+ camera_new = scale_camera(camera_new, max_size)
+ camera_new.id = camera.id + "_undistorted"
+ undistorter = CameraUndistorter(camera, camera_new)
+ else:
+ raise NotImplementedError(camera.projection_type)
+
+ shots_undist, images_undist = undistort_shot(
+ image_orig, shot, undistorter, pano_offset
+ )
+ for shot, image in zip(shots_undist, images_undist):
+ cv2.imwrite(str(output_dir / f"{shot.id}.jpg"), image)
+
+ return shots_undist
+
+
+def pack_shot_dict(shot: Shot, info: dict) -> dict:
+ latlong = info["computed_geometry.coordinates"][::-1]
+ latlong_gps = info["geometry.coordinates"][::-1]
+ w_p_c = shot.pose.get_origin()
+ w_r_c = shot.pose.get_R_cam_to_world()
+ rpy = decompose_rotmat(w_r_c)
+ return dict(
+ camera_id=shot.camera.id,
+ latlong=latlong,
+ t_c2w=w_p_c,
+ R_c2w=w_r_c,
+ roll_pitch_yaw=rpy,
+ capture_time=info["captured_at"],
+ gps_position=np.r_[latlong_gps, info["altitude"]],
+ compass_angle=info["compass_angle"],
+ chunk_id=int(info["sfm_cluster.id"]),
+ )
+
+
+def pack_camera_dict(camera: Camera) -> dict:
+ assert camera.projection_type == "perspective"
+ K = camera.get_K_in_pixel_coordinates(camera.width, camera.height)
+ return dict(
+ id=camera.id,
+ model="PINHOLE",
+ width=camera.width,
+ height=camera.height,
+ params=K[[0, 1, 0, 1], [0, 1, 2, 2]],
+ )
+
+
+def process_sequence(
+ image_ids: List[int],
+ image_infos: dict,
+ projection: Projection,
+ cfg: DictConfig,
+ raw_image_dir: Path,
+ out_image_dir: Path,
+):
+ shots = []
+ dump = {}
+ processed_ids = list()
+ if len(image_ids) == 0:
+ return dump, processed_ids
+
+ image_ids = sorted(image_ids, key=lambda i: image_infos[i]["captured_at"])
+ for i in image_ids:
+ _, shot = opensfm_shot_from_info(image_infos[i], projection)
+ shots.append(shot)
+ shot_idxs = keyframe_selection(shots, min_dist=cfg.min_dist_between_keyframes)
+ shots = [shots[i] for i in shot_idxs]
+
+ shots_out = thread_map(
+ lambda shot: process_shot(
+ shot,
+ image_infos[shot.id],
+ raw_image_dir / image_filename.format(image_id=shot.id),
+ out_image_dir,
+ cfg,
+ ),
+ shots,
+ disable=True,
+ )
+ shots_out = [s for s in shots_out if s is not None]
+ shots_out = [(i, s) for i, ss in enumerate(shots_out) for s in ss if ss is not None]
+
+ for index, shot in shots_out:
+ i, suffix = shot.id.rsplit("_", 1)
+ processed_ids.append(i)
+ info = image_infos[i]
+ seq_id = info["sequence"]
+ is_pano = not suffix.endswith("undistorted")
+ if is_pano:
+ seq_id += f"_{suffix}"
+ if seq_id not in dump:
+ dump[seq_id] = dict(views={}, cameras={})
+
+ view = pack_shot_dict(shot, info)
+ view["index"] = index
+ dump[seq_id]["views"][shot.id] = view
+ dump[seq_id]["cameras"][shot.camera.id] = pack_camera_dict(shot.camera)
+ return dump, processed_ids
\ No newline at end of file
diff --git a/mia/fpv/utils.py b/mia/fpv/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..440f0a51e6fba02e931718963487be8194ab569d
--- /dev/null
+++ b/mia/fpv/utils.py
@@ -0,0 +1,61 @@
+# Copied from OrienterNet
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+
+import numpy as np
+from scipy.spatial.transform import Rotation
+
+
+def crop_map(raster, xy, size, seed=None):
+ h, w = raster.shape[-2:]
+ state = np.random.RandomState(seed)
+ top = state.randint(0, h - size + 1)
+ left = state.randint(0, w - size + 1)
+ raster = raster[..., top : top + size, left : left + size]
+ xy -= np.array([left, top])
+ return raster, xy
+
+
+def random_rot90(raster, xy, heading, seed=None):
+ rot = np.random.RandomState(seed).randint(0, 4)
+ heading = (heading + rot * np.pi / 2) % (2 * np.pi)
+ h, w = raster.shape[-2:]
+ if rot == 0:
+ xy2 = xy
+ elif rot == 2:
+ xy2 = np.array([w, h]) - 1 - xy
+ elif rot == 1:
+ xy2 = np.array([xy[1], w - 1 - xy[0]])
+ elif rot == 3:
+ xy2 = np.array([h - 1 - xy[1], xy[0]])
+ else:
+ raise ValueError(rot)
+ raster = np.rot90(raster, rot, axes=(-2, -1))
+ return raster, xy2, heading
+
+
+def random_flip(image, raster, xy, heading, seed=None):
+ state = np.random.RandomState(seed)
+ if state.rand() > 0.5: # no flip
+ return image, raster, xy, heading
+ image = image[:, ::-1]
+ h, w = raster.shape[-2:]
+ if state.rand() > 0.5: # flip x
+ raster = raster[..., :, ::-1]
+ xy = np.array([w - 1 - xy[0], xy[1]])
+ heading = np.pi - heading
+ else: # flip y
+ raster = raster[..., ::-1, :]
+ xy = np.array([xy[0], h - 1 - xy[1]])
+ heading = -heading
+ heading = heading % (2 * np.pi)
+ return image, raster, xy, heading
+
+
+def decompose_rotmat(R_c2w):
+ R_cv2xyz = Rotation.from_euler("X", -90, degrees=True)
+ rot_w2c = R_cv2xyz * Rotation.from_matrix(R_c2w).inv()
+ roll, pitch, yaw = rot_w2c.as_euler("YXZ", degrees=True)
+ # rot_w2c_check = R_cv2xyz.inv() * Rotation.from_euler('YXZ', [roll, pitch, yaw], degrees=True)
+ # np.testing.assert_allclose(rot_w2c_check.as_matrix(), R_c2w.T, rtol=1e-6, atol=1e-6)
+ # R_plane2c = Rotation.from_euler("ZX", [roll, pitch], degrees=True).as_matrix()
+ return roll, pitch, yaw
\ No newline at end of file
diff --git a/mia/fpv/utils_sfm.py b/mia/fpv/utils_sfm.py
new file mode 100644
index 0000000000000000000000000000000000000000..3a0e0598314be63b651cd27e37b008e72c658c42
--- /dev/null
+++ b/mia/fpv/utils_sfm.py
@@ -0,0 +1,171 @@
+# Copied from OrienterNet
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+
+from typing import List, Tuple
+
+import cv2
+import numpy as np
+from opensfm import features
+from opensfm.pygeometry import Camera, compute_camera_mapping, Pose
+from opensfm.pymap import Shot
+from scipy.spatial.transform import Rotation
+
+
+def keyframe_selection(shots: List[Shot], min_dist: float = 4) -> List[int]:
+ camera_centers = np.stack([shot.pose.get_origin() for shot in shots], 0)
+ distances = np.linalg.norm(np.diff(camera_centers, axis=0), axis=1)
+ selected = [0]
+ cum = 0
+ for i in range(1, len(camera_centers)):
+ cum += distances[i - 1]
+ if cum >= min_dist:
+ selected.append(i)
+ cum = 0
+ return selected
+
+
+def perspective_camera_from_pano(camera: Camera, size: int) -> Camera:
+ camera_new = Camera.create_perspective(0.5, 0, 0)
+ camera_new.height = camera_new.width = size
+ camera_new.id = "perspective_from_pano"
+ return camera_new
+
+
+def scale_camera(camera: Camera, max_size: int) -> Camera:
+ height = camera.height
+ width = camera.width
+ factor = max_size / float(max(height, width))
+ if factor >= 1:
+ return camera
+ camera.width = int(round(width * factor))
+ camera.height = int(round(height * factor))
+ return camera
+
+
+class PanoramaUndistorter:
+ def __init__(self, camera_pano: Camera, camera_new: Camera):
+ w, h = camera_new.width, camera_new.height
+ self.shape = (h, w)
+
+ dst_y, dst_x = np.indices(self.shape).astype(np.float32)
+ dst_pixels_denormalized = np.column_stack([dst_x.ravel(), dst_y.ravel()])
+ dst_pixels = features.normalized_image_coordinates(
+ dst_pixels_denormalized, w, h
+ )
+ self.dst_bearings = camera_new.pixel_bearing_many(dst_pixels)
+
+ self.camera_pano = camera_pano
+ self.camera_perspective = camera_new
+
+ def __call__(
+ self, image: np.ndarray, panoshot: Shot, perspectiveshot: Shot
+ ) -> np.ndarray:
+ # Rotate to panorama reference frame
+ rotation = np.dot(
+ panoshot.pose.get_rotation_matrix(),
+ perspectiveshot.pose.get_rotation_matrix().T,
+ )
+ rotated_bearings = np.dot(self.dst_bearings, rotation.T)
+
+ # Project to panorama pixels
+ src_pixels = panoshot.camera.project_many(rotated_bearings)
+ src_pixels_denormalized = features.denormalized_image_coordinates(
+ src_pixels, image.shape[1], image.shape[0]
+ )
+ src_pixels_denormalized.shape = self.shape + (2,)
+
+ # Sample color
+ x = src_pixels_denormalized[..., 0].astype(np.float32)
+ y = src_pixels_denormalized[..., 1].astype(np.float32)
+ colors = cv2.remap(image, x, y, cv2.INTER_LINEAR, borderMode=cv2.BORDER_WRAP)
+ return colors
+
+
+class CameraUndistorter:
+ def __init__(self, camera_distorted: Camera, camera_new: Camera):
+ self.maps = compute_camera_mapping(
+ camera_distorted,
+ camera_new,
+ camera_distorted.width,
+ camera_distorted.height,
+ )
+ self.camera_perspective = camera_new
+ self.camera_distorted = camera_distorted
+
+ def __call__(self, image: np.ndarray) -> np.ndarray:
+ assert image.shape[:2] == (
+ self.camera_distorted.height,
+ self.camera_distorted.width,
+ )
+ undistorted = cv2.remap(image, *self.maps, cv2.INTER_LINEAR)
+ resized = cv2.resize(
+ undistorted,
+ (self.camera_perspective.width, self.camera_perspective.height),
+ interpolation=cv2.INTER_AREA,
+ )
+ return resized
+
+
+def render_panorama(
+ shot: Shot,
+ pano: np.ndarray,
+ undistorter: PanoramaUndistorter,
+ offset: float = 0.0,
+) -> Tuple[List[Shot], List[np.ndarray]]:
+ yaws = [0, 90, 180, 270]
+ suffixes = ["front", "left", "back", "right"]
+ images = []
+ shots = []
+
+ # To reduce aliasing, since cv2.remap does not support area samplimg,
+ # we first resize with anti-aliasing.
+ h, w = undistorter.shape
+ h, w = (w * 2, w * 4) # assuming 90deg FOV
+ pano_resized = cv2.resize(pano, (w, h), interpolation=cv2.INTER_AREA)
+
+ for yaw, suffix in zip(yaws, suffixes):
+ R_pano2persp = Rotation.from_euler("Y", yaw + offset, degrees=True).as_matrix()
+ name = f"{shot.id}_{suffix}"
+ shot_new = Shot(
+ name,
+ undistorter.camera_perspective,
+ Pose.compose(Pose(R_pano2persp), shot.pose),
+ )
+ shot_new.metadata = shot.metadata
+ perspective = undistorter(pano_resized, shot, shot_new)
+ images.append(perspective)
+ shots.append(shot_new)
+ return shots, images
+
+
+def undistort_camera(
+ shot: Shot, image: np.ndarray, undistorter: CameraUndistorter
+) -> Tuple[Shot, np.ndarray]:
+ name = f"{shot.id}_undistorted"
+ shot_out = Shot(name, undistorter.camera_perspective, shot.pose)
+ shot_out.metadata = shot.metadata
+ undistorted = undistorter(image)
+ return shot_out, undistorted
+
+
+def undistort_shot(
+ image_raw: np.ndarray,
+ shot_orig: Shot,
+ undistorter,
+ pano_offset: float,
+) -> Tuple[List[Shot], List[np.ndarray]]:
+ camera = shot_orig.camera
+ if image_raw.shape[:2] != (camera.height, camera.width):
+ raise ValueError(
+ shot_orig.id, image_raw.shape[:2], (camera.height, camera.width)
+ )
+ if camera.is_panorama(camera.projection_type):
+ shots, undistorted = render_panorama(
+ shot_orig, image_raw, undistorter, offset=pano_offset
+ )
+ elif camera.projection_type in ("perspective", "fisheye"):
+ shot, undistorted = undistort_camera(shot_orig, image_raw, undistorter)
+ shots, undistorted = [shot], [undistorted]
+ else:
+ raise NotImplementedError(camera.projection_type)
+ return shots, undistorted
\ No newline at end of file
diff --git a/mia/misc_tools/calc_stats.py b/mia/misc_tools/calc_stats.py
new file mode 100644
index 0000000000000000000000000000000000000000..cccbaade05adaf4b6c049cc213df9f7a1a524159
--- /dev/null
+++ b/mia/misc_tools/calc_stats.py
@@ -0,0 +1,232 @@
+"""
+
+Example usage:
+ python3.9 -m mapper.data.debug.calc_stats -d /ocean/projects/cis220039p/shared/map_perception/dataset_v0
+"""
+import datetime
+from datetime import datetime, timezone, timedelta
+import time
+import argparse
+import os
+from pathlib import Path
+import json
+
+from astral import LocationInfo
+from astral.sun import sun
+from timezonefinder import TimezoneFinder
+
+import numpy as np
+import pandas as pd
+import geopandas as gpd
+from pyproj.transformer import Transformer
+from matplotlib import pyplot as plt
+from matplotlib.backends.backend_pdf import PdfPages
+import tqdm
+
+from ..fpv import filters
+from .. import logger
+
+
+def is_daytime(timestamp, latitude, longitude):
+ # Create a LocationInfo object for the given latitude and longitude
+ tz_str = TimezoneFinder().timezone_at(lng=longitude, lat=latitude)
+ location = LocationInfo(name="", region="", timezone=tz_str,
+ latitude=latitude, longitude=longitude)
+
+ # Convert the timestamp to a datetime object
+ dt = datetime.fromtimestamp(timestamp, tz=timezone.utc)
+ # We query one day before and one day after to avoid timezone ambiguities
+ # Our query timestamp is guaranteed to fall into one of those 3 dates.
+ # Astral sometimes returns sunrise or sunsets that are not from the same query date
+ # Refer to this https://github.com/sffjunkie/astral/issues/83
+ d0 = (dt - timedelta(days=1)).date()
+ d1 = dt.date()
+ d2 = (dt + timedelta(days=1)).date()
+
+ # Calculate sunrise and sunset times
+ times = list()
+ for d in [d0, d1, d2]:
+ s = sun(location.observer, date=d)
+ sunrise = s['sunrise']
+ sunset = s['sunset']
+ times.append((sunrise, "sunrise"))
+ times.append((sunset, 'sunset'))
+
+ # Need to sort because there is no particular order
+ # where sunrise is always before sunset or vice versa
+ times = sorted(times, key=lambda x: x[0])
+ assert times[-1][0] > dt > times[0][0]
+
+ for i in range(1, len(times)):
+ if dt < times[i][0]:
+ prev_event = times[i-1][1]
+ break
+
+ return prev_event == "sunrise"
+
+def calculate_occupancy_map(df: pd.DataFrame, bev_meter_coverage=112, meters_per_pixel=112):
+ """
+ Args:
+ bev_meter_coverage: How much did the BEVs in the dataframe cover in meters
+ meters_per_pixel: At what resolution should we initialize the occupancy map.
+ This need not be the same resolution as the BEV. That would be unnecessarilly slow but most accurate.
+ """
+ # convert pandas dataframe to geopandas dataframe
+ gdf = gpd.GeoDataFrame(df,
+ geometry=gpd.points_from_xy(
+ df['computed_geometry.long'],
+ df['computed_geometry.lat']),
+ crs=4326)
+
+ utm_crs = gdf.estimate_utm_crs()
+ gdf_utm = gdf.to_crs(utm_crs)
+ left = gdf_utm.geometry.x.min() - bev_meter_coverage
+ right = gdf_utm.geometry.x.max() + bev_meter_coverage
+ bottom = gdf_utm.geometry.y.min() - bev_meter_coverage
+ top = gdf_utm.geometry.y.max() + bev_meter_coverage
+
+ width = right - left
+ height = top - bottom
+ width_pixels = int(width // meters_per_pixel)
+ height_pixels = int(height // meters_per_pixel)
+ if bev_meter_coverage % meters_per_pixel != 0:
+ logger.warn(f"bev_meter_coverage {bev_meter_coverage} is not divisble by meters_per_pixel "
+ f"{meters_per_pixel}. Occupancy may be overestimated.")
+
+ bev_pixels = int(np.ceil(bev_meter_coverage / meters_per_pixel))
+
+ logger.info(f"Initializing {height_pixels}x{width_pixels} occupancy map. Using {bev_pixels}x{bev_pixels} pixels for each BEV.")
+ map = np.zeros((height_pixels, width_pixels), dtype=bool)
+
+ for row in gdf_utm.itertuples():
+ utm_x = row.geometry.x
+ utm_y = row.geometry.y
+ img_x = int((utm_x - left) // meters_per_pixel)
+ img_y = int((utm_y - bottom) // meters_per_pixel)
+
+ bev_pixels_left = bev_pixels // 2
+ bev_pixels_right = bev_pixels - bev_pixels_left
+ map[img_y - bev_pixels_left: img_y + bev_pixels_right,
+ img_x - bev_pixels_left: img_x + bev_pixels_right] = True
+
+ return map
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--dataset_dir", '-d', type=str, required=True, help="Dataset directory")
+ parser.add_argument("--locations", '-l', type=str, default="all",
+ help="Location names in CSV format. Set to 'all' to traverse all locations.")
+ parser.add_argument("--plot", action="store_true", help="Store plots per location in PDFs")
+ parser.add_argument("--output", "-o", default=None, type=str, help="output json file to store statistics")
+ args = parser.parse_args()
+
+ locations = list()
+ if args.locations.lower() == "all":
+ locations = os.listdir(args.dataset_dir)
+ locations = [l for l in locations if os.path.isdir(os.path.join(args.dataset_dir, l))]
+ else:
+ locations = args.locations.split(",")
+
+ logger.info(f"Parsing {len(locations)} locations..")
+
+ all_locs_stats = dict()
+
+ for location in tqdm.tqdm(locations):
+ dataset_dir = Path(args.dataset_dir)
+ location_dir = dataset_dir / location
+ bev_dir = location_dir / "bev_raw"
+ semantic_mask_dir = location_dir / "semantic_masks"
+ osm_cache_dir = location_dir / "osm_cache"
+
+ pq_name = 'image_metadata_filtered_processed.parquet'
+ df = pd.read_parquet(location_dir / pq_name)
+
+ df = df[df["computed_geometry.lat"].notna()]
+ df = df[df["computed_geometry.long"].notna()]
+
+ logger.info(f"Loaded {df.shape[0]} image metadata from {location}")
+
+ # Calc derrivative attributes
+ tqdm.tqdm.pandas()
+
+ df["loc_descrip"] = filters.haversine_np(
+ lon1=df["geometry.long"], lat1=df["geometry.lat"],
+ lon2=df["computed_geometry.long"], lat2=df["computed_geometry.lat"]
+ )
+
+ df["angle_descrip"] = filters.angle_dist(
+ df["compass_angle"],
+ df["computed_compass_angle"]
+ )
+
+ # FIXME: Super slow
+ # df["is_daytime"] = df.progress_apply(lambda x: is_daytime(x["captured_at"]*1e-3,
+ # x["computed_geometry.lat"],
+ # x["computed_geometry.long"]),
+ # axis="columns", raw=False, engine="python")
+
+ meters_per_pixel = 7
+ map = calculate_occupancy_map(df, bev_meter_coverage=112,
+ meters_per_pixel=meters_per_pixel)
+
+ # Calc aggregate stats
+ loc_stats = dict()
+ loc_stats["num_images"] = len(df)
+ loc_stats["area_covered_km2"] = np.sum(map) * meters_per_pixel ** 2 * 1e-6
+ loc_stats["camera_types"] = set(df["camera_type"].unique())
+ loc_stats["camera_makes"] = set(df["make"].unique())
+ loc_stats["camera_model"] = set(df["model"].unique())
+
+ all_locs_stats[location] = loc_stats
+
+ # Plot if requested
+ if args.plot:
+ with PdfPages(location_dir / "stats.pdf") as pdf:
+ plt.figure()
+ plt.imshow(map)
+ plt.title(f"{location} occupancy map")
+ pdf.savefig()
+ plt.close()
+ for k in ["make", "model", "camera_type", "loc_descrip",
+ "angle_descrip"]:
+ plt.figure()
+ df[k].hist()
+ plt.title(k)
+ plt.xlabel(k)
+ plt.xticks(rotation=90)
+ plt.ylabel("Count")
+ plt.tight_layout()
+ pdf.savefig()
+ plt.close()
+
+ # Aggregate all stats
+ aggregated_stats = dict()
+ for loc, loc_stats in all_locs_stats.items():
+ for k,v in loc_stats.items():
+ if isinstance(v, float) or isinstance(v, int):
+ if k not in aggregated_stats.keys():
+ aggregated_stats[k] = v
+ else:
+ aggregated_stats[k] += v
+ elif isinstance(v, set):
+ if k not in aggregated_stats.keys():
+ aggregated_stats[k] = v
+ else:
+ aggregated_stats[k] = aggregated_stats[k].union(v)
+ aggregated_stats[f"{k}_count"] = len(aggregated_stats[k])
+ else:
+ raise Exception(f"{v} is not supported !")
+
+ all_locs_stats["aggregated"] = aggregated_stats
+
+ print(all_locs_stats)
+
+ # Store for json
+ for loc, loc_stats in all_locs_stats.items():
+ for k,v in loc_stats.items():
+ if isinstance(v, set):
+ loc_stats[k] = list(v)
+
+ if args.output:
+ with open(args.output, "w") as f:
+ json.dump(all_locs_stats, f, indent=2)
\ No newline at end of file
diff --git a/mia/misc_tools/vis_samples.py b/mia/misc_tools/vis_samples.py
new file mode 100644
index 0000000000000000000000000000000000000000..b14149ca5681ef264934a70f8c0f7c219cdacd79
--- /dev/null
+++ b/mia/misc_tools/vis_samples.py
@@ -0,0 +1,133 @@
+import argparse
+import os
+from pathlib import Path
+
+import matplotlib.pyplot as plt
+from matplotlib.backends.backend_pdf import PdfPages
+from matplotlib.patches import Patch
+import pandas as pd
+import numpy as np
+import tqdm
+
+from ..bev.get_bev import mask2rgb, PRETTY_COLORS as COLORS, VIS_ORDER
+from ..fpv.filters import haversine_np, angle_dist
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--dataset_dir", '-d', type=str, required=True, help="Dataset directory")
+ parser.add_argument("--locations", '-l', type=str, default="all",
+ help="Location names in CSV format. Set to 'all' to traverse all locations.")
+ parser.add_argument("--rows", type=int, default=5, help="How many samples per PDF page")
+ parser.add_argument("--n_samples", type=int, default=30, help="How many samples to visualize?")
+ parser.add_argument("--store_sat", action="store_true", help="Add sattelite column")
+ args = parser.parse_args()
+
+ MAX_ROWS = args.rows
+ MAX_COLS = 4 if args.store_sat else 3
+ MAX_TEXT_LEN=30
+
+ locations = list()
+ if args.locations.lower() == "all":
+ locations = os.listdir(args.dataset_dir)
+ locations = [l for l in locations if os.path.isdir(os.path.join(args.dataset_dir, l))]
+ else:
+ locations = args.locations.split(",")
+
+ print(f"Parsing {len(locations)} locations..")
+
+ all_locs_stats = dict()
+
+ for location in tqdm.tqdm(locations):
+ dataset_dir = Path(args.dataset_dir)
+ location_dir = dataset_dir / location
+ semantic_mask_dir = location_dir / "semantic_masks"
+ sat_dir = location_dir / "sattelite"
+ comp_dir = location_dir / "images"
+
+ pq_name = 'image_metadata_filtered_processed.parquet'
+ df = pd.read_parquet(location_dir / pq_name)
+
+ # Calc derrivative attributes
+ df["loc_descrip"] = haversine_np(
+ lon1=df["geometry.long"], lat1=df["geometry.lat"],
+ lon2=df["computed_geometry.long"], lat2=df["computed_geometry.lat"]
+ )
+
+ df["angle_descrip"] = angle_dist(
+ df["compass_angle"],
+ df["computed_compass_angle"]
+ )
+
+ with PdfPages(location_dir / 'compare.pdf') as pdf:
+ # Plot legend page
+ plt.figure()
+ key2mask_i = dict(zip(COLORS.keys(), range(len(COLORS))))
+ patches = [Patch(color=COLORS[key], label=f"{key}") for i,key in enumerate(VIS_ORDER) if COLORS[key] is not None]
+ plt.legend(handles=patches, loc='center', title='Legend')
+ plt.axis("off")
+ plt.tight_layout()
+ pdf.savefig()
+ plt.close()
+
+ # Plot pairs
+ row_cnt = 0
+ fig = plt.figure(figsize=(MAX_COLS*2, MAX_ROWS*2))
+ for index, row in tqdm.tqdm(df.iterrows()):
+ id = row["id"]
+ mask_fp = semantic_mask_dir / f"{id}.npz"
+ comp_fp = comp_dir / f"{id}_undistorted.jpg"
+ sat_fp = sat_dir / f"{id}.png"
+ if not os.path.exists(mask_fp) or not os.path.exists(comp_fp) or \
+ (args.store_sat and not os.path.exists(sat_fp)):
+ continue
+ plt.subplot(MAX_ROWS, MAX_COLS, (row_cnt % MAX_ROWS)*MAX_COLS + 1)
+ plt.axis("off")
+ desc = list()
+
+ # Display attributes
+ keys = ["geometry.long", "geometry.lat", "compass_angle",
+ "loc_descrip", "angle_descrip",
+ "make", "model", "camera_type",
+ "quality_score"]
+ for k in keys:
+ v = row[k]
+ if isinstance(v, float):
+ v = f"{v:.4f}"
+ bullet = f"{k}: {v}"
+ if len(bullet) > MAX_TEXT_LEN:
+ bullet = bullet[:MAX_TEXT_LEN-2] + ".."
+ desc.append(bullet)
+ plt.text(0,0, "\n".join(desc), fontsize=7)
+ plt.title(id)
+ plt.subplot(MAX_ROWS, MAX_COLS, (row_cnt % MAX_ROWS)*MAX_COLS + 2)
+
+
+ mask = np.load(mask_fp)["arr_0"]
+ mask_rgb = mask2rgb(mask)
+ plt.imshow(mask_rgb); plt.axis("off")
+ plt.title(f"BEV")
+ H,W,_ = mask_rgb.shape
+ plt.scatter(np.array([H/2]), np.array([W/2]), marker="x")
+
+ plt.subplot(MAX_ROWS, MAX_COLS, (row_cnt % MAX_ROWS)*MAX_COLS + 3)
+
+ plt.imshow(plt.imread(comp_fp)); plt.axis("off")
+ plt.title(f"FPV")
+
+ if args.store_sat:
+ sat_fp = sat_dir / f"{id}.png"
+ plt.subplot(MAX_ROWS, MAX_COLS, (row_cnt % MAX_ROWS)*MAX_COLS + 4)
+ plt.imshow(plt.imread(sat_fp)); plt.axis("off")
+ plt.title(f"SAT")
+
+ row_cnt += 1
+ if row_cnt % MAX_ROWS == 0:
+ #plt.suptitle(location)
+ plt.tight_layout()
+ fig.align_titles()
+ pdf.savefig()
+ plt.close()
+ fig = plt.figure(figsize=(MAX_COLS*2, MAX_ROWS*2))
+
+ if row_cnt == args.n_samples:
+ break
\ No newline at end of file