|
import torch |
|
from torchvision.models.segmentation._utils import _SimpleSegmentationModel |
|
|
|
from frame_field_learning import tta_utils |
|
|
|
|
|
def get_out_channels(module): |
|
if hasattr(module, "out_channels"): |
|
return module.out_channels |
|
children = list(module.children()) |
|
i = 1 |
|
out_channels = None |
|
while out_channels is None and i <= len(children): |
|
last_child = children[-i] |
|
out_channels = get_out_channels(last_child) |
|
i += 1 |
|
|
|
return out_channels |
|
|
|
|
|
class FrameFieldModel(torch.nn.Module): |
|
def __init__(self, config: dict, backbone, train_transform=None, eval_transform=None): |
|
""" |
|
|
|
:param config: |
|
:param backbone: A _SimpleSegmentationModel network, its output features will be used to compute seg and framefield. |
|
:param train_transform: transform applied to the inputs when self.training is True |
|
:param eval_transform: transform applied to the inputs when self.training is False |
|
""" |
|
super(FrameFieldModel, self).__init__() |
|
assert config["compute_seg"] or config["compute_crossfield"], \ |
|
"Model has to compute at least one of those:\n" \ |
|
"\t- segmentation\n" \ |
|
"\t- cross-field" |
|
assert isinstance(backbone, _SimpleSegmentationModel), \ |
|
"backbone should be an instance of _SimpleSegmentationModel" |
|
self.config = config |
|
self.backbone = backbone |
|
self.train_transform = train_transform |
|
self.eval_transform = eval_transform |
|
|
|
backbone_out_features = get_out_channels(self.backbone) |
|
|
|
|
|
seg_channels = 0 |
|
if self.config["compute_seg"]: |
|
seg_channels = self.config["seg_params"]["compute_vertex"]\ |
|
+ self.config["seg_params"]["compute_edge"]\ |
|
+ self.config["seg_params"]["compute_interior"] |
|
self.seg_module = torch.nn.Sequential( |
|
torch.nn.Conv2d(backbone_out_features, backbone_out_features, 3, padding=1), |
|
torch.nn.BatchNorm2d(backbone_out_features), |
|
torch.nn.ELU(), |
|
torch.nn.Conv2d(backbone_out_features, seg_channels, 1), |
|
torch.nn.Sigmoid(),) |
|
|
|
if self.config["compute_crossfield"]: |
|
crossfield_channels = 4 |
|
self.crossfield_module = torch.nn.Sequential( |
|
torch.nn.Conv2d(backbone_out_features + seg_channels, backbone_out_features, 3, padding=1), |
|
torch.nn.BatchNorm2d(backbone_out_features), |
|
torch.nn.ELU(), |
|
torch.nn.Conv2d(backbone_out_features, crossfield_channels, 1), |
|
torch.nn.Tanh(), |
|
) |
|
|
|
def inference(self, image): |
|
outputs = {} |
|
|
|
|
|
backbone_features = self.backbone(image)["out"] |
|
|
|
if self.config["compute_seg"]: |
|
|
|
seg = self.seg_module(backbone_features) |
|
seg_to_cat = seg.clone().detach() |
|
backbone_features = torch.cat([backbone_features, seg_to_cat], dim=1) |
|
outputs["seg"] = seg |
|
|
|
if self.config["compute_crossfield"]: |
|
|
|
crossfield = 2 * self.crossfield_module(backbone_features) |
|
outputs["crossfield"] = crossfield |
|
|
|
return outputs |
|
|
|
|
|
def forward(self, xb, tta=False): |
|
|
|
if self.training: |
|
if self.train_transform is not None: |
|
xb = self.train_transform(xb) |
|
else: |
|
if self.eval_transform is not None: |
|
xb = self.eval_transform(xb) |
|
|
|
if not tta: |
|
final_outputs = self.inference(xb["image"]) |
|
else: |
|
final_outputs = tta_utils.tta_inference(self, xb, self.config["eval_params"]["seg_threshold"]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return final_outputs, xb |
|
|