Spaces:
Sleeping
Sleeping
# -*- coding: utf-8 -*- | |
import torch | |
from torch import nn | |
import torch.nn.functional as F | |
from models.modules import * | |
#from modules import * | |
backbone_dict = {'resnet18': {'models': resnet18, 'out': [64, 128, 256, 512]}, | |
'resnet34': {'models': resnet34, 'out': [64, 128, 256, 512]}, | |
'resnet50': {'models': resnet50, 'out': [256, 512, 1024, 2048]}, | |
'resnet101': {'models': resnet101, 'out': [256, 512, 1024, 2048]}, | |
'resnet152': {'models': resnet152, 'out': [256, 512, 1024, 2048]}, | |
'resnext50_32x4d': {'models': resnext50_32x4d, 'out': [256, 512, 1024, 2048]}, | |
'resnext101_32x8d': {'models': resnext101_32x8d, 'out': [256, 512, 1024, 2048]}, | |
'shufflenetv2': {'models': shufflenet_v2_x1_0, 'out': [24, 116, 232, 464]}, | |
'mobilenetv2': {'models': mobilenet_v2_x1_0, 'out': [24, 40, 160, 160]} | |
} | |
segmentation_head_dict = {'FPN': FPN, 'FPEM_FFM': FPEM_FFM} | |
# 'MobileNetV3_Large': {'models': MobileNetV3_Large, 'out': [24, 40, 160, 160]}, | |
# 'MobileNetV3_Small': {'models': MobileNetV3_Small, 'out': [16, 24, 48, 96]}, | |
# 'shufflenetv2': {'models': shufflenet_v2_x1_0, 'out': [24, 116, 232, 464]}} | |
class Model(nn.Module): | |
def __init__(self, model_config: dict): | |
""" | |
PANnet | |
:param model_config: 模型配置 | |
""" | |
super().__init__() | |
backbone = model_config['backbone'] | |
pretrained = model_config['pretrained'] | |
segmentation_head = model_config['segmentation_head'] | |
assert backbone in backbone_dict, 'backbone must in: {}'.format(backbone_dict) | |
assert segmentation_head in segmentation_head_dict, 'segmentation_head must in: {}'.format( | |
segmentation_head_dict) | |
backbone_model, backbone_out = backbone_dict[backbone]['models'], backbone_dict[backbone]['out'] | |
self.backbone = backbone_model(pretrained=pretrained) | |
self.segmentation_head = segmentation_head_dict[segmentation_head](backbone_out, **model_config) | |
self.name = '{}_{}'.format(backbone, segmentation_head) | |
def forward(self, x): | |
_, _, H, W = x.size() | |
backbone_out = self.backbone(x) | |
segmentation_head_out = self.segmentation_head(backbone_out) | |
y = segmentation_head_out | |
return y | |
if __name__ == '__main__': | |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
print(device) | |
x = torch.zeros(1, 3, 640, 640).to(device) | |
model_config = { | |
'backbone': 'mobilenetv2', | |
'fpem_repeat': 2, # fpem模块重复的次数 | |
'pretrained': False, # backbone 是否使用imagesnet的预训练模型 | |
'segmentation_head': 'FPN' # 分割头,FPN or FPEM_FFM | |
} | |
model = Model(model_config=model_config).to(device) | |
y = model(x) | |
print(model) | |
#torch.save(model.state_dict(), 'PAN.pth') | |