Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
try: | |
import timm | |
except ImportError: | |
timm = None | |
from mmengine.model import BaseModule | |
from mmengine.registry import MODELS as MMENGINE_MODELS | |
from mmseg.registry import MODELS | |
class TIMMBackbone(BaseModule): | |
"""Wrapper to use backbones from timm library. More details can be found in | |
`timm <https://github.com/rwightman/pytorch-image-models>`_ . | |
Args: | |
model_name (str): Name of timm model to instantiate. | |
pretrained (bool): Load pretrained weights if True. | |
checkpoint_path (str): Path of checkpoint to load after | |
model is initialized. | |
in_channels (int): Number of input image channels. Default: 3. | |
init_cfg (dict, optional): Initialization config dict | |
**kwargs: Other timm & model specific arguments. | |
""" | |
def __init__( | |
self, | |
model_name, | |
features_only=True, | |
pretrained=True, | |
checkpoint_path='', | |
in_channels=3, | |
init_cfg=None, | |
**kwargs, | |
): | |
if timm is None: | |
raise RuntimeError('timm is not installed') | |
super().__init__(init_cfg) | |
if 'norm_layer' in kwargs: | |
kwargs['norm_layer'] = MMENGINE_MODELS.get(kwargs['norm_layer']) | |
self.timm_model = timm.create_model( | |
model_name=model_name, | |
features_only=features_only, | |
pretrained=pretrained, | |
in_chans=in_channels, | |
checkpoint_path=checkpoint_path, | |
**kwargs, | |
) | |
# Make unused parameters None | |
self.timm_model.global_pool = None | |
self.timm_model.fc = None | |
self.timm_model.classifier = None | |
# Hack to use pretrained weights from timm | |
if pretrained or checkpoint_path: | |
self._is_init = True | |
def forward(self, x): | |
features = self.timm_model(x) | |
return features | |