jmliu's picture
Add app
0742dfe
# -*- coding: utf-8 -*-
import torch
from torch import nn
import torch.nn.functional as F
from models.modules import *
#from modules import *
backbone_dict = {'resnet18': {'models': resnet18, 'out': [64, 128, 256, 512]},
'resnet34': {'models': resnet34, 'out': [64, 128, 256, 512]},
'resnet50': {'models': resnet50, 'out': [256, 512, 1024, 2048]},
'resnet101': {'models': resnet101, 'out': [256, 512, 1024, 2048]},
'resnet152': {'models': resnet152, 'out': [256, 512, 1024, 2048]},
'resnext50_32x4d': {'models': resnext50_32x4d, 'out': [256, 512, 1024, 2048]},
'resnext101_32x8d': {'models': resnext101_32x8d, 'out': [256, 512, 1024, 2048]},
'shufflenetv2': {'models': shufflenet_v2_x1_0, 'out': [24, 116, 232, 464]},
'mobilenetv2': {'models': mobilenet_v2_x1_0, 'out': [24, 40, 160, 160]}
}
segmentation_head_dict = {'FPN': FPN, 'FPEM_FFM': FPEM_FFM}
# 'MobileNetV3_Large': {'models': MobileNetV3_Large, 'out': [24, 40, 160, 160]},
# 'MobileNetV3_Small': {'models': MobileNetV3_Small, 'out': [16, 24, 48, 96]},
# 'shufflenetv2': {'models': shufflenet_v2_x1_0, 'out': [24, 116, 232, 464]}}
class Model(nn.Module):
def __init__(self, model_config: dict):
"""
PANnet
:param model_config: 模型配置
"""
super().__init__()
backbone = model_config['backbone']
pretrained = model_config['pretrained']
segmentation_head = model_config['segmentation_head']
assert backbone in backbone_dict, 'backbone must in: {}'.format(backbone_dict)
assert segmentation_head in segmentation_head_dict, 'segmentation_head must in: {}'.format(
segmentation_head_dict)
backbone_model, backbone_out = backbone_dict[backbone]['models'], backbone_dict[backbone]['out']
self.backbone = backbone_model(pretrained=pretrained)
self.segmentation_head = segmentation_head_dict[segmentation_head](backbone_out, **model_config)
self.name = '{}_{}'.format(backbone, segmentation_head)
def forward(self, x):
_, _, H, W = x.size()
backbone_out = self.backbone(x)
segmentation_head_out = self.segmentation_head(backbone_out)
y = segmentation_head_out
return y
if __name__ == '__main__':
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
x = torch.zeros(1, 3, 640, 640).to(device)
model_config = {
'backbone': 'mobilenetv2',
'fpem_repeat': 2, # fpem模块重复的次数
'pretrained': False, # backbone 是否使用imagesnet的预训练模型
'segmentation_head': 'FPN' # 分割头,FPN or FPEM_FFM
}
model = Model(model_config=model_config).to(device)
y = model(x)
print(model)
#torch.save(model.state_dict(), 'PAN.pth')