|
from . import polygonize_utils |
|
from . import polygonize_acm |
|
from . import polygonize_simple |
|
|
|
from lydorn_utils import print_utils |
|
|
|
|
|
class Polygonizer(): |
|
def __init__(self, polygonize_params, pool=None): |
|
self.pool = pool |
|
self.polygonizer_asm = None |
|
|
|
def __call__(self, polygonize_params, seg_batch, crossfield_batch=None, pre_computed=None): |
|
""" |
|
|
|
:param polygonize_params: |
|
:param seg_batch: (N, C, H, W) |
|
:param crossfield_batch: (N, 4, H, W) |
|
:param pre_computed: None o a Dictionary of pre-computed values used for various methods |
|
:return: |
|
""" |
|
assert len(seg_batch.shape) == 4, "seg_batch should be (N, C, H, W)" |
|
assert pre_computed is None or isinstance(pre_computed, dict), "pre_computed should be either None or a dict" |
|
batch_size = seg_batch.shape[0] |
|
|
|
|
|
if type(polygonize_params["method"]) == list: |
|
|
|
if pre_computed is None: |
|
pre_computed = {} |
|
if ("simple" in polygonize_params["method"] or "acm" in polygonize_params["method"]) and "init_contours_batch" not in pre_computed: |
|
indicator_batch = seg_batch[:, 0, :, :] |
|
np_indicator_batch = indicator_batch.cpu().numpy() |
|
init_contours_batch = polygonize_utils.compute_init_contours_batch(np_indicator_batch, |
|
polygonize_params["common_params"][ |
|
"init_data_level"], |
|
pool=self.pool) |
|
pre_computed["init_contours_batch"] = init_contours_batch |
|
|
|
|
|
out_polygons_dict_batch = [{} for _ in range(batch_size)] |
|
out_probs_dict_batch = [{} for _ in range(batch_size)] |
|
for method_name in polygonize_params["method"]: |
|
new_polygonize_params = polygonize_params.copy() |
|
new_polygonize_params["method"] = method_name |
|
polygons_batch, probs_batch = self(new_polygonize_params, seg_batch, |
|
crossfield_batch=crossfield_batch, pre_computed=pre_computed) |
|
if polygons_batch is not None: |
|
for i, (polygons, probs) in enumerate(zip(polygons_batch, probs_batch)): |
|
out_polygons_dict_batch[i][method_name] = polygons |
|
out_probs_dict_batch[i][method_name] = probs |
|
return out_polygons_dict_batch, out_probs_dict_batch |
|
|
|
|
|
if polygonize_params["method"] == "acm": |
|
if crossfield_batch is None: |
|
|
|
return None, None |
|
polygons_batch, probs_batch = polygonize_acm.polygonize(seg_batch, crossfield_batch, |
|
polygonize_params["acm_method"], pool=self.pool, |
|
pre_computed=pre_computed) |
|
elif polygonize_params["method"] == "asm": |
|
from . import polygonize_asm |
|
if crossfield_batch is None: |
|
|
|
return None, None |
|
if self.polygonizer_asm is None: |
|
self.polygonizer_asm = polygonize_asm.PolygonizerASM(polygonize_params["asm_method"], pool=self.pool) |
|
polygons_batch, probs_batch = self.polygonizer_asm(seg_batch, crossfield_batch, pre_computed=pre_computed) |
|
elif polygonize_params["method"] == "simple": |
|
polygons_batch, probs_batch = polygonize_simple.polygonize(seg_batch, polygonize_params["simple_method"], |
|
pool=self.pool, pre_computed=pre_computed) |
|
else: |
|
print_utils.print_error("ERROR: polygonize method {} not recognized!".format(polygonize_params["method"])) |
|
raise NotImplementedError |
|
|
|
return polygons_batch, probs_batch |
|
|
|
|
|
def polygonize(polygonize_params, seg_batch, crossfield_batch=None, pool=None, pre_computed=None): |
|
polygonizer = Polygonizer(polygonize_params, pool=pool) |
|
return polygonizer(polygonize_params, seg_batch, crossfield_batch=crossfield_batch, pre_computed=pre_computed) |
|
|