File size: 4,632 Bytes
abd2a81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
from . import polygonize_utils
from . import polygonize_acm
from . import polygonize_simple

from lydorn_utils import print_utils


class Polygonizer():
    def __init__(self, polygonize_params, pool=None):
        self.pool = pool
        self.polygonizer_asm = None

    def __call__(self, polygonize_params, seg_batch, crossfield_batch=None, pre_computed=None):
        """

        :param polygonize_params:
        :param seg_batch: (N, C, H, W)
        :param crossfield_batch: (N, 4, H, W)
        :param pre_computed: None o a Dictionary of pre-computed values used for various methods
        :return:
        """
        assert len(seg_batch.shape) == 4, "seg_batch should be (N, C, H, W)"
        assert pre_computed is None or isinstance(pre_computed, dict), "pre_computed should be either None or a dict"
        batch_size = seg_batch.shape[0]

        # Check if polygonize_params["method"] is a list or a string:
        if type(polygonize_params["method"]) == list:
            # --- For speed up, pre-compute anything that is used by multiple methods:
            if pre_computed is None:
                pre_computed = {}
            if ("simple" in polygonize_params["method"] or "acm" in polygonize_params["method"]) and "init_contours_batch" not in pre_computed:
                indicator_batch = seg_batch[:, 0, :, :]
                np_indicator_batch = indicator_batch.cpu().numpy()
                init_contours_batch = polygonize_utils.compute_init_contours_batch(np_indicator_batch,
                                                                                   polygonize_params["common_params"][
                                                                                       "init_data_level"],
                                                                                   pool=self.pool)
                pre_computed["init_contours_batch"] = init_contours_batch
            # ---
            # Run one method after the other:
            out_polygons_dict_batch = [{} for _ in range(batch_size)]
            out_probs_dict_batch = [{} for _ in range(batch_size)]
            for method_name in polygonize_params["method"]:
                new_polygonize_params = polygonize_params.copy()
                new_polygonize_params["method"] = method_name
                polygons_batch, probs_batch = self(new_polygonize_params, seg_batch,
                                                   crossfield_batch=crossfield_batch, pre_computed=pre_computed)
                if polygons_batch is not None:
                    for i, (polygons, probs) in enumerate(zip(polygons_batch, probs_batch)):
                        out_polygons_dict_batch[i][method_name] = polygons
                        out_probs_dict_batch[i][method_name] = probs
            return out_polygons_dict_batch, out_probs_dict_batch

        # --- Else: run the one method
        if polygonize_params["method"] == "acm":
            if crossfield_batch is None:
                # Cannot run the ACM method
                return None, None
            polygons_batch, probs_batch = polygonize_acm.polygonize(seg_batch, crossfield_batch,
                                                                    polygonize_params["acm_method"], pool=self.pool,
                                                                    pre_computed=pre_computed)
        elif polygonize_params["method"] == "asm":
            from . import polygonize_asm
            if crossfield_batch is None:
                # Cannot run the ASM method
                return None, None
            if self.polygonizer_asm is None:
                self.polygonizer_asm = polygonize_asm.PolygonizerASM(polygonize_params["asm_method"], pool=self.pool)
            polygons_batch, probs_batch = self.polygonizer_asm(seg_batch, crossfield_batch, pre_computed=pre_computed)
        elif polygonize_params["method"] == "simple":
            polygons_batch, probs_batch = polygonize_simple.polygonize(seg_batch, polygonize_params["simple_method"],
                                                                       pool=self.pool, pre_computed=pre_computed)
        else:
            print_utils.print_error("ERROR: polygonize method {} not recognized!".format(polygonize_params["method"]))
            raise NotImplementedError

        return polygons_batch, probs_batch


def polygonize(polygonize_params, seg_batch, crossfield_batch=None, pool=None, pre_computed=None):
    polygonizer = Polygonizer(polygonize_params, pool=pool)
    return polygonizer(polygonize_params, seg_batch, crossfield_batch=crossfield_batch, pre_computed=pre_computed)