File size: 4,501 Bytes
abd2a81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import os

import torch
import torchvision

from lydorn_utils import print_utils


def get_backbone(backbone_params):
    set_download_dir()
    if backbone_params["name"] == "unet":
        from torchvision.models.segmentation._utils import _SimpleSegmentationModel
        from frame_field_learning.unet import UNetBackbone

        backbone = UNetBackbone(backbone_params["input_features"], backbone_params["features"])
        backbone = _SimpleSegmentationModel(backbone, classifier=torch.nn.Identity())
    elif backbone_params["name"] == "fcn50":
        backbone = torchvision.models.segmentation.fcn_resnet50(pretrained=backbone_params["pretrained"],
                                                                num_classes=21)
        backbone.classifier = torch.nn.Sequential(*list(backbone.classifier.children())[:-1],
                                                  torch.nn.Conv2d(512, backbone_params["features"], kernel_size=(1, 1),
                                                                  stride=(1, 1)))
    elif backbone_params["name"] == "fcn101":
        backbone = torchvision.models.segmentation.fcn_resnet101(pretrained=backbone_params["pretrained"],
                                                                 num_classes=21)
        backbone.classifier = torch.nn.Sequential(*list(backbone.classifier.children())[:-1],
                                                  torch.nn.Conv2d(512, backbone_params["features"], kernel_size=(1, 1),
                                                                  stride=(1, 1)))

    elif backbone_params["name"] == "deeplab50":
        backbone = torchvision.models.segmentation.deeplabv3_resnet50(pretrained=backbone_params["pretrained"],
                                                                      num_classes=21)
        backbone.classifier = torch.nn.Sequential(*list(backbone.classifier.children())[:-1],
                                                  torch.nn.Conv2d(256, backbone_params["features"], kernel_size=(1, 1),
                                                                  stride=(1, 1)))
    elif backbone_params["name"] == "deeplab101":
        backbone = torchvision.models.segmentation.deeplabv3_resnet101(pretrained=backbone_params["pretrained"],
                                                                       num_classes=21)
        backbone.classifier = torch.nn.Sequential(*list(backbone.classifier.children())[:-1],
                                                  torch.nn.Conv2d(256, backbone_params["features"], kernel_size=(1, 1),
                                                                  stride=(1, 1)))
    elif backbone_params["name"] == "unet_resnet":
        from torchvision.models.segmentation._utils import _SimpleSegmentationModel
        from frame_field_learning.unet_resnet import UNetResNetBackbone

        backbone = UNetResNetBackbone(backbone_params["encoder_depth"], num_filters=backbone_params["num_filters"],
                                      dropout_2d=backbone_params["dropout_2d"],
                                      pretrained=backbone_params["pretrained"],
                                      is_deconv=backbone_params["is_deconv"])
        backbone = _SimpleSegmentationModel(backbone, classifier=torch.nn.Identity())

    elif backbone_params["name"] == "ictnet":
        from torchvision.models.segmentation._utils import _SimpleSegmentationModel
        from frame_field_learning.ictnet import ICTNetBackbone

        backbone = ICTNetBackbone(in_channels=backbone_params["in_channels"],
                                  out_channels=backbone_params["out_channels"],
                                  preset_model=backbone_params["preset_model"],
                                  dropout_2d=backbone_params["dropout_2d"],
                                  efficient=backbone_params["efficient"])
        backbone = _SimpleSegmentationModel(backbone, classifier=torch.nn.Identity())
    else:
        print_utils.print_error("ERROR: config[\"backbone_params\"][\"name\"] = \"{}\" is an unknown backbone!"
                                "If it is a new backbone you want to use, "
                                "add it in backbone.py's get_backbone() function.".format(backbone_params["name"]))
        raise RuntimeError("Specified backbone {} unknown".format(backbone_params["name"]))
    return backbone


def set_download_dir():
    os.environ['TORCH_HOME'] = 'models'  # setting the environment variable