白鹭先生 commited on
Commit
250d697
·
1 Parent(s): 41a8223
Files changed (1) hide show
  1. gis.py +280 -0
gis.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Author: Egrt
3
+ Date: 2022-03-19 10:25:50
4
+ LastEditors: Egrt
5
+ LastEditTime: 2022-03-20 13:38:21
6
+ FilePath: \Luuu\gis.py
7
+ '''
8
+ from asyncio.windows_events import NULL
9
+ import os
10
+ import numpy as np
11
+ import skimage.io
12
+ import torch
13
+
14
+ from tqdm import tqdm
15
+
16
+ from frame_field_learning import data_transforms, save_utils
17
+ from frame_field_learning.model import FrameFieldModel
18
+ from frame_field_learning import inference
19
+ from frame_field_learning import local_utils
20
+ from backbone import get_backbone
21
+ from torch_lydorn import torchvision
22
+ import argparse
23
+ from lydorn_utils import print_utils
24
+ from lydorn_utils import run_utils
25
+
26
+
27
+ class GIS(object):
28
+ #-----------------------------------------#
29
+ # 注意修改model_path
30
+ #-----------------------------------------#
31
+ _defaults = {
32
+
33
+ }
34
+
35
+ #---------------------------------------------------#
36
+ # 初始化SRGAN
37
+ #---------------------------------------------------#
38
+ def __init__(self, **kwargs):
39
+ self.__dict__.update(self._defaults)
40
+ for name, value in kwargs.items():
41
+ setattr(self, name, value)
42
+ self.args = self.get_args()
43
+ self.config = self.launch_inference_from_filepath(self.args)
44
+ self.generate()
45
+
46
+ def get_args(self):
47
+ argparser = argparse.ArgumentParser(description=__doc__)
48
+ argparser.add_argument(
49
+ '--in_filepath',
50
+ type=str,
51
+ nargs='*',
52
+ default='images/ex1images',
53
+ help='For launching prediction on several images, use this argument to specify their paths.'
54
+ 'If --out_dirpath is specified, prediction outputs will be saved there..'
55
+ 'If --out_dirpath is not specified, predictions will be saved next to inputs.'
56
+ 'Make sure to also specify the run_name of the model to use for prediction.')
57
+ argparser.add_argument(
58
+ '--out_dirpath',
59
+ type=str,
60
+ default='images',
61
+ help='Path to the output directory of prediction when using the --in_filepath option to launch prediction on several images.')
62
+
63
+ argparser.add_argument(
64
+ '-c', '--config',
65
+ type=str,
66
+ help='Name of the config file, excluding the .json file extension.')
67
+ argparser.add_argument(
68
+ '--dataset_params',
69
+ type=str,
70
+ help='Allows to overwrite the dataset_params in the config file. Accepts a path to a .json file.')
71
+
72
+ argparser.add_argument(
73
+ '-r', '--runs_dirpath',
74
+ default="runs",
75
+ type=str,
76
+ help='Directory where runs are recorded (model saves and logs).')
77
+ argparser.add_argument(
78
+ '--run_name',
79
+ type=str,
80
+ default='mapping_dataset.unet_resnet101_pretrained.train_val',
81
+ help='Name of the run to use.'
82
+ 'That name does not include the timestamp of the folder name: <run_name> | <yyyy-mm-dd hh:mm:ss>.')
83
+ argparser.add_argument(
84
+ '--new_run',
85
+ action='store_true',
86
+ help="Train from scratch (when True) or train from the last checkpoint (when False)")
87
+ argparser.add_argument(
88
+ '--init_run_name',
89
+ type=str,
90
+ help="This is the run_name to initialize the weights from."
91
+ "If None, weights will be initialized randomly."
92
+ "This is a single word, without the timestamp.")
93
+ argparser.add_argument(
94
+ '--samples',
95
+ type=int,
96
+ help='Limits the number of samples to train (and validate and test) if set.')
97
+
98
+ argparser.add_argument(
99
+ '-b', '--batch_size',
100
+ type=int,
101
+ help='Batch size. Default value can be set in config file. Is doubled when no back propagation is done (while in eval mode). If a specific effective batch size is desired, set the eval_batch_size argument.')
102
+ argparser.add_argument(
103
+ '--eval_batch_size',
104
+ type=int,
105
+ help='Batch size for evaluation. Overrides the effective batch size when evaluating.')
106
+ argparser.add_argument(
107
+ '-m', '--mode',
108
+ default="train",
109
+ type=str,
110
+ choices=['train', 'eval', 'eval_coco'],
111
+ help='Mode to launch the script in. '
112
+ 'Train: train model on speciffied folds. '
113
+ 'Eval: eval model on specified fold. '
114
+ 'Eval_coco: measures COCO metrics of specified fold')
115
+ argparser.add_argument(
116
+ '--fold',
117
+ nargs='*',
118
+ type=str,
119
+ choices=['train', 'val', 'test'],
120
+ help='If training (mode=train): all folds entered here will be used for optimizing the network.'
121
+ 'If the train fold is selected and not the val fold, the val fold will be used during training to validate at each epoch.'
122
+ 'The most common scenario is to optimize on train and validate on val: select only train.'
123
+ 'When optimizing the network for the last time before test, we would like to optimize it on train + val: in that case select both train and val folds.'
124
+ 'Then for evaluation (mode=eval), we might want to evaluate on the val folds for hyper-parameter selection.'
125
+ 'And finally evaluate (mode=eval) on the test fold for the final predictions (and possibly metric) for the paper/competition')
126
+ argparser.add_argument(
127
+ '--max_epoch',
128
+ type=int,
129
+ help='Stop training when max_epoch is reached. If not set, value in config is used.')
130
+ argparser.add_argument(
131
+ '--eval_patch_size',
132
+ type=int,
133
+ help='When evaluating, patch size the tile split into.')
134
+ argparser.add_argument(
135
+ '--eval_patch_overlap',
136
+ type=int,
137
+ help='When evaluating, patch the tile with the specified overlap to reduce edge artifacts when reconstructing '
138
+ 'the whole tile')
139
+
140
+ argparser.add_argument('--master_addr', default="localhost", type=str, help="Address of master node")
141
+ argparser.add_argument('--master_port', default="6666", type=str, help="Port on master node")
142
+ argparser.add_argument('-n', '--nodes', default=1, type=int, metavar='N', help="Number of total nodes")
143
+ argparser.add_argument('-g', '--gpus', default=1, type=int, help='Number of gpus per node')
144
+ argparser.add_argument('-nr', '--nr', default=0, type=int, help='Ranking within the nodes')
145
+
146
+ args = argparser.parse_args()
147
+
148
+ return args
149
+
150
+ def launch_inference_from_filepath(self, args):
151
+
152
+ # --- First step: figure out what run (experiment) is to be evaluated
153
+ # Option 1: the run_name argument is given in which case that's our run
154
+ run_name = None
155
+ config = None
156
+ if args.run_name is not None:
157
+ run_name = args.run_name
158
+ # Else option 2: Check if a config has been given to look for the run_name
159
+ if args.config is not None:
160
+ config = run_utils.load_config(args.config)
161
+ if config is not None and "run_name" in config and run_name is None:
162
+ run_name = config["run_name"]
163
+ # Else abort...
164
+ if run_name is None:
165
+ print_utils.print_error("ERROR: the run to evaluate could no be identified with the given arguments. "
166
+ "Please specify either the --run_name argument or the --config argument "
167
+ "linking to a config file that has a 'run_name' field filled with the name of "
168
+ "the run name to evaluate.")
169
+
170
+ # --- Second step: get path to the run and if --config was not specified, load the config from the run's folder
171
+ run_dirpath = local_utils.get_run_dirpath(args.runs_dirpath, run_name)
172
+ if config is None:
173
+ config = run_utils.load_config(config_dirpath=run_dirpath)
174
+ if config is None:
175
+ print_utils.print_error(f"ERROR: the default run's config file at {run_dirpath} could not be loaded. "
176
+ f"Exiting now...")
177
+
178
+ # --- Add command-line arguments
179
+ if args.batch_size is not None:
180
+ config["optim_params"]["batch_size"] = args.batch_size
181
+ if args.eval_batch_size is not None:
182
+ config["optim_params"]["eval_batch_size"] = args.eval_batch_size
183
+ else:
184
+ config["optim_params"]["eval_batch_size"] = 2*config["optim_params"]["batch_size"]
185
+
186
+ # --- Load params in config set as relative path to another JSON file
187
+ config = run_utils.load_defaults_in_config(config, filepath_key="defaults_filepath")
188
+
189
+ config["eval_params"]["run_dirpath"] = run_dirpath
190
+ if args.eval_patch_size is not None:
191
+ config["eval_params"]["patch_size"] = args.eval_patch_size
192
+ if args.eval_patch_overlap is not None:
193
+ config["eval_params"]["patch_overlap"] = args.eval_patch_overlap
194
+
195
+ self.backbone = get_backbone(config["backbone_params"])
196
+ return config
197
+ # 加载模型
198
+ def generate(self):
199
+ # --- Online transform performed on the device (GPU):
200
+ eval_online_cuda_transform = data_transforms.get_eval_online_cuda_transform(self.config)
201
+
202
+ print("Loading model...")
203
+ self.model = FrameFieldModel(self.config, backbone=self.backbone, eval_transform=eval_online_cuda_transform)
204
+ self.model.to(self.config["device"])
205
+ checkpoints_dirpath = run_utils.setup_run_subdir(self.config["eval_params"]["run_dirpath"], self.config["optim_params"]["checkpoints_dirname"])
206
+ self.model = inference.load_checkpoint(self.model, checkpoints_dirpath, self.config["device"])
207
+ self.model.eval()
208
+
209
+ def get_save_filepath(self, base_filepath, name=None, ext=""):
210
+ if type(base_filepath) is tuple:
211
+ if name is not None:
212
+ save_filepath = os.path.join(base_filepath[0], name, base_filepath[1] + ext)
213
+ else:
214
+ save_filepath = os.path.join(base_filepath[0], base_filepath[1] + ext)
215
+ elif type(base_filepath) is str:
216
+ if name is not None:
217
+ save_filepath = base_filepath + "." + name + ext
218
+ else:
219
+ save_filepath = base_filepath + ext
220
+ return save_filepath
221
+ # 检测单张图片
222
+ def detect_image(self, in_filepath):
223
+ out_dirpath = self.args.out_dirpath
224
+ image = skimage.io.imread(in_filepath)
225
+ patch_size = self.config['eval_params']['patch_size']
226
+ # 如果超出切片预期的大小则关闭切片处理
227
+ if image.shape[0] < patch_size or image.shape[1] < patch_size:
228
+ self.config['eval_params']['patch_size'] = None
229
+ if 3 < image.shape[2]:
230
+ print_utils.print_info(f"Image {in_filepath} has more than 3 channels. Keeping the first 3 channels and discarding the rest...")
231
+ image = image[:, :, :3]
232
+ elif image.shape[2] < 3:
233
+ print_utils.print_error(f"Image {in_filepath} has only {image.shape[2]} channels but the network expects 3 channels.")
234
+ raise ValueError
235
+ image_float = image / 255
236
+ mean = np.mean(image_float.reshape(-1, image_float.shape[-1]), axis=0)
237
+ std = np.std(image_float.reshape(-1, image_float.shape[-1]), axis=0)
238
+ sample = {
239
+ "image": torchvision.transforms.functional.to_tensor(image)[None, ...],
240
+ "image_mean": torch.from_numpy(mean)[None, ...],
241
+ "image_std": torch.from_numpy(std)[None, ...],
242
+ "image_filepath": [in_filepath],
243
+ }
244
+
245
+
246
+ tile_data = inference.inference(self.config, self.model, sample, compute_polygonization=True)
247
+
248
+ tile_data = local_utils.batch_to_cpu(tile_data)
249
+
250
+ # Remove batch dim:
251
+ tile_data = local_utils.split_batch(tile_data)[0]
252
+
253
+
254
+ # Figuring out_base_filepath out:
255
+ if out_dirpath is None:
256
+ out_dirpath = os.path.dirname(in_filepath)
257
+ base_filename = os.path.splitext(os.path.basename(in_filepath))[0]
258
+ out_base_filepath = (out_dirpath, base_filename)
259
+
260
+ if self.config["compute_seg"]:
261
+ if self.config["eval_params"]["save_individual_outputs"]["seg_mask"]:
262
+ seg_mask = 0.5 < tile_data["seg"][0]
263
+ result_seg_mask_path = save_utils.save_seg_mask(seg_mask, out_base_filepath, "mask", tile_data["image_filepath"])
264
+ if self.config["eval_params"]["save_individual_outputs"]["seg"]:
265
+ result_seg_path = save_utils.save_seg(tile_data["seg"], out_base_filepath, "seg", tile_data["image_filepath"])
266
+ if "poly_viz" in self.config["eval_params"]["save_individual_outputs"] and \
267
+ self.config["eval_params"]["save_individual_outputs"]["poly_viz"]:
268
+ save_utils.save_poly_viz(tile_data["image"], tile_data["polygons"], tile_data["polygon_probs"], out_base_filepath, "poly_viz")
269
+ if self.config["eval_params"]["save_individual_outputs"]["poly_shapefile"]:
270
+ save_utils.save_shapefile(tile_data["polygons"], out_base_filepath, "poly_shapefile", tile_data["image_filepath"])
271
+ pdf_filepath = os.path.join(out_dirpath, 'poly_viz.acm.tol_0.125', base_filename + ".pdf")
272
+ cpg_filepath = os.path.join(out_dirpath, 'poly_shapefile.acm.tol_0.125', base_filename + ".cpg")
273
+ dbf_filepath = os.path.join(out_dirpath, 'poly_shapefile.acm.tol_0.125', base_filename + ".dbf")
274
+ shx_filepath = os.path.join(out_dirpath, 'poly_shapefile.acm.tol_0.125', base_filename + ".shx")
275
+ shp_filepath = os.path.join(out_dirpath, 'poly_shapefile.acm.tol_0.125', base_filename + ".shp")
276
+ prj_filepath = os.path.join(out_dirpath, 'poly_shapefile.acm.tol_0.125', base_filename + ".prj")
277
+
278
+ return base_filename, [result_seg_mask_path, result_seg_path, pdf_filepath, cpg_filepath, dbf_filepath, shx_filepath, shp_filepath, prj_filepath]
279
+
280
+