File size: 8,034 Bytes
abd2a81 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 |
import sys
from tqdm import tqdm
import scipy
import numpy as np
import torch
from . import local_utils
from . import polygonize
from lydorn_utils import image_utils
from lydorn_utils import print_utils
from lydorn_utils import python_utils
def network_inference(config, model, batch):
if config['device'] == 'cuda':
batch = local_utils.batch_to_cuda(batch)
pred, batch = model(batch, tta=config["eval_params"]["test_time_augmentation"])
return pred, batch
def inference(config, model, tile_data, compute_polygonization=False, pool=None):
if config["eval_params"]["patch_size"] is not None:
# Cut image into patches for inference
inference_with_patching(config, model, tile_data)
single_sample = True
else:
# Feed images as-is to the model
inference_no_patching(config, model, tile_data)
single_sample = False
# Polygonize:
if compute_polygonization:
pool = None if single_sample else pool # A single big image is being processed
crossfield = tile_data["crossfield"] if "crossfield" in tile_data else None
polygons_batch, probs_batch = polygonize.polygonize(config["polygonize_params"], tile_data["seg"], crossfield_batch=crossfield,
pool=pool)
tile_data["polygons"] = polygons_batch
tile_data["polygon_probs"] = probs_batch
return tile_data
def inference_no_patching(config, model, tile_data):
with torch.no_grad():
batch = {
"image": tile_data["image"],
"image_mean": tile_data["image_mean"],
"image_std": tile_data["image_std"]
}
try:
pred, batch = network_inference(config, model, batch)
except RuntimeError as e:
print_utils.print_error("ERROR: " + str(e))
if 1 < config["optim_params"]["eval_batch_size"]:
print_utils.print_info("INFO: Try lowering the effective batch_size (which is {} currently). "
"Note that in eval mode, the effective bath_size is equal to double the batch_size "
"because gradients do not need to "
"be computed so double the memory is available. "
"You can override the effective batch_size with the --eval_batch_size command-line argument."
.format(config["optim_params"]["eval_batch_size"]))
else:
print_utils.print_info("INFO: The effective batch_size is 1 but the GPU still ran out of memory."
"You can specify parameters to split the image into patches for inference:\n"
"--eval_patch_size is the size of the patch and should be chosen as big as memory allows.\n"
"--eval_patch_overlap (optional, default=200) adds overlaps between patches to avoid border artifacts."
.format(config["optim_params"]["eval_batch_size"]))
sys.exit()
tile_data["seg"] = pred["seg"]
if "crossfield" in pred:
tile_data["crossfield"] = pred["crossfield"]
return tile_data
def inference_with_patching(config, model, tile_data):
assert len(tile_data["image"].shape) == 4 and tile_data["image"].shape[0] == 1, \
f"When using inference with patching, tile_data should have a batch size of 1, " \
f"with image's shape being (1, C, H, W), not {tile_data['image'].shape}"
with torch.no_grad():
# Init tile outputs (image is (N, C, H, W)):
height = tile_data["image"].shape[2]
width = tile_data["image"].shape[3]
seg_channels = config["seg_params"]["compute_interior"] \
+ config["seg_params"]["compute_edge"] \
+ config["seg_params"]["compute_vertex"]
if config["compute_seg"]:
tile_data["seg"] = torch.zeros((1, seg_channels, height, width), device=config["device"])
if config["compute_crossfield"]:
tile_data["crossfield"] = torch.zeros((1, 4, height, width), device=config["device"])
weight_map = torch.zeros((1, 1, height, width), device=config["device"]) # Count number of patches on top of each pixel
# Split tile in patches:
stride = config["eval_params"]["patch_size"] - config["eval_params"]["patch_overlap"]
patch_boundingboxes = image_utils.compute_patch_boundingboxes((height, width),
stride=stride,
patch_res=config["eval_params"]["patch_size"])
# Compute patch pixel weights to merge overlapping patches back together smoothly:
patch_weights = np.ones((config["eval_params"]["patch_size"] + 2, config["eval_params"]["patch_size"] + 2),
dtype=np.float)
patch_weights[0, :] = 0
patch_weights[-1, :] = 0
patch_weights[:, 0] = 0
patch_weights[:, -1] = 0
patch_weights = scipy.ndimage.distance_transform_edt(patch_weights)
patch_weights = patch_weights[1:-1, 1:-1]
patch_weights = torch.tensor(patch_weights, device=config["device"]).float()
patch_weights = patch_weights[None, None, :, :] # Adding batch and channels dims
# Predict on each patch and save in outputs:
for bbox in tqdm(patch_boundingboxes, desc="Running model on patches", leave=False):
# Crop data
batch = {
"image": tile_data["image"][:, :, bbox[0]:bbox[2], bbox[1]:bbox[3]],
"image_mean": tile_data["image_mean"],
"image_std": tile_data["image_std"],
}
# Send batch to device
try:
pred, batch = network_inference(config, model, batch)
except RuntimeError as e:
print_utils.print_error("ERROR: " + str(e))
print_utils.print_info("INFO: Reduce --eval_patch_size until the patch fits in memory.")
raise e
if config["compute_seg"]:
tile_data["seg"][:, :, bbox[0]:bbox[2], bbox[1]:bbox[3]] += patch_weights * pred["seg"]
if config["compute_crossfield"]:
tile_data["crossfield"][:, :, bbox[0]:bbox[2], bbox[1]:bbox[3]] += patch_weights * pred["crossfield"]
weight_map[:, :, bbox[0]:bbox[2], bbox[1]:bbox[3]] += patch_weights
# Take care of overlapping parts
if config["compute_seg"]:
tile_data["seg"] /= weight_map
if config["compute_crossfield"]:
tile_data["crossfield"] /= weight_map
return tile_data
def load_checkpoint(model, checkpoints_dirpath, device):
"""
Loads best val checkpoint in checkpoints_dirpath
"""
filepaths = python_utils.get_filepaths(checkpoints_dirpath, startswith_str="checkpoint.best_val.",
endswith_str=".tar")
if len(filepaths):
filepaths = sorted(filepaths)
filepath = filepaths[-1] # Last best val checkpoint filepath in case there is more than one
print_utils.print_info("Loading best val checkpoint: {}".format(filepath))
else:
# No best val checkpoint fount: find last checkpoint:
filepaths = python_utils.get_filepaths(checkpoints_dirpath, endswith_str=".tar",
startswith_str="checkpoint.")
filepaths = sorted(filepaths)
filepath = filepaths[-1] # Last checkpoint
print_utils.print_info("Loading last checkpoint: {}".format(filepath))
device = torch.device(device)
checkpoint = torch.load(filepath, map_location=device) # map_location is used to load on current device
model.load_state_dict(checkpoint['model_state_dict'])
return model
|