|
import os |
|
|
|
import torch |
|
import torchvision |
|
|
|
from lydorn_utils import print_utils |
|
|
|
|
|
def get_backbone(backbone_params): |
|
set_download_dir() |
|
if backbone_params["name"] == "unet": |
|
from torchvision.models.segmentation._utils import _SimpleSegmentationModel |
|
from frame_field_learning.unet import UNetBackbone |
|
|
|
backbone = UNetBackbone(backbone_params["input_features"], backbone_params["features"]) |
|
backbone = _SimpleSegmentationModel(backbone, classifier=torch.nn.Identity()) |
|
elif backbone_params["name"] == "fcn50": |
|
backbone = torchvision.models.segmentation.fcn_resnet50(pretrained=backbone_params["pretrained"], |
|
num_classes=21) |
|
backbone.classifier = torch.nn.Sequential(*list(backbone.classifier.children())[:-1], |
|
torch.nn.Conv2d(512, backbone_params["features"], kernel_size=(1, 1), |
|
stride=(1, 1))) |
|
elif backbone_params["name"] == "fcn101": |
|
backbone = torchvision.models.segmentation.fcn_resnet101(pretrained=backbone_params["pretrained"], |
|
num_classes=21) |
|
backbone.classifier = torch.nn.Sequential(*list(backbone.classifier.children())[:-1], |
|
torch.nn.Conv2d(512, backbone_params["features"], kernel_size=(1, 1), |
|
stride=(1, 1))) |
|
|
|
elif backbone_params["name"] == "deeplab50": |
|
backbone = torchvision.models.segmentation.deeplabv3_resnet50(pretrained=backbone_params["pretrained"], |
|
num_classes=21) |
|
backbone.classifier = torch.nn.Sequential(*list(backbone.classifier.children())[:-1], |
|
torch.nn.Conv2d(256, backbone_params["features"], kernel_size=(1, 1), |
|
stride=(1, 1))) |
|
elif backbone_params["name"] == "deeplab101": |
|
backbone = torchvision.models.segmentation.deeplabv3_resnet101(pretrained=backbone_params["pretrained"], |
|
num_classes=21) |
|
backbone.classifier = torch.nn.Sequential(*list(backbone.classifier.children())[:-1], |
|
torch.nn.Conv2d(256, backbone_params["features"], kernel_size=(1, 1), |
|
stride=(1, 1))) |
|
elif backbone_params["name"] == "unet_resnet": |
|
from torchvision.models.segmentation._utils import _SimpleSegmentationModel |
|
from frame_field_learning.unet_resnet import UNetResNetBackbone |
|
|
|
backbone = UNetResNetBackbone(backbone_params["encoder_depth"], num_filters=backbone_params["num_filters"], |
|
dropout_2d=backbone_params["dropout_2d"], |
|
pretrained=backbone_params["pretrained"], |
|
is_deconv=backbone_params["is_deconv"]) |
|
backbone = _SimpleSegmentationModel(backbone, classifier=torch.nn.Identity()) |
|
|
|
elif backbone_params["name"] == "ictnet": |
|
from torchvision.models.segmentation._utils import _SimpleSegmentationModel |
|
from frame_field_learning.ictnet import ICTNetBackbone |
|
|
|
backbone = ICTNetBackbone(in_channels=backbone_params["in_channels"], |
|
out_channels=backbone_params["out_channels"], |
|
preset_model=backbone_params["preset_model"], |
|
dropout_2d=backbone_params["dropout_2d"], |
|
efficient=backbone_params["efficient"]) |
|
backbone = _SimpleSegmentationModel(backbone, classifier=torch.nn.Identity()) |
|
else: |
|
print_utils.print_error("ERROR: config[\"backbone_params\"][\"name\"] = \"{}\" is an unknown backbone!" |
|
"If it is a new backbone you want to use, " |
|
"add it in backbone.py's get_backbone() function.".format(backbone_params["name"])) |
|
raise RuntimeError("Specified backbone {} unknown".format(backbone_params["name"])) |
|
return backbone |
|
|
|
|
|
def set_download_dir(): |
|
os.environ['TORCH_HOME'] = 'models' |
|
|