|
''' |
|
Author: Egrt |
|
Date: 2022-03-19 10:25:50 |
|
LastEditors: Egrt |
|
LastEditTime: 2022-03-20 14:58:13 |
|
FilePath: \Luuu\gis.py |
|
''' |
|
import os |
|
import numpy as np |
|
import skimage.io |
|
import torch |
|
|
|
from tqdm import tqdm |
|
|
|
from frame_field_learning import data_transforms, save_utils |
|
from frame_field_learning.model import FrameFieldModel |
|
from frame_field_learning import inference |
|
from frame_field_learning import local_utils |
|
from backbone import get_backbone |
|
from torch_lydorn import torchvision |
|
import argparse |
|
from lydorn_utils import print_utils |
|
from lydorn_utils import run_utils |
|
|
|
|
|
class GIS(object): |
|
|
|
|
|
|
|
_defaults = { |
|
|
|
} |
|
|
|
|
|
|
|
|
|
def __init__(self, **kwargs): |
|
self.__dict__.update(self._defaults) |
|
for name, value in kwargs.items(): |
|
setattr(self, name, value) |
|
self.args = self.get_args() |
|
self.config = self.launch_inference_from_filepath(self.args) |
|
self.generate() |
|
|
|
def get_args(self): |
|
argparser = argparse.ArgumentParser(description=__doc__) |
|
argparser.add_argument( |
|
'--in_filepath', |
|
type=str, |
|
nargs='*', |
|
default='images/ex1images', |
|
help='For launching prediction on several images, use this argument to specify their paths.' |
|
'If --out_dirpath is specified, prediction outputs will be saved there..' |
|
'If --out_dirpath is not specified, predictions will be saved next to inputs.' |
|
'Make sure to also specify the run_name of the model to use for prediction.') |
|
argparser.add_argument( |
|
'--out_dirpath', |
|
type=str, |
|
default='images', |
|
help='Path to the output directory of prediction when using the --in_filepath option to launch prediction on several images.') |
|
|
|
argparser.add_argument( |
|
'-c', '--config', |
|
type=str, |
|
help='Name of the config file, excluding the .json file extension.') |
|
argparser.add_argument( |
|
'--dataset_params', |
|
type=str, |
|
help='Allows to overwrite the dataset_params in the config file. Accepts a path to a .json file.') |
|
|
|
argparser.add_argument( |
|
'-r', '--runs_dirpath', |
|
default="runs", |
|
type=str, |
|
help='Directory where runs are recorded (model saves and logs).') |
|
argparser.add_argument( |
|
'--run_name', |
|
type=str, |
|
default='mapping_dataset.unet_resnet101_pretrained.train_val', |
|
help='Name of the run to use.' |
|
'That name does not include the timestamp of the folder name: <run_name> | <yyyy-mm-dd hh:mm:ss>.') |
|
argparser.add_argument( |
|
'--new_run', |
|
action='store_true', |
|
help="Train from scratch (when True) or train from the last checkpoint (when False)") |
|
argparser.add_argument( |
|
'--init_run_name', |
|
type=str, |
|
help="This is the run_name to initialize the weights from." |
|
"If None, weights will be initialized randomly." |
|
"This is a single word, without the timestamp.") |
|
argparser.add_argument( |
|
'--samples', |
|
type=int, |
|
help='Limits the number of samples to train (and validate and test) if set.') |
|
|
|
argparser.add_argument( |
|
'-b', '--batch_size', |
|
type=int, |
|
help='Batch size. Default value can be set in config file. Is doubled when no back propagation is done (while in eval mode). If a specific effective batch size is desired, set the eval_batch_size argument.') |
|
argparser.add_argument( |
|
'--eval_batch_size', |
|
type=int, |
|
help='Batch size for evaluation. Overrides the effective batch size when evaluating.') |
|
argparser.add_argument( |
|
'-m', '--mode', |
|
default="train", |
|
type=str, |
|
choices=['train', 'eval', 'eval_coco'], |
|
help='Mode to launch the script in. ' |
|
'Train: train model on speciffied folds. ' |
|
'Eval: eval model on specified fold. ' |
|
'Eval_coco: measures COCO metrics of specified fold') |
|
argparser.add_argument( |
|
'--fold', |
|
nargs='*', |
|
type=str, |
|
choices=['train', 'val', 'test'], |
|
help='If training (mode=train): all folds entered here will be used for optimizing the network.' |
|
'If the train fold is selected and not the val fold, the val fold will be used during training to validate at each epoch.' |
|
'The most common scenario is to optimize on train and validate on val: select only train.' |
|
'When optimizing the network for the last time before test, we would like to optimize it on train + val: in that case select both train and val folds.' |
|
'Then for evaluation (mode=eval), we might want to evaluate on the val folds for hyper-parameter selection.' |
|
'And finally evaluate (mode=eval) on the test fold for the final predictions (and possibly metric) for the paper/competition') |
|
argparser.add_argument( |
|
'--max_epoch', |
|
type=int, |
|
help='Stop training when max_epoch is reached. If not set, value in config is used.') |
|
argparser.add_argument( |
|
'--eval_patch_size', |
|
type=int, |
|
help='When evaluating, patch size the tile split into.') |
|
argparser.add_argument( |
|
'--eval_patch_overlap', |
|
type=int, |
|
help='When evaluating, patch the tile with the specified overlap to reduce edge artifacts when reconstructing ' |
|
'the whole tile') |
|
|
|
argparser.add_argument('--master_addr', default="localhost", type=str, help="Address of master node") |
|
argparser.add_argument('--master_port', default="6666", type=str, help="Port on master node") |
|
argparser.add_argument('-n', '--nodes', default=1, type=int, metavar='N', help="Number of total nodes") |
|
argparser.add_argument('-g', '--gpus', default=1, type=int, help='Number of gpus per node') |
|
argparser.add_argument('-nr', '--nr', default=0, type=int, help='Ranking within the nodes') |
|
|
|
args = argparser.parse_args() |
|
|
|
return args |
|
|
|
def launch_inference_from_filepath(self, args): |
|
|
|
|
|
|
|
run_name = None |
|
config = None |
|
if args.run_name is not None: |
|
run_name = args.run_name |
|
|
|
if args.config is not None: |
|
config = run_utils.load_config(args.config) |
|
if config is not None and "run_name" in config and run_name is None: |
|
run_name = config["run_name"] |
|
|
|
if run_name is None: |
|
print_utils.print_error("ERROR: the run to evaluate could no be identified with the given arguments. " |
|
"Please specify either the --run_name argument or the --config argument " |
|
"linking to a config file that has a 'run_name' field filled with the name of " |
|
"the run name to evaluate.") |
|
|
|
|
|
run_dirpath = local_utils.get_run_dirpath(args.runs_dirpath, run_name) |
|
if config is None: |
|
config = run_utils.load_config(config_dirpath=run_dirpath) |
|
if config is None: |
|
print_utils.print_error(f"ERROR: the default run's config file at {run_dirpath} could not be loaded. " |
|
f"Exiting now...") |
|
|
|
|
|
if args.batch_size is not None: |
|
config["optim_params"]["batch_size"] = args.batch_size |
|
if args.eval_batch_size is not None: |
|
config["optim_params"]["eval_batch_size"] = args.eval_batch_size |
|
else: |
|
config["optim_params"]["eval_batch_size"] = 2*config["optim_params"]["batch_size"] |
|
|
|
|
|
config = run_utils.load_defaults_in_config(config, filepath_key="defaults_filepath") |
|
|
|
config["eval_params"]["run_dirpath"] = run_dirpath |
|
if args.eval_patch_size is not None: |
|
config["eval_params"]["patch_size"] = args.eval_patch_size |
|
if args.eval_patch_overlap is not None: |
|
config["eval_params"]["patch_overlap"] = args.eval_patch_overlap |
|
|
|
self.backbone = get_backbone(config["backbone_params"]) |
|
return config |
|
|
|
def generate(self): |
|
|
|
eval_online_cuda_transform = data_transforms.get_eval_online_cuda_transform(self.config) |
|
|
|
print("Loading model...") |
|
self.model = FrameFieldModel(self.config, backbone=self.backbone, eval_transform=eval_online_cuda_transform) |
|
self.model.to(self.config["device"]) |
|
checkpoints_dirpath = run_utils.setup_run_subdir(self.config["eval_params"]["run_dirpath"], self.config["optim_params"]["checkpoints_dirname"]) |
|
self.model = inference.load_checkpoint(self.model, checkpoints_dirpath, self.config["device"]) |
|
self.model.eval() |
|
|
|
def get_save_filepath(self, base_filepath, name=None, ext=""): |
|
if type(base_filepath) is tuple: |
|
if name is not None: |
|
save_filepath = os.path.join(base_filepath[0], name, base_filepath[1] + ext) |
|
else: |
|
save_filepath = os.path.join(base_filepath[0], base_filepath[1] + ext) |
|
elif type(base_filepath) is str: |
|
if name is not None: |
|
save_filepath = base_filepath + "." + name + ext |
|
else: |
|
save_filepath = base_filepath + ext |
|
return save_filepath |
|
|
|
def detect_image(self, in_filepath): |
|
out_dirpath = self.args.out_dirpath |
|
image = skimage.io.imread(in_filepath) |
|
patch_size = self.config['eval_params']['patch_size'] |
|
|
|
if image.shape[0] < patch_size or image.shape[1] < patch_size: |
|
self.config['eval_params']['patch_size'] = None |
|
if 3 < image.shape[2]: |
|
print_utils.print_info(f"Image {in_filepath} has more than 3 channels. Keeping the first 3 channels and discarding the rest...") |
|
image = image[:, :, :3] |
|
elif image.shape[2] < 3: |
|
print_utils.print_error(f"Image {in_filepath} has only {image.shape[2]} channels but the network expects 3 channels.") |
|
raise ValueError |
|
image_float = image / 255 |
|
mean = np.mean(image_float.reshape(-1, image_float.shape[-1]), axis=0) |
|
std = np.std(image_float.reshape(-1, image_float.shape[-1]), axis=0) |
|
sample = { |
|
"image": torchvision.transforms.functional.to_tensor(image)[None, ...], |
|
"image_mean": torch.from_numpy(mean)[None, ...], |
|
"image_std": torch.from_numpy(std)[None, ...], |
|
"image_filepath": [in_filepath], |
|
} |
|
|
|
|
|
tile_data = inference.inference(self.config, self.model, sample, compute_polygonization=True) |
|
|
|
tile_data = local_utils.batch_to_cpu(tile_data) |
|
|
|
|
|
tile_data = local_utils.split_batch(tile_data)[0] |
|
|
|
|
|
|
|
if out_dirpath is None: |
|
out_dirpath = os.path.dirname(in_filepath) |
|
base_filename = os.path.splitext(os.path.basename(in_filepath))[0] |
|
out_base_filepath = (out_dirpath, base_filename) |
|
|
|
if self.config["compute_seg"]: |
|
if self.config["eval_params"]["save_individual_outputs"]["seg_mask"]: |
|
seg_mask = 0.5 < tile_data["seg"][0] |
|
result_seg_mask_path = save_utils.save_seg_mask(seg_mask, out_base_filepath, "mask", tile_data["image_filepath"]) |
|
if self.config["eval_params"]["save_individual_outputs"]["seg"]: |
|
result_seg_path = save_utils.save_seg(tile_data["seg"], out_base_filepath, "seg", tile_data["image_filepath"]) |
|
if "poly_viz" in self.config["eval_params"]["save_individual_outputs"] and \ |
|
self.config["eval_params"]["save_individual_outputs"]["poly_viz"]: |
|
save_utils.save_poly_viz(tile_data["image"], tile_data["polygons"], tile_data["polygon_probs"], out_base_filepath, "poly_viz") |
|
if self.config["eval_params"]["save_individual_outputs"]["poly_shapefile"]: |
|
save_utils.save_shapefile(tile_data["polygons"], out_base_filepath, "poly_shapefile", tile_data["image_filepath"]) |
|
pdf_filepath = os.path.join(out_dirpath, 'poly_viz.acm.tol_0.125', base_filename + ".pdf") |
|
cpg_filepath = os.path.join(out_dirpath, 'poly_shapefile.acm.tol_0.125', base_filename + ".cpg") |
|
dbf_filepath = os.path.join(out_dirpath, 'poly_shapefile.acm.tol_0.125', base_filename + ".dbf") |
|
shx_filepath = os.path.join(out_dirpath, 'poly_shapefile.acm.tol_0.125', base_filename + ".shx") |
|
shp_filepath = os.path.join(out_dirpath, 'poly_shapefile.acm.tol_0.125', base_filename + ".shp") |
|
prj_filepath = os.path.join(out_dirpath, 'poly_shapefile.acm.tol_0.125', base_filename + ".prj") |
|
|
|
return base_filename, [result_seg_mask_path, result_seg_path, pdf_filepath, cpg_filepath, dbf_filepath, shx_filepath, shp_filepath, prj_filepath] |
|
|
|
|