''' Author: Egrt Date: 2022-03-19 10:25:50 LastEditors: Egrt LastEditTime: 2022-03-21 00:01:10 FilePath: \Luuu\gis.py ''' import os import numpy as np import skimage.io import torch from tqdm import tqdm from frame_field_learning import data_transforms, save_utils from frame_field_learning.model import FrameFieldModel from frame_field_learning import inference from frame_field_learning import local_utils from backbone import get_backbone from torch_lydorn import torchvision import argparse from lydorn_utils import print_utils from lydorn_utils import run_utils class GIS(object): #-----------------------------------------# # 注意修改model_path #-----------------------------------------# _defaults = { } #---------------------------------------------------# # 初始化SRGAN #---------------------------------------------------# def __init__(self, **kwargs): self.__dict__.update(self._defaults) for name, value in kwargs.items(): setattr(self, name, value) self.args = self.get_args() self.config = self.launch_inference_from_filepath(self.args) self.generate() def get_args(self): argparser = argparse.ArgumentParser(description=__doc__) argparser.add_argument( '--in_filepath', type=str, nargs='*', default='images/ex1images', help='For launching prediction on several images, use this argument to specify their paths.' 'If --out_dirpath is specified, prediction outputs will be saved there..' 'If --out_dirpath is not specified, predictions will be saved next to inputs.' 'Make sure to also specify the run_name of the model to use for prediction.') argparser.add_argument( '--out_dirpath', type=str, default='images', help='Path to the output directory of prediction when using the --in_filepath option to launch prediction on several images.') argparser.add_argument( '-c', '--config', type=str, help='Name of the config file, excluding the .json file extension.') argparser.add_argument( '--dataset_params', type=str, help='Allows to overwrite the dataset_params in the config file. Accepts a path to a .json file.') argparser.add_argument( '-r', '--runs_dirpath', default="runs", type=str, help='Directory where runs are recorded (model saves and logs).') argparser.add_argument( '--run_name', type=str, default='mapping_dataset.unet_resnet101_pretrained.train_val', help='Name of the run to use.' 'That name does not include the timestamp of the folder name: | .') argparser.add_argument( '--new_run', action='store_true', help="Train from scratch (when True) or train from the last checkpoint (when False)") argparser.add_argument( '--init_run_name', type=str, help="This is the run_name to initialize the weights from." "If None, weights will be initialized randomly." "This is a single word, without the timestamp.") argparser.add_argument( '--samples', type=int, help='Limits the number of samples to train (and validate and test) if set.') argparser.add_argument( '-b', '--batch_size', type=int, help='Batch size. Default value can be set in config file. Is doubled when no back propagation is done (while in eval mode). If a specific effective batch size is desired, set the eval_batch_size argument.') argparser.add_argument( '--eval_batch_size', type=int, help='Batch size for evaluation. Overrides the effective batch size when evaluating.') argparser.add_argument( '-m', '--mode', default="train", type=str, choices=['train', 'eval', 'eval_coco'], help='Mode to launch the script in. ' 'Train: train model on speciffied folds. ' 'Eval: eval model on specified fold. ' 'Eval_coco: measures COCO metrics of specified fold') argparser.add_argument( '--fold', nargs='*', type=str, choices=['train', 'val', 'test'], help='If training (mode=train): all folds entered here will be used for optimizing the network.' 'If the train fold is selected and not the val fold, the val fold will be used during training to validate at each epoch.' 'The most common scenario is to optimize on train and validate on val: select only train.' 'When optimizing the network for the last time before test, we would like to optimize it on train + val: in that case select both train and val folds.' 'Then for evaluation (mode=eval), we might want to evaluate on the val folds for hyper-parameter selection.' 'And finally evaluate (mode=eval) on the test fold for the final predictions (and possibly metric) for the paper/competition') argparser.add_argument( '--max_epoch', type=int, help='Stop training when max_epoch is reached. If not set, value in config is used.') argparser.add_argument( '--eval_patch_size', type=int, help='When evaluating, patch size the tile split into.') argparser.add_argument( '--eval_patch_overlap', type=int, help='When evaluating, patch the tile with the specified overlap to reduce edge artifacts when reconstructing ' 'the whole tile') argparser.add_argument('--master_addr', default="localhost", type=str, help="Address of master node") argparser.add_argument('--master_port', default="6666", type=str, help="Port on master node") argparser.add_argument('-n', '--nodes', default=1, type=int, metavar='N', help="Number of total nodes") argparser.add_argument('-g', '--gpus', default=1, type=int, help='Number of gpus per node') argparser.add_argument('-nr', '--nr', default=0, type=int, help='Ranking within the nodes') args = argparser.parse_args() return args def launch_inference_from_filepath(self, args): # --- First step: figure out what run (experiment) is to be evaluated # Option 1: the run_name argument is given in which case that's our run run_name = None config = None if args.run_name is not None: run_name = args.run_name # Else option 2: Check if a config has been given to look for the run_name if args.config is not None: config = run_utils.load_config(args.config) if config is not None and "run_name" in config and run_name is None: run_name = config["run_name"] # Else abort... if run_name is None: print_utils.print_error("ERROR: the run to evaluate could no be identified with the given arguments. " "Please specify either the --run_name argument or the --config argument " "linking to a config file that has a 'run_name' field filled with the name of " "the run name to evaluate.") # --- Second step: get path to the run and if --config was not specified, load the config from the run's folder run_dirpath = local_utils.get_run_dirpath(args.runs_dirpath, run_name) if config is None: config = run_utils.load_config(config_dirpath=run_dirpath) if config is None: print_utils.print_error(f"ERROR: the default run's config file at {run_dirpath} could not be loaded. " f"Exiting now...") # --- Add command-line arguments if args.batch_size is not None: config["optim_params"]["batch_size"] = args.batch_size if args.eval_batch_size is not None: config["optim_params"]["eval_batch_size"] = args.eval_batch_size else: config["optim_params"]["eval_batch_size"] = 2*config["optim_params"]["batch_size"] # --- Load params in config set as relative path to another JSON file config = run_utils.load_defaults_in_config(config, filepath_key="defaults_filepath") config["eval_params"]["run_dirpath"] = run_dirpath if args.eval_patch_size is not None: config["eval_params"]["patch_size"] = args.eval_patch_size if args.eval_patch_overlap is not None: config["eval_params"]["patch_overlap"] = args.eval_patch_overlap self.backbone = get_backbone(config["backbone_params"]) return config # 加载模型 def generate(self): # --- Online transform performed on the device (GPU): eval_online_cuda_transform = data_transforms.get_eval_online_cuda_transform(self.config) print("Loading model...") self.model = FrameFieldModel(self.config, backbone=self.backbone, eval_transform=eval_online_cuda_transform) self.model.to(self.config["device"]) checkpoints_dirpath = run_utils.setup_run_subdir(self.config["eval_params"]["run_dirpath"], self.config["optim_params"]["checkpoints_dirname"]) self.model = inference.load_checkpoint(self.model, checkpoints_dirpath, self.config["device"]) self.model.eval() def get_save_filepath(self, base_filepath, name=None, ext=""): if type(base_filepath) is tuple: if name is not None: save_filepath = os.path.join(base_filepath[0], name, base_filepath[1] + ext) else: save_filepath = os.path.join(base_filepath[0], base_filepath[1] + ext) elif type(base_filepath) is str: if name is not None: save_filepath = base_filepath + "." + name + ext else: save_filepath = base_filepath + ext return save_filepath # 检测单张图片 def detect_image(self, in_filepath): out_dirpath = self.args.out_dirpath image = skimage.io.imread(in_filepath) # 如果超出切片预期的大小则关闭切片处理 patch_size = self.config['eval_params']['patch_size'] if image.shape[0] < patch_size or image.shape[1] < patch_size: self.config['eval_params']['patch_size'] = None else: self.config['eval_params']['patch_size'] = 1024 if 3 < image.shape[2]: print_utils.print_info(f"Image {in_filepath} has more than 3 channels. Keeping the first 3 channels and discarding the rest...") image = image[:, :, :3] elif image.shape[2] < 3: print_utils.print_error(f"Image {in_filepath} has only {image.shape[2]} channels but the network expects 3 channels.") raise ValueError image_float = image / 255 mean = np.mean(image_float.reshape(-1, image_float.shape[-1]), axis=0) std = np.std(image_float.reshape(-1, image_float.shape[-1]), axis=0) sample = { "image": torchvision.transforms.functional.to_tensor(image)[None, ...], "image_mean": torch.from_numpy(mean)[None, ...], "image_std": torch.from_numpy(std)[None, ...], "image_filepath": [in_filepath], } tile_data = inference.inference(self.config, self.model, sample, compute_polygonization=True) tile_data = local_utils.batch_to_cpu(tile_data) # Remove batch dim: tile_data = local_utils.split_batch(tile_data)[0] # Figuring out_base_filepath out: if out_dirpath is None: out_dirpath = os.path.dirname(in_filepath) base_filename = os.path.splitext(os.path.basename(in_filepath))[0] out_base_filepath = (out_dirpath, base_filename) if self.config["compute_seg"]: if self.config["eval_params"]["save_individual_outputs"]["seg_mask"]: seg_mask = 0.5 < tile_data["seg"][0] result_seg_mask_path = save_utils.save_seg_mask(seg_mask, out_base_filepath, "mask", tile_data["image_filepath"]) if self.config["eval_params"]["save_individual_outputs"]["seg"]: result_seg_path = save_utils.save_seg(tile_data["seg"], out_base_filepath, "seg", tile_data["image_filepath"]) if "poly_viz" in self.config["eval_params"]["save_individual_outputs"] and \ self.config["eval_params"]["save_individual_outputs"]["poly_viz"]: save_utils.save_poly_viz(tile_data["image"], tile_data["polygons"], tile_data["polygon_probs"], out_base_filepath, "poly_viz") if self.config["eval_params"]["save_individual_outputs"]["poly_shapefile"]: save_utils.save_shapefile(tile_data["polygons"], out_base_filepath, "poly_shapefile", tile_data["image_filepath"]) pdf_filepath = os.path.join(out_dirpath, 'poly_viz.acm.tol_0.125', base_filename + ".pdf") cpg_filepath = os.path.join(out_dirpath, 'poly_shapefile.acm.tol_0.125', base_filename + ".cpg") dbf_filepath = os.path.join(out_dirpath, 'poly_shapefile.acm.tol_0.125', base_filename + ".dbf") shx_filepath = os.path.join(out_dirpath, 'poly_shapefile.acm.tol_0.125', base_filename + ".shx") shp_filepath = os.path.join(out_dirpath, 'poly_shapefile.acm.tol_0.125', base_filename + ".shp") prj_filepath = os.path.join(out_dirpath, 'poly_shapefile.acm.tol_0.125', base_filename + ".prj") self.config['eval_params']['patch_size'] = 1024 return base_filename, [result_seg_mask_path, result_seg_path, pdf_filepath, cpg_filepath, dbf_filepath, shx_filepath, shp_filepath, prj_filepath]