File size: 13,980 Bytes
250d697
 
 
 
ff63123
250d697
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
'''
Author: Egrt
Date: 2022-03-19 10:25:50
LastEditors: Egrt
LastEditTime: 2022-03-20 14:58:13
FilePath: \Luuu\gis.py
'''
import os
import numpy as np
import skimage.io
import torch

from tqdm import tqdm

from frame_field_learning import data_transforms, save_utils
from frame_field_learning.model import FrameFieldModel
from frame_field_learning import inference
from frame_field_learning import local_utils
from backbone import get_backbone
from torch_lydorn import torchvision
import argparse
from lydorn_utils import print_utils
from lydorn_utils import run_utils


class GIS(object):
    #-----------------------------------------#
    #   注意修改model_path
    #-----------------------------------------#
    _defaults = {

    }

    #---------------------------------------------------#
    #   初始化SRGAN
    #---------------------------------------------------#
    def __init__(self, **kwargs):
        self.__dict__.update(self._defaults)
        for name, value in kwargs.items():
            setattr(self, name, value)
        self.args = self.get_args()
        self.config = self.launch_inference_from_filepath(self.args)
        self.generate()

    def get_args(self):
        argparser = argparse.ArgumentParser(description=__doc__)
        argparser.add_argument(
            '--in_filepath',
            type=str,
            nargs='*',
            default='images/ex1images',
            help='For launching prediction on several images, use this argument to specify their paths.'
                'If --out_dirpath is specified, prediction outputs will be saved there..'
                'If --out_dirpath is not specified, predictions will be saved next to inputs.'
                'Make sure to also specify the run_name of the model to use for prediction.')
        argparser.add_argument(
            '--out_dirpath',
            type=str,
            default='images',
            help='Path to the output directory of prediction when using the --in_filepath option to launch prediction on several images.')

        argparser.add_argument(
            '-c', '--config',
            type=str,
            help='Name of the config file, excluding the .json file extension.')
        argparser.add_argument(
            '--dataset_params',
            type=str,
            help='Allows to overwrite the dataset_params in the config file. Accepts a path to a .json file.')

        argparser.add_argument(
            '-r', '--runs_dirpath',
            default="runs",
            type=str,
            help='Directory where runs are recorded (model saves and logs).')
        argparser.add_argument(
            '--run_name',
            type=str,
            default='mapping_dataset.unet_resnet101_pretrained.train_val',
            help='Name of the run to use.'
                'That name does not include the timestamp of the folder name: <run_name> | <yyyy-mm-dd hh:mm:ss>.')
        argparser.add_argument(
            '--new_run',
            action='store_true',
            help="Train from scratch (when True) or train from the last checkpoint (when False)")
        argparser.add_argument(
            '--init_run_name',
            type=str,
            help="This is the run_name to initialize the weights from."
                "If None, weights will be initialized randomly."
                "This is a single word, without the timestamp.")
        argparser.add_argument(
            '--samples',
            type=int,
            help='Limits the number of samples to train (and validate and test) if set.')

        argparser.add_argument(
            '-b', '--batch_size',
            type=int,
            help='Batch size. Default value can be set in config file. Is doubled when no back propagation is done (while in eval mode). If a specific effective batch size is desired, set the eval_batch_size argument.')
        argparser.add_argument(
            '--eval_batch_size',
            type=int,
            help='Batch size for evaluation. Overrides the effective batch size when evaluating.')
        argparser.add_argument(
            '-m', '--mode',
            default="train",
            type=str,
            choices=['train', 'eval', 'eval_coco'],
            help='Mode to launch the script in. '
                'Train: train model on speciffied folds. '
                'Eval: eval model on specified fold. '
                'Eval_coco: measures COCO metrics of specified fold')
        argparser.add_argument(
            '--fold',
            nargs='*',
            type=str,
            choices=['train', 'val', 'test'],
            help='If training (mode=train): all folds entered here will be used for optimizing the network.'
                'If the train fold is selected and not the val fold, the val fold will be used during training to validate at each epoch.'
                'The most common scenario is to optimize on train and validate on val: select only train.'
                'When optimizing the network for the last time before test, we would like to optimize it on train + val: in that case select both train and val folds.'
                'Then for evaluation (mode=eval), we might want to evaluate on the val folds for hyper-parameter selection.'
                'And finally evaluate (mode=eval) on the test fold for the final predictions (and possibly metric) for the paper/competition')
        argparser.add_argument(
            '--max_epoch',
            type=int,
            help='Stop training when max_epoch is reached. If not set, value in config is used.')
        argparser.add_argument(
            '--eval_patch_size',
            type=int,
            help='When evaluating, patch size the tile split into.')
        argparser.add_argument(
            '--eval_patch_overlap',
            type=int,
            help='When evaluating, patch the tile with the specified overlap to reduce edge artifacts when reconstructing '
                'the whole tile')

        argparser.add_argument('--master_addr', default="localhost", type=str, help="Address of master node")
        argparser.add_argument('--master_port', default="6666", type=str, help="Port on master node")
        argparser.add_argument('-n', '--nodes', default=1, type=int, metavar='N', help="Number of total nodes")
        argparser.add_argument('-g', '--gpus', default=1, type=int, help='Number of gpus per node')
        argparser.add_argument('-nr', '--nr', default=0, type=int, help='Ranking within the nodes')

        args = argparser.parse_args()
        
        return args

    def launch_inference_from_filepath(self, args):

        # --- First step: figure out what run (experiment) is to be evaluated
        # Option 1: the run_name argument is given in which case that's our run
        run_name = None
        config = None
        if args.run_name is not None:
            run_name = args.run_name
        # Else option 2: Check if a config has been given to look for the run_name
        if args.config is not None:
            config = run_utils.load_config(args.config)
            if config is not None and "run_name" in config and run_name is None:
                run_name = config["run_name"]
        # Else abort...
        if run_name is None:
            print_utils.print_error("ERROR: the run to evaluate could no be identified with the given arguments. "
                                    "Please specify either the --run_name argument or the --config argument "
                                    "linking to a config file that has a 'run_name' field filled with the name of "
                                    "the run name to evaluate.")

        # --- Second step: get path to the run and if --config was not specified, load the config from the run's folder
        run_dirpath = local_utils.get_run_dirpath(args.runs_dirpath, run_name)
        if config is None:
            config = run_utils.load_config(config_dirpath=run_dirpath)
        if config is None:
            print_utils.print_error(f"ERROR: the default run's config file at {run_dirpath} could not be loaded. "
                                    f"Exiting now...")

        # --- Add command-line arguments
        if args.batch_size is not None:
            config["optim_params"]["batch_size"] = args.batch_size
        if args.eval_batch_size is not None:
            config["optim_params"]["eval_batch_size"] = args.eval_batch_size
        else:
            config["optim_params"]["eval_batch_size"] = 2*config["optim_params"]["batch_size"]

        # --- Load params in config set as relative path to another JSON file
        config = run_utils.load_defaults_in_config(config, filepath_key="defaults_filepath")

        config["eval_params"]["run_dirpath"] = run_dirpath
        if args.eval_patch_size is not None:
            config["eval_params"]["patch_size"] = args.eval_patch_size
        if args.eval_patch_overlap is not None:
            config["eval_params"]["patch_overlap"] = args.eval_patch_overlap

        self.backbone = get_backbone(config["backbone_params"])
        return config
    # 加载模型
    def generate(self):
        # --- Online transform performed on the device (GPU):
        eval_online_cuda_transform = data_transforms.get_eval_online_cuda_transform(self.config)

        print("Loading model...")
        self.model = FrameFieldModel(self.config, backbone=self.backbone, eval_transform=eval_online_cuda_transform)
        self.model.to(self.config["device"])
        checkpoints_dirpath = run_utils.setup_run_subdir(self.config["eval_params"]["run_dirpath"], self.config["optim_params"]["checkpoints_dirname"])
        self.model = inference.load_checkpoint(self.model, checkpoints_dirpath, self.config["device"])
        self.model.eval()

    def get_save_filepath(self, base_filepath, name=None, ext=""):
        if type(base_filepath) is tuple:
            if name is not None:
                save_filepath = os.path.join(base_filepath[0], name, base_filepath[1] + ext)
            else:
                save_filepath = os.path.join(base_filepath[0], base_filepath[1] + ext)
        elif type(base_filepath) is str:
            if name is not None:
                save_filepath = base_filepath + "." + name + ext
            else:
                save_filepath = base_filepath + ext
        return save_filepath
    # 检测单张图片 
    def detect_image(self, in_filepath):
        out_dirpath = self.args.out_dirpath
        image = skimage.io.imread(in_filepath)
        patch_size = self.config['eval_params']['patch_size']
        # 如果超出切片预期的大小则关闭切片处理
        if image.shape[0] < patch_size or image.shape[1] < patch_size:
            self.config['eval_params']['patch_size'] = None
        if 3 < image.shape[2]:
            print_utils.print_info(f"Image {in_filepath} has more than 3 channels. Keeping the first 3 channels and discarding the rest...")
            image = image[:, :, :3]
        elif image.shape[2] < 3:
            print_utils.print_error(f"Image {in_filepath} has only {image.shape[2]} channels but the network expects 3 channels.")
            raise ValueError
        image_float = image / 255
        mean = np.mean(image_float.reshape(-1, image_float.shape[-1]), axis=0)
        std = np.std(image_float.reshape(-1, image_float.shape[-1]), axis=0)
        sample = {
            "image": torchvision.transforms.functional.to_tensor(image)[None, ...],
            "image_mean": torch.from_numpy(mean)[None, ...],
            "image_std": torch.from_numpy(std)[None, ...],
            "image_filepath": [in_filepath],
        }


        tile_data = inference.inference(self.config, self.model, sample, compute_polygonization=True)

        tile_data = local_utils.batch_to_cpu(tile_data)

        # Remove batch dim:
        tile_data = local_utils.split_batch(tile_data)[0]


        # Figuring out_base_filepath out:
        if out_dirpath is None:
            out_dirpath = os.path.dirname(in_filepath)
        base_filename = os.path.splitext(os.path.basename(in_filepath))[0]
        out_base_filepath = (out_dirpath, base_filename)
        
        if self.config["compute_seg"]:
            if self.config["eval_params"]["save_individual_outputs"]["seg_mask"]:
                seg_mask = 0.5 < tile_data["seg"][0]
                result_seg_mask_path = save_utils.save_seg_mask(seg_mask, out_base_filepath, "mask", tile_data["image_filepath"])
            if self.config["eval_params"]["save_individual_outputs"]["seg"]:
                result_seg_path = save_utils.save_seg(tile_data["seg"], out_base_filepath, "seg", tile_data["image_filepath"])
        if "poly_viz" in self.config["eval_params"]["save_individual_outputs"] and \
                self.config["eval_params"]["save_individual_outputs"]["poly_viz"]:
            save_utils.save_poly_viz(tile_data["image"], tile_data["polygons"], tile_data["polygon_probs"], out_base_filepath, "poly_viz")
        if self.config["eval_params"]["save_individual_outputs"]["poly_shapefile"]:
            save_utils.save_shapefile(tile_data["polygons"], out_base_filepath, "poly_shapefile", tile_data["image_filepath"])
        pdf_filepath = os.path.join(out_dirpath, 'poly_viz.acm.tol_0.125', base_filename + ".pdf")
        cpg_filepath = os.path.join(out_dirpath, 'poly_shapefile.acm.tol_0.125', base_filename + ".cpg")
        dbf_filepath = os.path.join(out_dirpath, 'poly_shapefile.acm.tol_0.125', base_filename + ".dbf")
        shx_filepath = os.path.join(out_dirpath, 'poly_shapefile.acm.tol_0.125', base_filename + ".shx")
        shp_filepath = os.path.join(out_dirpath, 'poly_shapefile.acm.tol_0.125', base_filename + ".shp")
        prj_filepath = os.path.join(out_dirpath, 'poly_shapefile.acm.tol_0.125', base_filename + ".prj")

        return base_filename, [result_seg_mask_path, result_seg_path, pdf_filepath, cpg_filepath, dbf_filepath, shx_filepath, shp_filepath, prj_filepath]