File size: 4,501 Bytes
abd2a81 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 |
import os
import torch
import torchvision
from lydorn_utils import print_utils
def get_backbone(backbone_params):
set_download_dir()
if backbone_params["name"] == "unet":
from torchvision.models.segmentation._utils import _SimpleSegmentationModel
from frame_field_learning.unet import UNetBackbone
backbone = UNetBackbone(backbone_params["input_features"], backbone_params["features"])
backbone = _SimpleSegmentationModel(backbone, classifier=torch.nn.Identity())
elif backbone_params["name"] == "fcn50":
backbone = torchvision.models.segmentation.fcn_resnet50(pretrained=backbone_params["pretrained"],
num_classes=21)
backbone.classifier = torch.nn.Sequential(*list(backbone.classifier.children())[:-1],
torch.nn.Conv2d(512, backbone_params["features"], kernel_size=(1, 1),
stride=(1, 1)))
elif backbone_params["name"] == "fcn101":
backbone = torchvision.models.segmentation.fcn_resnet101(pretrained=backbone_params["pretrained"],
num_classes=21)
backbone.classifier = torch.nn.Sequential(*list(backbone.classifier.children())[:-1],
torch.nn.Conv2d(512, backbone_params["features"], kernel_size=(1, 1),
stride=(1, 1)))
elif backbone_params["name"] == "deeplab50":
backbone = torchvision.models.segmentation.deeplabv3_resnet50(pretrained=backbone_params["pretrained"],
num_classes=21)
backbone.classifier = torch.nn.Sequential(*list(backbone.classifier.children())[:-1],
torch.nn.Conv2d(256, backbone_params["features"], kernel_size=(1, 1),
stride=(1, 1)))
elif backbone_params["name"] == "deeplab101":
backbone = torchvision.models.segmentation.deeplabv3_resnet101(pretrained=backbone_params["pretrained"],
num_classes=21)
backbone.classifier = torch.nn.Sequential(*list(backbone.classifier.children())[:-1],
torch.nn.Conv2d(256, backbone_params["features"], kernel_size=(1, 1),
stride=(1, 1)))
elif backbone_params["name"] == "unet_resnet":
from torchvision.models.segmentation._utils import _SimpleSegmentationModel
from frame_field_learning.unet_resnet import UNetResNetBackbone
backbone = UNetResNetBackbone(backbone_params["encoder_depth"], num_filters=backbone_params["num_filters"],
dropout_2d=backbone_params["dropout_2d"],
pretrained=backbone_params["pretrained"],
is_deconv=backbone_params["is_deconv"])
backbone = _SimpleSegmentationModel(backbone, classifier=torch.nn.Identity())
elif backbone_params["name"] == "ictnet":
from torchvision.models.segmentation._utils import _SimpleSegmentationModel
from frame_field_learning.ictnet import ICTNetBackbone
backbone = ICTNetBackbone(in_channels=backbone_params["in_channels"],
out_channels=backbone_params["out_channels"],
preset_model=backbone_params["preset_model"],
dropout_2d=backbone_params["dropout_2d"],
efficient=backbone_params["efficient"])
backbone = _SimpleSegmentationModel(backbone, classifier=torch.nn.Identity())
else:
print_utils.print_error("ERROR: config[\"backbone_params\"][\"name\"] = \"{}\" is an unknown backbone!"
"If it is a new backbone you want to use, "
"add it in backbone.py's get_backbone() function.".format(backbone_params["name"]))
raise RuntimeError("Specified backbone {} unknown".format(backbone_params["name"]))
return backbone
def set_download_dir():
os.environ['TORCH_HOME'] = 'models' # setting the environment variable
|