File size: 5,852 Bytes
abd2a81 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 |
import os
import numpy as np
import skimage.io
import torch
from tqdm import tqdm
from . import data_transforms, save_utils
from .model import FrameFieldModel
from . import inference
from . import local_utils
from torch_lydorn import torchvision
from lydorn_utils import print_utils
from lydorn_utils import run_utils
def inference_from_filepath(config, in_filepaths, backbone, out_dirpath=None):
# --- Online transform performed on the device (GPU):
eval_online_cuda_transform = data_transforms.get_eval_online_cuda_transform(config)
print("Loading model...")
model = FrameFieldModel(config, backbone=backbone, eval_transform=eval_online_cuda_transform)
model.to(config["device"])
checkpoints_dirpath = run_utils.setup_run_subdir(config["eval_params"]["run_dirpath"], config["optim_params"]["checkpoints_dirname"])
model = inference.load_checkpoint(model, checkpoints_dirpath, config["device"])
model.eval()
# Read image
in_filepath_list = [os.path.join(in_filepaths, in_filepath) for in_filepath in os.listdir(in_filepaths) if in_filepath.endswith(('.JPG','.PNG','.png','.jpg','.jepg','bmp'))]
pbar = tqdm(in_filepath_list, desc="Infer images")
for in_filepath in pbar:
print(in_filepath)
pbar.set_postfix(status="Loading image")
image = skimage.io.imread(in_filepath)
if 3 < image.shape[2]:
print_utils.print_info(f"Image {in_filepath} has more than 3 channels. Keeping the first 3 channels and discarding the rest...")
image = image[:, :, :3]
elif image.shape[2] < 3:
print_utils.print_error(f"Image {in_filepath} has only {image.shape[2]} channels but the network expects 3 channels.")
raise ValueError
image_float = image / 255
mean = np.mean(image_float.reshape(-1, image_float.shape[-1]), axis=0)
std = np.std(image_float.reshape(-1, image_float.shape[-1]), axis=0)
sample = {
"image": torchvision.transforms.functional.to_tensor(image)[None, ...],
"image_mean": torch.from_numpy(mean)[None, ...],
"image_std": torch.from_numpy(std)[None, ...],
"image_filepath": [in_filepath],
}
pbar.set_postfix(status="Inference")
tile_data = inference.inference(config, model, sample, compute_polygonization=True)
tile_data = local_utils.batch_to_cpu(tile_data)
# Remove batch dim:
tile_data = local_utils.split_batch(tile_data)[0]
# --- Saving outputs --- #
pbar.set_postfix(status="Saving output")
# Figuring out_base_filepath out:
if out_dirpath is None:
out_dirpath = os.path.dirname(in_filepath)
base_filename = os.path.splitext(os.path.basename(in_filepath))[0]
out_base_filepath = (out_dirpath, base_filename)
if config["compute_seg"]:
if config["eval_params"]["save_individual_outputs"]["seg_mask"]:
seg_mask = 0.5 < tile_data["seg"][0]
save_utils.save_seg_mask(seg_mask, out_base_filepath, "mask", tile_data["image_filepath"])
if config["eval_params"]["save_individual_outputs"]["seg"]:
save_utils.save_seg(tile_data["seg"], out_base_filepath, "seg", tile_data["image_filepath"])
if config["eval_params"]["save_individual_outputs"]["seg_luxcarta"]:
save_utils.save_seg_luxcarta_format(tile_data["seg"], out_base_filepath, "seg_luxcarta_format", tile_data["image_filepath"])
if config["compute_crossfield"] and config["eval_params"]["save_individual_outputs"]["crossfield"]:
save_utils.save_crossfield(tile_data["crossfield"], out_base_filepath, "crossfield")
if "poly_viz" in config["eval_params"]["save_individual_outputs"] and \
config["eval_params"]["save_individual_outputs"]["poly_viz"]:
save_utils.save_poly_viz(tile_data["image"], tile_data["polygons"], tile_data["polygon_probs"], out_base_filepath, "poly_viz")
if config["eval_params"]["save_individual_outputs"]["poly_shapefile"]:
save_utils.save_shapefile(tile_data["polygons"], out_base_filepath, "poly_shapefile", tile_data["image_filepath"])
# if config["eval_params"]["save_individual_outputs"]["seg_gt"]:
# save_utils.save_seg(tile_data["gt_polygons_image"], base_filepath, "seg.gt", tile_data["image_filepath"])
# if config["eval_params"]["save_individual_outputs"]["seg"]:
# save_utils.save_seg(tile_data["seg"], base_filepath, "seg", tile_data["image_filepath"])
# if config["eval_params"]["save_individual_outputs"]["seg_mask"]:
# save_utils.save_seg_mask(tile_data["seg_mask"], base_filepath, "seg_mask", tile_data["image_filepath"])
# if config["eval_params"]["save_individual_outputs"]["seg_opencities_mask"]:
# save_utils.save_opencities_mask(tile_data["seg_mask"], base_filepath, "drivendata",
# tile_data["image_filepath"])
# if config["eval_params"]["save_individual_outputs"]["seg_luxcarta"]:
# save_utils.save_seg_luxcarta_format(tile_data["seg"], base_filepath, "seg_luxcarta_format",
# tile_data["image_filepath"])
# if config["eval_params"]["save_individual_outputs"]["crossfield"]:
# save_utils.save_crossfield(tile_data["crossfield"], base_filepath, "crossfield")
# if config["eval_params"]["save_individual_outputs"]["uv_angles"]:
# save_utils.save_uv_angles(tile_data["crossfield"], base_filepath, "uv_angles", tile_data["image_filepath"])
#
# if "polygons" in tile_data:
# save_utils.save_polygons(tile_data["polygons"], base_filepath, "polygons", tile_data["image_filepath"])
|