File size: 8,034 Bytes
abd2a81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
import sys

from tqdm import tqdm
import scipy

import numpy as np
import torch

from . import local_utils
from . import polygonize

from lydorn_utils import image_utils
from lydorn_utils import print_utils
from lydorn_utils import python_utils


def network_inference(config, model, batch):
    if config['device'] == 'cuda':
        batch = local_utils.batch_to_cuda(batch)
    pred, batch = model(batch, tta=config["eval_params"]["test_time_augmentation"])
    return pred, batch


def inference(config, model, tile_data, compute_polygonization=False, pool=None):
    if config["eval_params"]["patch_size"] is not None:
        # Cut image into patches for inference
        inference_with_patching(config, model, tile_data)
        single_sample = True
    else:
        # Feed images as-is to the model
        inference_no_patching(config, model, tile_data)
        single_sample = False

    # Polygonize:
    if compute_polygonization:
        pool = None if single_sample else pool  # A single big image is being processed
        crossfield = tile_data["crossfield"] if "crossfield" in tile_data else None
        polygons_batch, probs_batch = polygonize.polygonize(config["polygonize_params"], tile_data["seg"], crossfield_batch=crossfield,
                                         pool=pool)
        tile_data["polygons"] = polygons_batch
        tile_data["polygon_probs"] = probs_batch

    return tile_data


def inference_no_patching(config, model, tile_data):
    with torch.no_grad():
        batch = {
            "image": tile_data["image"],
            "image_mean": tile_data["image_mean"],
            "image_std": tile_data["image_std"]
        }
        try:
            pred, batch = network_inference(config, model, batch)
        except RuntimeError as e:
            print_utils.print_error("ERROR: " + str(e))
            if 1 < config["optim_params"]["eval_batch_size"]:
                print_utils.print_info("INFO: Try lowering the effective batch_size (which is {} currently). "
                                       "Note that in eval mode, the effective bath_size is equal to double the batch_size "
                                       "because gradients do not need to "
                                       "be computed so double the memory is available. "
                                       "You can override the effective batch_size with the --eval_batch_size command-line argument."
                                       .format(config["optim_params"]["eval_batch_size"]))
            else:
                print_utils.print_info("INFO: The effective batch_size is 1 but the GPU still ran out of memory."
                                       "You can specify parameters to split the image into patches for inference:\n"
                                       "--eval_patch_size is the size of the patch and should be chosen as big as memory allows.\n"
                                       "--eval_patch_overlap (optional, default=200) adds overlaps between patches to avoid border artifacts."
                                       .format(config["optim_params"]["eval_batch_size"]))
            sys.exit()

        tile_data["seg"] = pred["seg"]
        if "crossfield" in pred:
            tile_data["crossfield"] = pred["crossfield"]

    return tile_data


def inference_with_patching(config, model, tile_data):
    assert len(tile_data["image"].shape) == 4 and tile_data["image"].shape[0] == 1, \
        f"When using inference with patching, tile_data should have a batch size of 1, " \
        f"with image's shape being (1, C, H, W), not {tile_data['image'].shape}"
    with torch.no_grad():
        # Init tile outputs (image is (N, C, H, W)):
        height = tile_data["image"].shape[2]
        width = tile_data["image"].shape[3]
        seg_channels = config["seg_params"]["compute_interior"] \
                       + config["seg_params"]["compute_edge"] \
                       + config["seg_params"]["compute_vertex"]
        if config["compute_seg"]:
            tile_data["seg"] = torch.zeros((1, seg_channels, height, width), device=config["device"])
        if config["compute_crossfield"]:
            tile_data["crossfield"] = torch.zeros((1, 4, height, width), device=config["device"])
        weight_map = torch.zeros((1, 1, height, width), device=config["device"])  # Count number of patches on top of each pixel

        # Split tile in patches:
        stride = config["eval_params"]["patch_size"] - config["eval_params"]["patch_overlap"]
        patch_boundingboxes = image_utils.compute_patch_boundingboxes((height, width),
                                                                      stride=stride,
                                                                      patch_res=config["eval_params"]["patch_size"])
        # Compute patch pixel weights to merge overlapping patches back together smoothly:
        patch_weights = np.ones((config["eval_params"]["patch_size"] + 2, config["eval_params"]["patch_size"] + 2),
                                dtype=np.float)
        patch_weights[0, :] = 0
        patch_weights[-1, :] = 0
        patch_weights[:, 0] = 0
        patch_weights[:, -1] = 0
        patch_weights = scipy.ndimage.distance_transform_edt(patch_weights)
        patch_weights = patch_weights[1:-1, 1:-1]
        patch_weights = torch.tensor(patch_weights, device=config["device"]).float()
        patch_weights = patch_weights[None, None, :, :]  # Adding batch and channels dims

        # Predict on each patch and save in outputs:
        for bbox in tqdm(patch_boundingboxes, desc="Running model on patches", leave=False):
            # Crop data
            batch = {
                "image": tile_data["image"][:, :, bbox[0]:bbox[2], bbox[1]:bbox[3]],
                "image_mean": tile_data["image_mean"],
                "image_std": tile_data["image_std"],
            }
            # Send batch to device
            try:
                pred, batch = network_inference(config, model, batch)
            except RuntimeError as e:
                print_utils.print_error("ERROR: " + str(e))
                print_utils.print_info("INFO: Reduce --eval_patch_size until the patch fits in memory.")
                raise e

            if config["compute_seg"]:
                tile_data["seg"][:, :, bbox[0]:bbox[2], bbox[1]:bbox[3]] += patch_weights * pred["seg"]
            if config["compute_crossfield"]:
                tile_data["crossfield"][:, :, bbox[0]:bbox[2], bbox[1]:bbox[3]] += patch_weights * pred["crossfield"]
            weight_map[:, :, bbox[0]:bbox[2], bbox[1]:bbox[3]] += patch_weights

        # Take care of overlapping parts
        if config["compute_seg"]:
            tile_data["seg"] /= weight_map
        if config["compute_crossfield"]:
            tile_data["crossfield"] /= weight_map

    return tile_data


def load_checkpoint(model, checkpoints_dirpath, device):
    """
    Loads best val checkpoint in checkpoints_dirpath
    """
    filepaths = python_utils.get_filepaths(checkpoints_dirpath, startswith_str="checkpoint.best_val.",
                                           endswith_str=".tar")
    if len(filepaths):
        filepaths = sorted(filepaths)
        filepath = filepaths[-1]  # Last best val checkpoint filepath in case there is more than one
        print_utils.print_info("Loading best val checkpoint: {}".format(filepath))
    else:
        # No best val checkpoint fount: find last checkpoint:
        filepaths = python_utils.get_filepaths(checkpoints_dirpath, endswith_str=".tar",
                                               startswith_str="checkpoint.")
        filepaths = sorted(filepaths)
        filepath = filepaths[-1]  # Last checkpoint
        print_utils.print_info("Loading last checkpoint: {}".format(filepath))

    device = torch.device(device)
    checkpoint = torch.load(filepath, map_location=device)  # map_location is used to load on current device

    model.load_state_dict(checkpoint['model_state_dict'])

    return model