Luuu / torch_lydorn /torchvision /datasets /mapping_challenge.py
็™ฝ้นญๅ…ˆ็”Ÿ
init
abd2a81
raw
history blame
14.1 kB
import os
import pathlib
import warnings
import skimage.io
from multiprocess import Pool
from functools import partial
import numpy as np
from pycocotools.coco import COCO
import shapely.geometry
from tqdm import tqdm
import torch
from lydorn_utils import print_utils
from lydorn_utils import python_utils
from torch_lydorn.torch.utils.data import Dataset as LydornDataset, makedirs, files_exist, __repr__
from torch_lydorn.torchvision.datasets import utils
class MappingChallenge(LydornDataset):
def __init__(self, root, transform=None, pre_transform=None, fold="train", small=False, pool_size=1):
assert fold in ["train", "val", "test_images"], "Input fold={} should be in [\"train\", \"val\", \"test_images\"]".format(fold)
if fold == "test_images":
print_utils.print_error("ERROR: fold {} not yet implemented!".format(fold))
exit()
self.root = root
self.fold = fold
makedirs(self.processed_dir)
self.small = small
if self.small:
print_utils.print_info("INFO: Using small version of the Mapping challenge dataset.")
self.pool_size = pool_size
self.coco = None
self.image_id_list = self.load_image_ids()
self.stats_filepath = os.path.join(self.processed_dir, "stats.pt")
self.stats = None
if os.path.exists(self.stats_filepath):
self.stats = torch.load(self.stats_filepath)
self.processed_flag_filepath = os.path.join(self.processed_dir, "processed-flag-small" if self.small else "processed-flag")
super(MappingChallenge, self).__init__(root, transform, pre_transform)
def load_image_ids(self):
image_id_list_filepath = os.path.join(self.processed_dir, "image_id_list-small.json" if self.small else "image_id_list.json")
if os.path.exists(image_id_list_filepath):
image_id_list = python_utils.load_json(image_id_list_filepath)
else:
coco = self.get_coco()
image_id_list = coco.getImgIds(catIds=coco.getCatIds())
# Save for later so that the whole coco object doesn't have to be instantiated when just reading processed samples with multiple workers:
python_utils.save_json(image_id_list_filepath, image_id_list)
return image_id_list
def get_coco(self):
if self.coco is None:
annotation_filename = "annotation-small.json" if self.small else "annotation.json"
annotations_filepath = os.path.join(self.raw_dir, self.fold, annotation_filename)
self.coco = COCO(annotations_filepath)
return self.coco
@property
def processed_dir(self):
return os.path.join(self.root, 'processed', self.fold)
@property
def processed_file_names(self):
l = []
for image_id in self.image_id_list:
l.append(os.path.join("data_{:012d}.pt".format(image_id)))
return l
def __len__(self):
return len(self.image_id_list)
def _download(self):
pass
def download(self):
pass
def _process(self):
f = os.path.join(self.processed_dir, 'pre_transform.pt')
if os.path.exists(f) and torch.load(f) != __repr__(self.pre_transform):
warnings.warn(
'The `pre_transform` argument differs from the one used in '
'the pre-processed version of this dataset. If you really '
'want to make use of another pre-processing technique, make '
'sure to delete `{}` first.'.format(self.processed_dir))
f = os.path.join(self.processed_dir, 'pre_filter.pt')
if os.path.exists(f) and torch.load(f) != __repr__(self.pre_filter):
warnings.warn(
'The `pre_filter` argument differs from the one used in the '
'pre-processed version of this dataset. If you really want to '
'make use of another pre-fitering technique, make sure to '
'delete `{}` first.'.format(self.processed_dir))
if os.path.exists(self.processed_flag_filepath):
return
print('Processing...')
makedirs(self.processed_dir)
self.process()
path = os.path.join(self.processed_dir, 'pre_transform.pt')
torch.save(__repr__(self.pre_transform), path)
path = os.path.join(self.processed_dir, 'pre_filter.pt')
torch.save(__repr__(self.pre_filter), path)
print('Done!')
def process(self):
images_relative_dirpath = os.path.join("raw", self.fold, "images")
image_info_list = []
coco = self.get_coco()
for image_id in self.image_id_list:
filename = coco.loadImgs(image_id)[0]["file_name"]
annotation_ids = coco.getAnnIds(imgIds=image_id)
annotation_list = coco.loadAnns(annotation_ids)
image_info = {
"image_id": image_id,
"image_filepath": os.path.join(self.root, images_relative_dirpath, filename),
"image_relative_filepath": os.path.join(images_relative_dirpath, filename),
"annotation_list": annotation_list
}
image_info_list.append(image_info)
partial_preprocess_one = partial(preprocess_one, pre_filter=self.pre_filter, pre_transform=self.pre_transform,
processed_dir=self.processed_dir)
with Pool(self.pool_size) as p:
sample_stats_list = list(tqdm(p.imap(partial_preprocess_one, image_info_list), total=len(image_info_list)))
# Aggregate sample_stats_list
image_s0_list, image_s1_list, image_s2_list, class_freq_list = zip(*sample_stats_list)
image_s0_array = np.stack(image_s0_list, axis=0)
image_s1_array = np.stack(image_s1_list, axis=0)
image_s2_array = np.stack(image_s2_list, axis=0)
class_freq_array = np.stack(class_freq_list, axis=0)
image_s0_total = np.sum(image_s0_array, axis=0)
image_s1_total = np.sum(image_s1_array, axis=0)
image_s2_total = np.sum(image_s2_array, axis=0)
image_mean = image_s1_total / image_s0_total
image_std = np.sqrt(image_s2_total/image_s0_total - np.power(image_mean, 2))
class_freq = np.sum(class_freq_array*image_s0_array[:, None], axis=0) / image_s0_total
# Save aggregated stats
self.stats = {
"image_mean": image_mean,
"image_std": image_std,
"class_freq": class_freq,
}
torch.save(self.stats, self.stats_filepath)
# Indicates that processing has been performed:
pathlib.Path(self.processed_flag_filepath).touch()
def get(self, idx):
image_id = self.image_id_list[idx]
data = torch.load(os.path.join(self.processed_dir, "data_{:012d}.pt".format(image_id)))
data["image_mean"] = self.stats["image_mean"]
data["image_std"] = self.stats["image_std"]
data["class_freq"] = self.stats["class_freq"]
return data
def preprocess_one(image_info, pre_filter, pre_transform, processed_dir):
out_filepath = os.path.join(processed_dir, "data_{:012d}.pt".format(image_info["image_id"]))
data = None
if os.path.exists(out_filepath):
# Load already-processed sample
try:
data = torch.load(out_filepath)
except EOFError:
pass
if data is None:
# Process sample:
image = skimage.io.imread(image_info["image_filepath"])
gt_polygons = []
for annotation in image_info["annotation_list"]:
flattened_segmentation_list = annotation["segmentation"]
if len(flattened_segmentation_list) != 1:
print("WHAT!?!, len(flattened_segmentation_list = {}".format(len(flattened_segmentation_list)))
print("To implement: if more than one segmentation in flattened_segmentation_list (MS COCO format), does it mean it is a MultiPolygon or a Polygon with holes?")
raise NotImplementedError
flattened_arrays = np.array(flattened_segmentation_list)
coords = np.reshape(flattened_arrays, (-1, 2))
polygon = shapely.geometry.Polygon(coords)
# Filter out degenerate polygons (area is lower than 2.0)
if 2.0 < polygon.area:
gt_polygons.append(polygon)
data = {
"image": image,
"gt_polygons": gt_polygons,
"image_relative_filepath": image_info["image_relative_filepath"],
"name": os.path.splitext(os.path.basename(image_info["image_relative_filepath"]))[0],
"image_id": image_info["image_id"]
}
if pre_filter is not None and not pre_filter(data):
return
if pre_transform is not None:
data = pre_transform(data)
# masked_angles = data["gt_crossfield_angle"].astype(np.float) * data["gt_polygons_image"][:, :, 1].astype(np.float)
# skimage.io.imsave("gt_crossfield_angle.png", data["gt_crossfield_angle"])
# skimage.io.imsave("masked_angles.png", masked_angles)
# exit()
torch.save(data, out_filepath)
# Compute stats for later aggregation for the whole dataset
normed_image = data["image"] / 255
image_s0 = data["image"].shape[0] * data["image"].shape[1] # Number of pixels
image_s1 = np.sum(normed_image, axis=(0, 1)) # Sum of pixel normalized values
image_s2 = np.sum(np.power(normed_image, 2), axis=(0, 1))
class_freq = np.mean(data["gt_polygons_image"], axis=(0, 1)) / 255
return image_s0, image_s1, image_s2, class_freq
def main():
# Test using transforms from the frame_field_learning project:
from frame_field_learning import data_transforms
config = {
"data_dir_candidates": [
"/data/titane/user/nigirard/data",
"~/data",
"/data"
],
"dataset_params": {
"small": True,
"root_dirname": "mapping_challenge_dataset",
"seed": 0,
"train_fraction": 0.75
},
"num_workers": 8,
"data_aug_params": {
"enable": False,
"vflip": True,
"affine": True,
"color_jitter": True,
"device": "cuda"
}
}
# Find data_dir
data_dir = python_utils.choose_first_existing_path(config["data_dir_candidates"])
if data_dir is None:
print_utils.print_error("ERROR: Data directory not found!")
exit()
else:
print_utils.print_info("Using data from {}".format(data_dir))
root_dir = os.path.join(data_dir, config["dataset_params"]["root_dirname"])
# --- Transforms: --- #
# --- pre-processing transform (done once then saved on disk):
# --- Online transform done on the host (CPU):
train_online_cpu_transform = data_transforms.get_online_cpu_transform(config,
augmentations=config["data_aug_params"][
"enable"])
test_online_cpu_transform = data_transforms.get_eval_online_cpu_transform()
train_online_cuda_transform = data_transforms.get_online_cuda_transform(config,
augmentations=config["data_aug_params"][
"enable"])
# --- --- #
dataset = MappingChallenge(root_dir,
transform=test_online_cpu_transform,
pre_transform=data_transforms.get_offline_transform_patch(),
fold="train",
small=config["dataset_params"]["small"],
pool_size=config["num_workers"])
print("# --- Sample 0 --- #")
sample = dataset[0]
print(sample.keys())
for key, item in sample.items():
print("{}: {}".format(key, type(item)))
print(sample["image"].shape)
print(len(sample["gt_polygons_image"]))
print("# --- Samples --- #")
# for data in tqdm(dataset):
# pass
data_loader = torch.utils.data.DataLoader(dataset, batch_size=10, shuffle=True, num_workers=config["num_workers"])
print("# --- Batches --- #")
for batch in tqdm(data_loader):
print("Images:")
print(batch["image_relative_filepath"])
print(batch["image"].shape)
print(batch["gt_polygons_image"].shape)
print("Apply online tranform:")
batch = utils.batch_to_cuda(batch)
batch = train_online_cuda_transform(batch)
batch = utils.batch_to_cpu(batch)
print(batch["image"].shape)
print(batch["gt_polygons_image"].shape)
# Save output to visualize
seg = np.array(batch["gt_polygons_image"][0])
seg = np.moveaxis(seg, 0, -1)
seg_display = utils.get_seg_display(seg)
seg_display = (seg_display * 255).astype(np.uint8)
skimage.io.imsave("gt_seg.png", seg_display)
skimage.io.imsave("gt_seg_edge.png", seg[:, :, 1])
im = np.array(batch["image"][0])
im = np.moveaxis(im, 0, -1)
skimage.io.imsave('im.png', im)
gt_crossfield_angle = np.array(batch["gt_crossfield_angle"][0])
gt_crossfield_angle = np.moveaxis(gt_crossfield_angle, 0, -1)
skimage.io.imsave('gt_crossfield_angle.png', gt_crossfield_angle)
distances = np.array(batch["distances"][0])
distances = np.moveaxis(distances, 0, -1)
skimage.io.imsave('distances.png', distances)
sizes = np.array(batch["sizes"][0])
sizes = np.moveaxis(sizes, 0, -1)
skimage.io.imsave('sizes.png', sizes)
# valid_mask = np.array(batch["valid_mask"][0])
# valid_mask = np.moveaxis(valid_mask, 0, -1)
# skimage.io.imsave('valid_mask.png', valid_mask)
input("Press enter to continue...")
if __name__ == '__main__':
main()