''' |
Author: Egrt |
Date: 2022-03-19 10:25:50 |
LastEditors: Egrt |
LastEditTime: 2022-03-20 14:58:13 |
FilePath: \Luuu\gis.py |
''' |
import os |
import numpy as np |
import skimage.io |
import torch |
from tqdm import tqdm |
from frame_field_learning import data_transforms, save_utils |
from frame_field_learning.model import FrameFieldModel |
from frame_field_learning import inference |
from frame_field_learning import local_utils |
from backbone import get_backbone |
from torch_lydorn import torchvision |
import argparse |
from lydorn_utils import print_utils |
from lydorn_utils import run_utils |
class GIS(object): |
_defaults = { |
} |
def __init__(self, **kwargs): |
self.__dict__.update(self._defaults) |
for name, value in kwargs.items(): |
setattr(self, name, value) |
self.args = self.get_args() |
self.config = self.launch_inference_from_filepath(self.args) |
self.generate() |
def get_args(self): |
argparser = argparse.ArgumentParser(description=__doc__) |
argparser.add_argument( |
'--in_filepath', |
type=str, |
nargs='*', |
default='images/ex1images', |
help='For launching prediction on several images, use this argument to specify their paths.' |
'If --out_dirpath is specified, prediction outputs will be saved there..' |
'If --out_dirpath is not specified, predictions will be saved next to inputs.' |
'Make sure to also specify the run_name of the model to use for prediction.') |
argparser.add_argument( |
'--out_dirpath', |
type=str, |
default='images', |
help='Path to the output directory of prediction when using the --in_filepath option to launch prediction on several images.') |
argparser.add_argument( |
'-c', '--config', |
type=str, |
help='Name of the config file, excluding the .json file extension.') |
argparser.add_argument( |
'--dataset_params', |
type=str, |
help='Allows to overwrite the dataset_params in the config file. Accepts a path to a .json file.') |
argparser.add_argument( |
'-r', '--runs_dirpath', |
default="runs", |
type=str, |
help='Directory where runs are recorded (model saves and logs).') |
argparser.add_argument( |
'--run_name', |
type=str, |
default='mapping_dataset.unet_resnet101_pretrained.train_val', |
help='Name of the run to use.' |
'That name does not include the timestamp of the folder name: <run_name> | <yyyy-mm-dd hh:mm:ss>.') |
argparser.add_argument( |
'--new_run', |
action='store_true', |
help="Train from scratch (when True) or train from the last checkpoint (when False)") |
argparser.add_argument( |
'--init_run_name', |
type=str, |
help="This is the run_name to initialize the weights from." |
"If None, weights will be initialized randomly." |
"This is a single word, without the timestamp.") |
argparser.add_argument( |
'--samples', |
type=int, |
help='Limits the number of samples to train (and validate and test) if set.') |
argparser.add_argument( |
'-b', '--batch_size', |
type=int, |
help='Batch size. Default value can be set in config file. Is doubled when no back propagation is done (while in eval mode). If a specific effective batch size is desired, set the eval_batch_size argument.') |
argparser.add_argument( |
'--eval_batch_size', |
type=int, |
help='Batch size for evaluation. Overrides the effective batch size when evaluating.') |
argparser.add_argument( |
'-m', '--mode', |
default="train", |
type=str, |
choices=['train', 'eval', 'eval_coco'], |
help='Mode to launch the script in. ' |
'Train: train model on speciffied folds. ' |
'Eval: eval model on specified fold. ' |
'Eval_coco: measures COCO metrics of specified fold') |
argparser.add_argument( |
'--fold', |
nargs='*', |
type=str, |
choices=['train', 'val', 'test'], |
help='If training (mode=train): all folds entered here will be used for optimizing the network.' |
'If the train fold is selected and not the val fold, the val fold will be used during training to validate at each epoch.' |
'The most common scenario is to optimize on train and validate on val: select only train.' |
'When optimizing the network for the last time before test, we would like to optimize it on train + val: in that case select both train and val folds.' |
'Then for evaluation (mode=eval), we might want to evaluate on the val folds for hyper-parameter selection.' |
'And finally evaluate (mode=eval) on the test fold for the final predictions (and possibly metric) for the paper/competition') |
argparser.add_argument( |
'--max_epoch', |
type=int, |
help='Stop training when max_epoch is reached. If not set, value in config is used.') |
argparser.add_argument( |
'--eval_patch_size', |
type=int, |
help='When evaluating, patch size the tile split into.') |
argparser.add_argument( |
'--eval_patch_overlap', |
type=int, |
help='When evaluating, patch the tile with the specified overlap to reduce edge artifacts when reconstructing ' |
'the whole tile') |
argparser.add_argument('--master_addr', default="localhost", type=str, help="Address of master node") |
argparser.add_argument('--master_port', default="6666", type=str, help="Port on master node") |
argparser.add_argument('-n', '--nodes', default=1, type=int, metavar='N', help="Number of total nodes") |
argparser.add_argument('-g', '--gpus', default=1, type=int, help='Number of gpus per node') |
argparser.add_argument('-nr', '--nr', default=0, type=int, help='Ranking within the nodes') |
args = argparser.parse_args() |
return args |
def launch_inference_from_filepath(self, args): |
run_name = None |
config = None |
if args.run_name is not None: |
run_name = args.run_name |
if args.config is not None: |
config = run_utils.load_config(args.config) |
if config is not None and "run_name" in config and run_name is None: |
run_name = config["run_name"] |
if run_name is None: |
print_utils.print_error("ERROR: the run to evaluate could no be identified with the given arguments. " |
"Please specify either the --run_name argument or the --config argument " |
"linking to a config file that has a 'run_name' field filled with the name of " |
"the run name to evaluate.") |
run_dirpath = local_utils.get_run_dirpath(args.runs_dirpath, run_name) |
if config is None: |
config = run_utils.load_config(config_dirpath=run_dirpath) |
if config is None: |
print_utils.print_error(f"ERROR: the default run's config file at {run_dirpath} could not be loaded. " |
f"Exiting now...") |
if args.batch_size is not None: |
config["optim_params"]["batch_size"] = args.batch_size |
if args.eval_batch_size is not None: |
config["optim_params"]["eval_batch_size"] = args.eval_batch_size |
else: |
config["optim_params"]["eval_batch_size"] = 2*config["optim_params"]["batch_size"] |
config = run_utils.load_defaults_in_config(config, filepath_key="defaults_filepath") |
config["eval_params"]["run_dirpath"] = run_dirpath |
if args.eval_patch_size is not None: |
config["eval_params"]["patch_size"] = args.eval_patch_size |
if args.eval_patch_overlap is not None: |
config["eval_params"]["patch_overlap"] = args.eval_patch_overlap |
self.backbone = get_backbone(config["backbone_params"]) |
return config |
def generate(self): |
eval_online_cuda_transform = data_transforms.get_eval_online_cuda_transform(self.config) |
print("Loading model...") |
self.model = FrameFieldModel(self.config, backbone=self.backbone, eval_transform=eval_online_cuda_transform) |
self.model.to(self.config["device"]) |
checkpoints_dirpath = run_utils.setup_run_subdir(self.config["eval_params"]["run_dirpath"], self.config["optim_params"]["checkpoints_dirname"]) |
self.model = inference.load_checkpoint(self.model, checkpoints_dirpath, self.config["device"]) |
self.model.eval() |
def get_save_filepath(self, base_filepath, name=None, ext=""): |
if type(base_filepath) is tuple: |
if name is not None: |
save_filepath = os.path.join(base_filepath[0], name, base_filepath[1] + ext) |
else: |
save_filepath = os.path.join(base_filepath[0], base_filepath[1] + ext) |
elif type(base_filepath) is str: |
if name is not None: |
save_filepath = base_filepath + "." + name + ext |
else: |
save_filepath = base_filepath + ext |
return save_filepath |
def detect_image(self, in_filepath): |
out_dirpath = self.args.out_dirpath |
image = skimage.io.imread(in_filepath) |
patch_size = self.config['eval_params']['patch_size'] |
if image.shape[0] < patch_size or image.shape[1] < patch_size: |
self.config['eval_params']['patch_size'] = None |
if 3 < image.shape[2]: |
print_utils.print_info(f"Image {in_filepath} has more than 3 channels. Keeping the first 3 channels and discarding the rest...") |
image = image[:, :, :3] |
elif image.shape[2] < 3: |
print_utils.print_error(f"Image {in_filepath} has only {image.shape[2]} channels but the network expects 3 channels.") |
raise ValueError |
image_float = image / 255 |
mean = np.mean(image_float.reshape(-1, image_float.shape[-1]), axis=0) |
std = np.std(image_float.reshape(-1, image_float.shape[-1]), axis=0) |
sample = { |
"image": torchvision.transforms.functional.to_tensor(image)[None, ...], |
"image_mean": torch.from_numpy(mean)[None, ...], |
"image_std": torch.from_numpy(std)[None, ...], |
"image_filepath": [in_filepath], |
} |
tile_data = inference.inference(self.config, self.model, sample, compute_polygonization=True) |
tile_data = local_utils.batch_to_cpu(tile_data) |
tile_data = local_utils.split_batch(tile_data)[0] |
if out_dirpath is None: |
out_dirpath = os.path.dirname(in_filepath) |
base_filename = os.path.splitext(os.path.basename(in_filepath))[0] |
out_base_filepath = (out_dirpath, base_filename) |
if self.config["compute_seg"]: |
if self.config["eval_params"]["save_individual_outputs"]["seg_mask"]: |
seg_mask = 0.5 < tile_data["seg"][0] |
result_seg_mask_path = save_utils.save_seg_mask(seg_mask, out_base_filepath, "mask", tile_data["image_filepath"]) |
if self.config["eval_params"]["save_individual_outputs"]["seg"]: |
result_seg_path = save_utils.save_seg(tile_data["seg"], out_base_filepath, "seg", tile_data["image_filepath"]) |
if "poly_viz" in self.config["eval_params"]["save_individual_outputs"] and \ |
self.config["eval_params"]["save_individual_outputs"]["poly_viz"]: |
save_utils.save_poly_viz(tile_data["image"], tile_data["polygons"], tile_data["polygon_probs"], out_base_filepath, "poly_viz") |
if self.config["eval_params"]["save_individual_outputs"]["poly_shapefile"]: |
save_utils.save_shapefile(tile_data["polygons"], out_base_filepath, "poly_shapefile", tile_data["image_filepath"]) |
pdf_filepath = os.path.join(out_dirpath, 'poly_viz.acm.tol_0.125', base_filename + ".pdf") |
cpg_filepath = os.path.join(out_dirpath, 'poly_shapefile.acm.tol_0.125', base_filename + ".cpg") |
dbf_filepath = os.path.join(out_dirpath, 'poly_shapefile.acm.tol_0.125', base_filename + ".dbf") |
shx_filepath = os.path.join(out_dirpath, 'poly_shapefile.acm.tol_0.125', base_filename + ".shx") |
shp_filepath = os.path.join(out_dirpath, 'poly_shapefile.acm.tol_0.125', base_filename + ".shp") |
prj_filepath = os.path.join(out_dirpath, 'poly_shapefile.acm.tol_0.125', base_filename + ".prj") |
return base_filename, [result_seg_mask_path, result_seg_path, pdf_filepath, cpg_filepath, dbf_filepath, shx_filepath, shp_filepath, prj_filepath] |