File size: 5,852 Bytes
abd2a81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import os
import numpy as np
import skimage.io
import torch

from tqdm import tqdm

from . import data_transforms, save_utils
from .model import FrameFieldModel
from . import inference
from . import local_utils

from torch_lydorn import torchvision

from lydorn_utils import print_utils
from lydorn_utils import run_utils


def inference_from_filepath(config, in_filepaths, backbone, out_dirpath=None):
    # --- Online transform performed on the device (GPU):
    eval_online_cuda_transform = data_transforms.get_eval_online_cuda_transform(config)

    print("Loading model...")
    model = FrameFieldModel(config, backbone=backbone, eval_transform=eval_online_cuda_transform)
    model.to(config["device"])
    checkpoints_dirpath = run_utils.setup_run_subdir(config["eval_params"]["run_dirpath"], config["optim_params"]["checkpoints_dirname"])
    model = inference.load_checkpoint(model, checkpoints_dirpath, config["device"])
    model.eval()

    # Read image
    in_filepath_list = [os.path.join(in_filepaths, in_filepath) for in_filepath in os.listdir(in_filepaths) if in_filepath.endswith(('.JPG','.PNG','.png','.jpg','.jepg','bmp'))]
    pbar = tqdm(in_filepath_list, desc="Infer images")
    for in_filepath in pbar:
        print(in_filepath)
        pbar.set_postfix(status="Loading image")
        image = skimage.io.imread(in_filepath)
        if 3 < image.shape[2]:
            print_utils.print_info(f"Image {in_filepath} has more than 3 channels. Keeping the first 3 channels and discarding the rest...")
            image = image[:, :, :3]
        elif image.shape[2] < 3:
            print_utils.print_error(f"Image {in_filepath} has only {image.shape[2]} channels but the network expects 3 channels.")
            raise ValueError
        image_float = image / 255
        mean = np.mean(image_float.reshape(-1, image_float.shape[-1]), axis=0)
        std = np.std(image_float.reshape(-1, image_float.shape[-1]), axis=0)
        sample = {
            "image": torchvision.transforms.functional.to_tensor(image)[None, ...],
            "image_mean": torch.from_numpy(mean)[None, ...],
            "image_std": torch.from_numpy(std)[None, ...],
            "image_filepath": [in_filepath],
        }

        pbar.set_postfix(status="Inference")
        tile_data = inference.inference(config, model, sample, compute_polygonization=True)

        tile_data = local_utils.batch_to_cpu(tile_data)

        # Remove batch dim:
        tile_data = local_utils.split_batch(tile_data)[0]

        # --- Saving outputs --- #

        pbar.set_postfix(status="Saving output")

        # Figuring out_base_filepath out:
        if out_dirpath is None:
            out_dirpath = os.path.dirname(in_filepath)
        base_filename = os.path.splitext(os.path.basename(in_filepath))[0]
        out_base_filepath = (out_dirpath, base_filename)

        if config["compute_seg"]:
            if config["eval_params"]["save_individual_outputs"]["seg_mask"]:
                seg_mask = 0.5 < tile_data["seg"][0]
                save_utils.save_seg_mask(seg_mask, out_base_filepath, "mask", tile_data["image_filepath"])
            if config["eval_params"]["save_individual_outputs"]["seg"]:
                save_utils.save_seg(tile_data["seg"], out_base_filepath, "seg", tile_data["image_filepath"])
            if config["eval_params"]["save_individual_outputs"]["seg_luxcarta"]:
                save_utils.save_seg_luxcarta_format(tile_data["seg"], out_base_filepath, "seg_luxcarta_format", tile_data["image_filepath"])

        if config["compute_crossfield"] and config["eval_params"]["save_individual_outputs"]["crossfield"]:
            save_utils.save_crossfield(tile_data["crossfield"], out_base_filepath, "crossfield")

        if "poly_viz" in config["eval_params"]["save_individual_outputs"] and \
                config["eval_params"]["save_individual_outputs"]["poly_viz"]:
            save_utils.save_poly_viz(tile_data["image"], tile_data["polygons"], tile_data["polygon_probs"], out_base_filepath, "poly_viz")
        if config["eval_params"]["save_individual_outputs"]["poly_shapefile"]:
            save_utils.save_shapefile(tile_data["polygons"], out_base_filepath, "poly_shapefile", tile_data["image_filepath"])

        # if config["eval_params"]["save_individual_outputs"]["seg_gt"]:
        #     save_utils.save_seg(tile_data["gt_polygons_image"], base_filepath, "seg.gt", tile_data["image_filepath"])
        # if config["eval_params"]["save_individual_outputs"]["seg"]:
        #     save_utils.save_seg(tile_data["seg"], base_filepath, "seg", tile_data["image_filepath"])
        # if config["eval_params"]["save_individual_outputs"]["seg_mask"]:
        #     save_utils.save_seg_mask(tile_data["seg_mask"], base_filepath, "seg_mask", tile_data["image_filepath"])
        # if config["eval_params"]["save_individual_outputs"]["seg_opencities_mask"]:
        #     save_utils.save_opencities_mask(tile_data["seg_mask"], base_filepath, "drivendata",
        #                                    tile_data["image_filepath"])
        # if config["eval_params"]["save_individual_outputs"]["seg_luxcarta"]:
        #     save_utils.save_seg_luxcarta_format(tile_data["seg"], base_filepath, "seg_luxcarta_format",
        #                                        tile_data["image_filepath"])
        # if config["eval_params"]["save_individual_outputs"]["crossfield"]:
        #     save_utils.save_crossfield(tile_data["crossfield"], base_filepath, "crossfield")
        # if config["eval_params"]["save_individual_outputs"]["uv_angles"]:
        #     save_utils.save_uv_angles(tile_data["crossfield"], base_filepath, "uv_angles", tile_data["image_filepath"])
        #
        # if "polygons" in tile_data:
        #     save_utils.save_polygons(tile_data["polygons"], base_filepath, "polygons", tile_data["image_filepath"])