File size: 11,546 Bytes
abd2a81 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 |
import os
import csv
from tqdm import tqdm
from multiprocess import Pool, Process, Queue
from functools import partial
import time
import torch
# from pytorch_memlab import profile, profile_every
from . import inference, save_utils, polygonize
from . import local_utils
from . import measures
from lydorn_utils import run_utils
from lydorn_utils import python_utils
from lydorn_utils import print_utils
from lydorn_utils import async_utils
class Evaluator:
def __init__(self, gpu: int, config: dict, shared_dict, barrier, model, run_dirpath):
self.gpu = gpu
self.config = config
assert 0 < self.config["eval_params"]["batch_size_mult"], \
"batch_size_mult in polygonize_params should be at least 1."
self.shared_dict = shared_dict
self.barrier = barrier
self.model = model
self.checkpoints_dirpath = run_utils.setup_run_subdir(run_dirpath,
self.eval_dirpath = os.path.join(config["data_root_dir"], "eval_runs", os.path.split(run_dirpath)[-1])
if self.gpu == 0:
os.makedirs(self.eval_dirpath, exist_ok=True)
print_utils.print_info("Saving eval outputs to {}".format(self.eval_dirpath))
# @profile
def evaluate(self, split_name: str, ds:
# Prepare data saving:
flag_filepath_format = os.path.join(self.eval_dirpath, split_name, "{}.flag")
# Loading model
# Create pool for multiprocessing
pool = None
if not self.config["eval_params"]["patch_size"]:
# If single image is not being split up, then a pool to process each sample in the batch makes sense
pool = Pool(processes=self.config["num_workers"])
compute_polygonization = self.config["eval_params"]["save_individual_outputs"]["poly_shapefile"] or \
self.config["eval_params"]["save_individual_outputs"]["poly_geojson"] or \
self.config["eval_params"]["save_individual_outputs"]["poly_viz"] or \
# Saving individual outputs to disk:
save_individual_outputs = True in self.config["eval_params"]["save_individual_outputs"].values()
saver_async = None
if save_individual_outputs:
save_outputs_partial = partial(save_utils.save_outputs, config=self.config, eval_dirpath=self.eval_dirpath,
split_name=split_name, flag_filepath_format=flag_filepath_format)
saver_async = async_utils.Async(save_outputs_partial)
# Saving aggregated outputs
save_aggregated_outputs = True in self.config["eval_params"]["save_aggregated_outputs"].values()
tile_data_list = []
if self.gpu == 0:
tile_iterator = tqdm(ds, desc="Eval {}: ".format(split_name), leave=True)
tile_iterator = ds
for tile_i, tile_data in enumerate(tile_iterator):
# --- Inference, add result to tile_data_list
if self.config["eval_params"]["patch_size"] is not None:
# Cut image into patches for inference
inference.inference_with_patching(self.config, self.model, tile_data)
# Feed images as-is to the model
inference.inference_no_patching(self.config, self.model, tile_data)
# --- Accumulate batches into tile_data_list until capacity is reached (or this is the last batch)
if self.config["eval_params"]["batch_size_mult"] <= len(tile_data_list)\
or tile_i == len(tile_iterator) - 1:
# Concat tensors of tile_data_list
accumulated_tile_data = {}
for key in tile_data_list[0].keys():
if isinstance(tile_data_list[0][key], list):
accumulated_tile_data[key] = [item for _tile_data in tile_data_list for item in _tile_data[key]]
elif isinstance(tile_data_list[0][key], torch.Tensor):
accumulated_tile_data[key] =[_tile_data[key] for _tile_data in tile_data_list], dim=0)
raise TypeError(f"Type {type(tile_data_list[0][key])} is not handled!")
tile_data_list = [] # Empty tile_data_list
# tile_data_list is not full yet, continue running inference...
# --- Polygonize
if compute_polygonization:
crossfield = accumulated_tile_data["crossfield"] if "crossfield" in accumulated_tile_data else None
accumulated_tile_data["polygons"], accumulated_tile_data["polygon_probs"] = polygonize.polygonize(
self.config["polygonize_params"], accumulated_tile_data["seg"],
# --- Save output
if self.config["eval_params"]["save_individual_outputs"]["seg_mask"] or \
# Take seg_interior:
seg_pred_mask = self.config["eval_params"]["seg_threshold"] < accumulated_tile_data["seg"][:, 0, ...]
accumulated_tile_data["seg_mask"] = seg_pred_mask
accumulated_tile_data = local_utils.batch_to_cpu(accumulated_tile_data)
sample_list = local_utils.split_batch(accumulated_tile_data)
# Save individual outputs:
if save_individual_outputs:
for sample in sample_list:
# Store aggregated outputs:
if save_aggregated_outputs:
if self.config["eval_params"]["save_aggregated_outputs"]["stats"]:
y_pred = accumulated_tile_data["seg"][:, 0, ...].cpu()
if "gt_mask" in accumulated_tile_data:
y_true = accumulated_tile_data["gt_mask"][:, 0, ...]
elif "gt_polygons_image" in accumulated_tile_data:
y_true = accumulated_tile_data["gt_polygons_image"][:, 0, ...]
raise ValueError("Either gt_mask or gt_polygons_image should be in accumulated_tile_data")
iou = measures.iou(y_pred.reshape(y_pred.shape[0], -1), y_true.reshape(y_true.shape[0], -1),
if self.config["eval_params"]["save_aggregated_outputs"]["seg_coco"]:
for sample in sample_list:
annotations = save_utils.seg_coco(sample)
if self.config["eval_params"]["save_aggregated_outputs"]["poly_coco"]:
for sample in sample_list:
annotations = save_utils.poly_coco(sample["polygons"], sample["polygon_probs"], sample["image_id"].item())
self.shared_dict["poly_coco_list"].append(annotations) # annotations could be a dict, or a list
# END of loop over samples
# Save aggregated results
if save_aggregated_outputs:
self.barrier.wait() # Wait on all processes so that shared_dict is synchronized.
if self.gpu == 0:
if self.config["eval_params"]["save_aggregated_outputs"]["stats"]:
print("Start saving stats:")
# Save sample_stats in CSV:
t1 = time.time()
stats_filepath = os.path.join(self.eval_dirpath, "{}.stats.csv".format(split_name))
stats_file = open(stats_filepath, "w")
fnames = ["name", "iou"]
writer = csv.DictWriter(stats_file, fieldnames=fnames)
for name, iou in sorted(zip(self.shared_dict["name_list"], self.shared_dict["iou_list"]), key=lambda pair: pair[0]):
"name": name,
"iou": iou
print(f"Finished in {time.time() - t1:02}s")
if self.config["eval_params"]["save_aggregated_outputs"]["seg_coco"]:
print("Start saving seg_coco:")
t1 = time.time()
seg_coco_filepath = os.path.join(self.eval_dirpath, "{}.annotation.seg.json".format(split_name))
python_utils.save_json(seg_coco_filepath, list(self.shared_dict["seg_coco_list"]))
print(f"Finished in {time.time() - t1:02}s")
if self.config["eval_params"]["save_aggregated_outputs"]["poly_coco"]:
print("Start saving poly_coco:")
poly_coco_base_filepath = os.path.join(self.eval_dirpath, f"{split_name}.annotation.poly")
t1 = time.time()
save_utils.save_poly_coco(self.shared_dict["poly_coco_list"], poly_coco_base_filepath)
print(f"Finished in {time.time() - t1:02}s")
# Sync point of individual outputs
if save_individual_outputs:
print_utils.print_info(f"GPU {self.gpu} -> INFO: Finishing saving individual outputs.")
self.barrier.wait() # Wait on all processes so that all saver_asyncs are finished
def load_checkpoint(self):
Loads best val checkpoint in checkpoints_dirpath
filepaths = python_utils.get_filepaths(self.checkpoints_dirpath, startswith_str="checkpoint.best_val.",
if len(filepaths):
filepaths = sorted(filepaths)
filepath = filepaths[-1] # Last best val checkpoint filepath in case there is more than one
if self.gpu == 0:
print_utils.print_info("Loading best val checkpoint: {}".format(filepath))
# No best val checkpoint fount: find last checkpoint:
filepaths = python_utils.get_filepaths(self.checkpoints_dirpath, endswith_str=".tar",
if len(filepaths) == 0:
raise FileNotFoundError("No checkpoint could be found at that location.")
filepaths = sorted(filepaths)
filepath = filepaths[-1] # Last checkpoint
if self.gpu == 0:
print_utils.print_info("Loading last checkpoint: {}".format(filepath))
# map_location is used to load on current device:
checkpoint = torch.load(filepath, map_location="cuda:{}".format(self.gpu))