|
import os |
|
|
|
from torch.utils import model_zoo |
|
|
|
from classes.fc4.squeezenet.SqueezeNet import SqueezeNet |
|
|
|
model_urls = { |
|
1.0: 'https://download.pytorch.org/models/squeezenet1_0-a815701f.pth', |
|
1.1: 'https://download.pytorch.org/models/squeezenet1_1-f364aa15.pth', |
|
} |
|
|
|
|
|
class SqueezeNetLoader: |
|
def __init__(self, version: float = 1.1): |
|
self.__version = version |
|
self.__model = SqueezeNet(self.__version) |
|
|
|
def load(self, pretrained: bool = False) -> SqueezeNet: |
|
""" |
|
Returns the specified version of SqueezeNet |
|
@param pretrained: if True, returns a model pre-trained on ImageNet |
|
""" |
|
if pretrained: |
|
path_to_local = os.path.join("assets", "pretrained") |
|
os.environ['TORCH_HOME'] = path_to_local |
|
self.__model.load_state_dict(model_zoo.load_url(model_urls[self.__version])) |
|
return self.__model |
|
|