Spaces:
Building
Building
import json | |
from collections import defaultdict | |
import os | |
import shutil | |
import tarfile | |
from pathlib import Path | |
from typing import Optional | |
import numpy as np | |
import pytorch_lightning as pl | |
import torch | |
import torch.utils.data as torchdata | |
from omegaconf import DictConfig | |
from ... import logger | |
from .dataset import MapLocDataset | |
from ..sequential import chunk_sequence | |
from ..torch import collate, worker_init_fn | |
from ..schema import MIADataConfiguration | |
def pack_dump_dict(dump): | |
for per_seq in dump.values(): | |
if "points" in per_seq: | |
for chunk in list(per_seq["points"]): | |
points = per_seq["points"].pop(chunk) | |
if points is not None: | |
per_seq["points"][chunk] = np.array( | |
per_seq["points"][chunk], np.float64 | |
) | |
for view in per_seq["views"].values(): | |
for k in ["R_c2w", "roll_pitch_yaw"]: | |
view[k] = np.array(view[k], np.float32) | |
for k in ["chunk_id"]: | |
if k in view: | |
view.pop(k) | |
if "observations" in view: | |
view["observations"] = np.array(view["observations"]) | |
for camera in per_seq["cameras"].values(): | |
for k in ["params"]: | |
camera[k] = np.array(camera[k], np.float32) | |
return dump | |
class MapillaryDataModule(pl.LightningDataModule): | |
dump_filename = "dump.json" | |
images_archive = "images.tar.gz" | |
images_dirname = "images/" | |
semantic_masks_dirname = "semantic_masks/" | |
flood_dirname = "flood_fill/" | |
def __init__(self, cfg: MIADataConfiguration): | |
super().__init__() | |
self.cfg = cfg | |
self.root = self.cfg.data_dir | |
self.local_dir = None | |
def prepare_data(self): | |
for scene in self.cfg.scenes: | |
dump_dir = self.root / scene | |
assert (dump_dir / self.dump_filename).exists(), dump_dir | |
# assert (dump_dir / self.cfg.tiles_filename).exists(), dump_dir | |
if self.local_dir is None: | |
assert (dump_dir / self.images_dirname).exists(), dump_dir | |
continue | |
assert (dump_dir / self.semantic_masks_dirname).exists(), dump_dir | |
assert (dump_dir / self.flood_dirname).exists(), dump_dir | |
# Cache the folder of images locally to speed up reading | |
local_dir = self.local_dir / scene | |
if local_dir.exists(): | |
shutil.rmtree(local_dir) | |
local_dir.mkdir(exist_ok=True, parents=True) | |
images_archive = dump_dir / self.images_archive | |
logger.info("Extracting the image archive %s.", images_archive) | |
with tarfile.open(images_archive) as fp: | |
fp.extractall(local_dir) | |
def setup(self, stage: Optional[str] = None): | |
self.dumps = {} | |
# self.tile_managers = {} | |
self.image_dirs = {} | |
self.seg_masks_dir = {} | |
self.flood_masks_dir = {} | |
names = [] | |
for scene in self.cfg.scenes: | |
logger.info("Loading scene %s.", scene) | |
dump_dir = self.root / scene | |
logger.info("Loading dump json file %s.", self.dump_filename) | |
with (dump_dir / self.dump_filename).open("r") as fp: | |
self.dumps[scene] = pack_dump_dict(json.load(fp)) | |
for seq, per_seq in self.dumps[scene].items(): | |
for cam_id, cam_dict in per_seq["cameras"].items(): | |
if cam_dict["model"] != "PINHOLE": | |
raise ValueError( | |
f"Unsupported camera model: {cam_dict['model']} for {scene},{seq},{cam_id}" | |
) | |
self.image_dirs[scene] = ( | |
(self.local_dir or self.root) / scene / self.images_dirname | |
) | |
assert self.image_dirs[scene].exists(), self.image_dirs[scene] | |
self.seg_masks_dir[scene] = ( | |
(self.local_dir or self.root) / scene / self.semantic_masks_dirname | |
) | |
assert self.seg_masks_dir[scene].exists(), self.seg_masks_dir[scene] | |
self.flood_masks_dir[scene] = ( | |
(self.local_dir or self.root) / scene / self.flood_dirname | |
) | |
assert self.flood_masks_dir[scene].exists(), self.flood_masks_dir[scene] | |
images = set(x.split('.')[0] for x in os.listdir(self.image_dirs[scene])) | |
flood_masks = set(x.split('.')[0] for x in os.listdir(self.flood_masks_dir[scene])) | |
semantic_masks = set(x.split('.')[0] for x in os.listdir(self.seg_masks_dir[scene])) | |
for seq, data in self.dumps[scene].items(): | |
for name in data["views"]: | |
if name in images and name.split("_")[0] in flood_masks and name.split("_")[0] in semantic_masks: | |
names.append((scene, seq, name)) | |
self.parse_splits(self.cfg.split, names) | |
if self.cfg.filter_for is not None: | |
self.filter_elements() | |
self.pack_data() | |
def pack_data(self): | |
# We pack the data into compact tensors that can be shared across processes without copy | |
exclude = { | |
"compass_angle", | |
"compass_accuracy", | |
"gps_accuracy", | |
"chunk_key", | |
"panorama_offset", | |
} | |
cameras = { | |
scene: {seq: per_seq["cameras"] for seq, per_seq in per_scene.items()} | |
for scene, per_scene in self.dumps.items() | |
} | |
points = { | |
scene: { | |
seq: { | |
i: torch.from_numpy(p) for i, p in per_seq.get("points", {}).items() | |
} | |
for seq, per_seq in per_scene.items() | |
} | |
for scene, per_scene in self.dumps.items() | |
} | |
self.data = {} | |
# TODO: remove | |
if self.cfg.split == "splits_MGL_13loc.json": | |
# Use Last 20% as Val | |
num_samples_to_move = int(len(self.splits['train']) * 0.2) | |
samples_to_move = self.splits['train'][-num_samples_to_move:] | |
self.splits['val'].extend(samples_to_move) | |
self.splits['train'] = self.splits['train'][:-num_samples_to_move] | |
print(f"Dataset Len: {len(self.splits['train']), len(self.splits['val'])}\n\n\n\n") | |
elif self.cfg.split == "splits_MGL_soma_70k_mappred_random.json": | |
for stage, names in self.splits.items(): | |
print("Length of splits {}: ".format(stage), len(self.splits[stage])) | |
for stage, names in self.splits.items(): | |
view = self.dumps[names[0][0]][names[0][1]]["views"][names[0][2]] | |
data = {k: [] for k in view.keys() - exclude} | |
for scene, seq, name in names: | |
for k in data: | |
data[k].append(self.dumps[scene][seq]["views"][name].get(k, None)) | |
for k in data: | |
v = np.array(data[k]) | |
if np.issubdtype(v.dtype, np.integer) or np.issubdtype( | |
v.dtype, np.floating | |
): | |
v = torch.from_numpy(v) | |
data[k] = v | |
data["cameras"] = cameras | |
data["points"] = points | |
self.data[stage] = data | |
self.splits[stage] = np.array(names) | |
def filter_elements(self): | |
for stage, names in self.splits.items(): | |
names_select = [] | |
for scene, seq, name in names: | |
view = self.dumps[scene][seq]["views"][name] | |
if self.cfg.filter_for == "ground_plane": | |
if not (1.0 <= view["height"] <= 3.0): | |
continue | |
planes = self.dumps[scene][seq].get("plane") | |
if planes is not None: | |
inliers = planes[str(view["chunk_id"])][-1] | |
if inliers < 10: | |
continue | |
if self.cfg.filter_by_ground_angle is not None: | |
plane = np.array(view["plane_params"]) | |
normal = plane[:3] / np.linalg.norm(plane[:3]) | |
angle = np.rad2deg(np.arccos(np.abs(normal[-1]))) | |
if angle > self.cfg.filter_by_ground_angle: | |
continue | |
elif self.cfg.filter_for == "pointcloud": | |
if len(view["observations"]) < self.cfg.min_num_points: | |
continue | |
elif self.cfg.filter_for is not None: | |
raise ValueError(f"Unknown filtering: {self.cfg.filter_for}") | |
names_select.append((scene, seq, name)) | |
logger.info( | |
"%s: Keep %d/%d images after filtering for %s.", | |
stage, | |
len(names_select), | |
len(names), | |
self.cfg.filter_for, | |
) | |
self.splits[stage] = names_select | |
def parse_splits(self, split_arg, names): | |
if split_arg is None: | |
self.splits = { | |
"train": names, | |
"val": names, | |
} | |
elif isinstance(split_arg, int): | |
names = np.random.RandomState(self.cfg.seed).permutation(names).tolist() | |
self.splits = { | |
"train": names[split_arg:], | |
"val": names[:split_arg], | |
} | |
elif isinstance(split_arg, float): | |
names = np.random.RandomState(self.cfg.seed).permutation(names).tolist() | |
self.splits = { | |
"train": names[int(split_arg * len(names)) :], | |
"val": names[: int(split_arg * len(names))], | |
} | |
elif isinstance(split_arg, DictConfig): | |
scenes_val = set(split_arg.val) | |
scenes_train = set(split_arg.train) | |
assert len(scenes_val - set(self.cfg.scenes)) == 0 | |
assert len(scenes_train - set(self.cfg.scenes)) == 0 | |
self.splits = { | |
"train": [n for n in names if n[0] in scenes_train], | |
"val": [n for n in names if n[0] in scenes_val], | |
} | |
elif isinstance(split_arg, str): | |
if "/" in split_arg: | |
split_path = self.root / split_arg | |
else: | |
split_path = Path(split_arg) | |
with split_path.open("r") as fp: | |
splits = json.load(fp) | |
splits = { | |
k: {loc: set(ids) for loc, ids in split.items()} | |
for k, split in splits.items() | |
} | |
self.splits = {} | |
for k, split in splits.items(): | |
self.splits[k] = [ | |
n | |
for n in names | |
if n[0] in split and int(n[-1].rsplit("_", 1)[0]) in split[n[0]] | |
] | |
else: | |
raise ValueError(split_arg) | |
def dataset(self, stage: str): | |
return MapLocDataset( | |
stage, | |
self.cfg, | |
self.splits[stage], | |
self.data[stage], | |
self.image_dirs, | |
self.seg_masks_dir, | |
self.flood_masks_dir, | |
image_ext=".jpg", | |
) | |
def sequence_dataset(self, stage: str, **kwargs): | |
keys = self.splits[stage] | |
seq2indices = defaultdict(list) | |
for index, (_, seq, _) in enumerate(keys): | |
seq2indices[seq].append(index) | |
# chunk the sequences to the required length | |
chunk2indices = {} | |
for seq, indices in seq2indices.items(): | |
chunks = chunk_sequence(self.data[stage], indices, **kwargs) | |
for i, sub_indices in enumerate(chunks): | |
chunk2indices[seq, i] = sub_indices | |
# store the index of each chunk in its sequence | |
chunk_indices = torch.full((len(keys),), -1) | |
for (_, chunk_index), idx in chunk2indices.items(): | |
chunk_indices[idx] = chunk_index | |
self.data[stage]["chunk_index"] = chunk_indices | |
dataset = self.dataset(stage) | |
return dataset, chunk2indices | |
def sequence_dataloader(self, stage: str, shuffle: bool = False, **kwargs): | |
dataset, chunk2idx = self.sequence_dataset(stage, **kwargs) | |
chunk_keys = sorted(chunk2idx) | |
if shuffle: | |
perm = torch.randperm(len(chunk_keys)) | |
chunk_keys = [chunk_keys[i] for i in perm] | |
key_indices = [i for key in chunk_keys for i in chunk2idx[key]] | |
num_workers = self.cfg.loading[stage]["num_workers"] | |
loader = torchdata.DataLoader( | |
dataset, | |
batch_size=None, | |
sampler=key_indices, | |
num_workers=num_workers, | |
shuffle=False, | |
pin_memory=True, | |
persistent_workers=num_workers > 0, | |
worker_init_fn=worker_init_fn, | |
collate_fn=collate, | |
) | |
return loader, chunk_keys, chunk2idx | |