Artyom
MiAlgo
82567db verified
import torch
import torch.nn as nn
import torch.nn.init as init
from classes.fc4.squeezenet.Fire import Fire
"""
This is the standard SqueezeNet implementation included in PyTorch at:
https://github.com/pytorch/vision/blob/072d8b2280569a2d13b91d3ed51546d201a57366/torchvision/models/squeezenet.py
SqueezeNet 1.0
-> Architecture from https://arxiv.org/abs/1602.07360, 'SqueezeNet: AlexNet-level accuracy with 50x fewer parameters and <0.5MB model size'.
SqueezeNet 1.1 (has 2.4x less computation and slightly fewer parameters than 1.0, without sacrificing accuracy)
-> Architecture from: <https://github.com/DeepScale/SqueezeNet/tree/master/SqueezeNet_v1.1>`.
"""
class SqueezeNet(nn.Module):
def __init__(self, version: float = 1.0, num_classes: int = 1000):
super().__init__()
self.num_classes = num_classes
if version == 1.0:
self.features = nn.Sequential(
nn.Conv2d(3, 96, kernel_size=7, stride=2),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
Fire(96, 16, 64, 64),
Fire(128, 16, 64, 64),
Fire(128, 32, 128, 128),
nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
Fire(256, 32, 128, 128),
Fire(256, 48, 192, 192),
Fire(384, 48, 192, 192),
Fire(384, 64, 256, 256),
nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
Fire(512, 64, 256, 256),
)
elif version == 1.1:
self.features = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=3, stride=2),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
Fire(64, 16, 64, 64),
Fire(128, 16, 64, 64),
nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
Fire(128, 32, 128, 128),
Fire(256, 32, 128, 128),
nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
Fire(256, 48, 192, 192),
Fire(384, 48, 192, 192),
Fire(384, 64, 256, 256),
Fire(512, 64, 256, 256),
)
else:
raise ValueError("Unsupported SqueezeNet version {version}: 1.0 or 1.1 expected".format(version=version))
# Final convolution is initialized differently form the rest
final_conv = nn.Conv2d(512, self.num_classes, kernel_size=1)
self.classifier = nn.Sequential(
nn.Dropout(p=0.5),
final_conv,
nn.ReLU(inplace=True),
nn.AdaptiveAvgPool2d((1, 1))
)
for m in self.modules():
if isinstance(m, nn.Conv2d):
if m is final_conv:
init.normal_(m.weight, mean=0.0, std=0.01)
else:
init.kaiming_uniform_(m.weight)
if m.bias is not None:
init.constant_(m.bias, 0)
def forward(self, x: torch):
x = self.features(x)
x = self.classifier(x)
return x.view(x.size(0), self.num_classes)