hocherie
add files
4187c6f
import json
from collections import defaultdict
import os
import shutil
import tarfile
from pathlib import Path
from typing import Optional
import numpy as np
import pytorch_lightning as pl
import torch
import torch.utils.data as torchdata
from omegaconf import DictConfig
from ... import logger
from .dataset import MapLocDataset
from ..sequential import chunk_sequence
from ..torch import collate, worker_init_fn
from ..schema import MIADataConfiguration
def pack_dump_dict(dump):
for per_seq in dump.values():
if "points" in per_seq:
for chunk in list(per_seq["points"]):
points = per_seq["points"].pop(chunk)
if points is not None:
per_seq["points"][chunk] = np.array(
per_seq["points"][chunk], np.float64
)
for view in per_seq["views"].values():
for k in ["R_c2w", "roll_pitch_yaw"]:
view[k] = np.array(view[k], np.float32)
for k in ["chunk_id"]:
if k in view:
view.pop(k)
if "observations" in view:
view["observations"] = np.array(view["observations"])
for camera in per_seq["cameras"].values():
for k in ["params"]:
camera[k] = np.array(camera[k], np.float32)
return dump
class MapillaryDataModule(pl.LightningDataModule):
dump_filename = "dump.json"
images_archive = "images.tar.gz"
images_dirname = "images/"
semantic_masks_dirname = "semantic_masks/"
flood_dirname = "flood_fill/"
def __init__(self, cfg: MIADataConfiguration):
super().__init__()
self.cfg = cfg
self.root = self.cfg.data_dir
self.local_dir = None
def prepare_data(self):
for scene in self.cfg.scenes:
dump_dir = self.root / scene
assert (dump_dir / self.dump_filename).exists(), dump_dir
# assert (dump_dir / self.cfg.tiles_filename).exists(), dump_dir
if self.local_dir is None:
assert (dump_dir / self.images_dirname).exists(), dump_dir
continue
assert (dump_dir / self.semantic_masks_dirname).exists(), dump_dir
assert (dump_dir / self.flood_dirname).exists(), dump_dir
# Cache the folder of images locally to speed up reading
local_dir = self.local_dir / scene
if local_dir.exists():
shutil.rmtree(local_dir)
local_dir.mkdir(exist_ok=True, parents=True)
images_archive = dump_dir / self.images_archive
logger.info("Extracting the image archive %s.", images_archive)
with tarfile.open(images_archive) as fp:
fp.extractall(local_dir)
def setup(self, stage: Optional[str] = None):
self.dumps = {}
# self.tile_managers = {}
self.image_dirs = {}
self.seg_masks_dir = {}
self.flood_masks_dir = {}
names = []
for scene in self.cfg.scenes:
logger.info("Loading scene %s.", scene)
dump_dir = self.root / scene
logger.info("Loading dump json file %s.", self.dump_filename)
with (dump_dir / self.dump_filename).open("r") as fp:
self.dumps[scene] = pack_dump_dict(json.load(fp))
for seq, per_seq in self.dumps[scene].items():
for cam_id, cam_dict in per_seq["cameras"].items():
if cam_dict["model"] != "PINHOLE":
raise ValueError(
f"Unsupported camera model: {cam_dict['model']} for {scene},{seq},{cam_id}"
)
self.image_dirs[scene] = (
(self.local_dir or self.root) / scene / self.images_dirname
)
assert self.image_dirs[scene].exists(), self.image_dirs[scene]
self.seg_masks_dir[scene] = (
(self.local_dir or self.root) / scene / self.semantic_masks_dirname
)
assert self.seg_masks_dir[scene].exists(), self.seg_masks_dir[scene]
self.flood_masks_dir[scene] = (
(self.local_dir or self.root) / scene / self.flood_dirname
)
assert self.flood_masks_dir[scene].exists(), self.flood_masks_dir[scene]
images = set(x.split('.')[0] for x in os.listdir(self.image_dirs[scene]))
flood_masks = set(x.split('.')[0] for x in os.listdir(self.flood_masks_dir[scene]))
semantic_masks = set(x.split('.')[0] for x in os.listdir(self.seg_masks_dir[scene]))
for seq, data in self.dumps[scene].items():
for name in data["views"]:
if name in images and name.split("_")[0] in flood_masks and name.split("_")[0] in semantic_masks:
names.append((scene, seq, name))
self.parse_splits(self.cfg.split, names)
if self.cfg.filter_for is not None:
self.filter_elements()
self.pack_data()
def pack_data(self):
# We pack the data into compact tensors that can be shared across processes without copy
exclude = {
"compass_angle",
"compass_accuracy",
"gps_accuracy",
"chunk_key",
"panorama_offset",
}
cameras = {
scene: {seq: per_seq["cameras"] for seq, per_seq in per_scene.items()}
for scene, per_scene in self.dumps.items()
}
points = {
scene: {
seq: {
i: torch.from_numpy(p) for i, p in per_seq.get("points", {}).items()
}
for seq, per_seq in per_scene.items()
}
for scene, per_scene in self.dumps.items()
}
self.data = {}
# TODO: remove
if self.cfg.split == "splits_MGL_13loc.json":
# Use Last 20% as Val
num_samples_to_move = int(len(self.splits['train']) * 0.2)
samples_to_move = self.splits['train'][-num_samples_to_move:]
self.splits['val'].extend(samples_to_move)
self.splits['train'] = self.splits['train'][:-num_samples_to_move]
print(f"Dataset Len: {len(self.splits['train']), len(self.splits['val'])}\n\n\n\n")
elif self.cfg.split == "splits_MGL_soma_70k_mappred_random.json":
for stage, names in self.splits.items():
print("Length of splits {}: ".format(stage), len(self.splits[stage]))
for stage, names in self.splits.items():
view = self.dumps[names[0][0]][names[0][1]]["views"][names[0][2]]
data = {k: [] for k in view.keys() - exclude}
for scene, seq, name in names:
for k in data:
data[k].append(self.dumps[scene][seq]["views"][name].get(k, None))
for k in data:
v = np.array(data[k])
if np.issubdtype(v.dtype, np.integer) or np.issubdtype(
v.dtype, np.floating
):
v = torch.from_numpy(v)
data[k] = v
data["cameras"] = cameras
data["points"] = points
self.data[stage] = data
self.splits[stage] = np.array(names)
def filter_elements(self):
for stage, names in self.splits.items():
names_select = []
for scene, seq, name in names:
view = self.dumps[scene][seq]["views"][name]
if self.cfg.filter_for == "ground_plane":
if not (1.0 <= view["height"] <= 3.0):
continue
planes = self.dumps[scene][seq].get("plane")
if planes is not None:
inliers = planes[str(view["chunk_id"])][-1]
if inliers < 10:
continue
if self.cfg.filter_by_ground_angle is not None:
plane = np.array(view["plane_params"])
normal = plane[:3] / np.linalg.norm(plane[:3])
angle = np.rad2deg(np.arccos(np.abs(normal[-1])))
if angle > self.cfg.filter_by_ground_angle:
continue
elif self.cfg.filter_for == "pointcloud":
if len(view["observations"]) < self.cfg.min_num_points:
continue
elif self.cfg.filter_for is not None:
raise ValueError(f"Unknown filtering: {self.cfg.filter_for}")
names_select.append((scene, seq, name))
logger.info(
"%s: Keep %d/%d images after filtering for %s.",
stage,
len(names_select),
len(names),
self.cfg.filter_for,
)
self.splits[stage] = names_select
def parse_splits(self, split_arg, names):
if split_arg is None:
self.splits = {
"train": names,
"val": names,
}
elif isinstance(split_arg, int):
names = np.random.RandomState(self.cfg.seed).permutation(names).tolist()
self.splits = {
"train": names[split_arg:],
"val": names[:split_arg],
}
elif isinstance(split_arg, float):
names = np.random.RandomState(self.cfg.seed).permutation(names).tolist()
self.splits = {
"train": names[int(split_arg * len(names)) :],
"val": names[: int(split_arg * len(names))],
}
elif isinstance(split_arg, DictConfig):
scenes_val = set(split_arg.val)
scenes_train = set(split_arg.train)
assert len(scenes_val - set(self.cfg.scenes)) == 0
assert len(scenes_train - set(self.cfg.scenes)) == 0
self.splits = {
"train": [n for n in names if n[0] in scenes_train],
"val": [n for n in names if n[0] in scenes_val],
}
elif isinstance(split_arg, str):
if "/" in split_arg:
split_path = self.root / split_arg
else:
split_path = Path(split_arg)
with split_path.open("r") as fp:
splits = json.load(fp)
splits = {
k: {loc: set(ids) for loc, ids in split.items()}
for k, split in splits.items()
}
self.splits = {}
for k, split in splits.items():
self.splits[k] = [
n
for n in names
if n[0] in split and int(n[-1].rsplit("_", 1)[0]) in split[n[0]]
]
else:
raise ValueError(split_arg)
def dataset(self, stage: str):
return MapLocDataset(
stage,
self.cfg,
self.splits[stage],
self.data[stage],
self.image_dirs,
self.seg_masks_dir,
self.flood_masks_dir,
image_ext=".jpg",
)
def sequence_dataset(self, stage: str, **kwargs):
keys = self.splits[stage]
seq2indices = defaultdict(list)
for index, (_, seq, _) in enumerate(keys):
seq2indices[seq].append(index)
# chunk the sequences to the required length
chunk2indices = {}
for seq, indices in seq2indices.items():
chunks = chunk_sequence(self.data[stage], indices, **kwargs)
for i, sub_indices in enumerate(chunks):
chunk2indices[seq, i] = sub_indices
# store the index of each chunk in its sequence
chunk_indices = torch.full((len(keys),), -1)
for (_, chunk_index), idx in chunk2indices.items():
chunk_indices[idx] = chunk_index
self.data[stage]["chunk_index"] = chunk_indices
dataset = self.dataset(stage)
return dataset, chunk2indices
def sequence_dataloader(self, stage: str, shuffle: bool = False, **kwargs):
dataset, chunk2idx = self.sequence_dataset(stage, **kwargs)
chunk_keys = sorted(chunk2idx)
if shuffle:
perm = torch.randperm(len(chunk_keys))
chunk_keys = [chunk_keys[i] for i in perm]
key_indices = [i for key in chunk_keys for i in chunk2idx[key]]
num_workers = self.cfg.loading[stage]["num_workers"]
loader = torchdata.DataLoader(
dataset,
batch_size=None,
sampler=key_indices,
num_workers=num_workers,
shuffle=False,
pin_memory=True,
persistent_workers=num_workers > 0,
worker_init_fn=worker_init_fn,
collate_fn=collate,
)
return loader, chunk_keys, chunk2idx