|
import torch |
|
import torch.nn as nn |
|
import torch.nn.init as init |
|
|
|
from classes.fc4.squeezenet.Fire import Fire |
|
|
|
""" |
|
This is the standard SqueezeNet implementation included in PyTorch at: |
|
https://github.com/pytorch/vision/blob/072d8b2280569a2d13b91d3ed51546d201a57366/torchvision/models/squeezenet.py |
|
|
|
SqueezeNet 1.0 |
|
-> Architecture from https://arxiv.org/abs/1602.07360, 'SqueezeNet: AlexNet-level accuracy with 50x fewer parameters and <0.5MB model size'. |
|
|
|
SqueezeNet 1.1 (has 2.4x less computation and slightly fewer parameters than 1.0, without sacrificing accuracy) |
|
-> Architecture from: <https://github.com/DeepScale/SqueezeNet/tree/master/SqueezeNet_v1.1>`. |
|
""" |
|
|
|
|
|
class SqueezeNet(nn.Module): |
|
|
|
def __init__(self, version: float = 1.0, num_classes: int = 1000): |
|
super().__init__() |
|
|
|
self.num_classes = num_classes |
|
|
|
if version == 1.0: |
|
self.features = nn.Sequential( |
|
nn.Conv2d(3, 96, kernel_size=7, stride=2), |
|
nn.ReLU(inplace=True), |
|
nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), |
|
Fire(96, 16, 64, 64), |
|
Fire(128, 16, 64, 64), |
|
Fire(128, 32, 128, 128), |
|
nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), |
|
Fire(256, 32, 128, 128), |
|
Fire(256, 48, 192, 192), |
|
Fire(384, 48, 192, 192), |
|
Fire(384, 64, 256, 256), |
|
nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), |
|
Fire(512, 64, 256, 256), |
|
) |
|
elif version == 1.1: |
|
self.features = nn.Sequential( |
|
nn.Conv2d(3, 64, kernel_size=3, stride=2), |
|
nn.ReLU(inplace=True), |
|
nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), |
|
Fire(64, 16, 64, 64), |
|
Fire(128, 16, 64, 64), |
|
nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), |
|
Fire(128, 32, 128, 128), |
|
Fire(256, 32, 128, 128), |
|
nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), |
|
Fire(256, 48, 192, 192), |
|
Fire(384, 48, 192, 192), |
|
Fire(384, 64, 256, 256), |
|
Fire(512, 64, 256, 256), |
|
) |
|
else: |
|
raise ValueError("Unsupported SqueezeNet version {version}: 1.0 or 1.1 expected".format(version=version)) |
|
|
|
|
|
final_conv = nn.Conv2d(512, self.num_classes, kernel_size=1) |
|
self.classifier = nn.Sequential( |
|
nn.Dropout(p=0.5), |
|
final_conv, |
|
nn.ReLU(inplace=True), |
|
nn.AdaptiveAvgPool2d((1, 1)) |
|
) |
|
|
|
for m in self.modules(): |
|
if isinstance(m, nn.Conv2d): |
|
if m is final_conv: |
|
init.normal_(m.weight, mean=0.0, std=0.01) |
|
else: |
|
init.kaiming_uniform_(m.weight) |
|
if m.bias is not None: |
|
init.constant_(m.bias, 0) |
|
|
|
def forward(self, x: torch): |
|
x = self.features(x) |
|
x = self.classifier(x) |
|
return x.view(x.size(0), self.num_classes) |
|
|