ameerazam08's picture
Upload folder using huggingface_hub
e34aada verified
raw
history blame
2.44 kB
import timm
import functools
import torch.utils.model_zoo as model_zoo
from .resnet import resnet_encoders
encoders = {}
encoders.update(resnet_encoders)
def get_encoder(name, in_channels=3, depth=5, weights=None, output_stride=32, **kwargs):
try:
Encoder = encoders[name]["encoder"]
except KeyError:
raise KeyError("Wrong encoder name `{}`, supported encoders: {}".format(name, list(encoders.keys())))
params = encoders[name]["params"]
params.update(depth=depth)
encoder = Encoder(**params)
if weights is not None:
try:
settings = encoders[name]["pretrained_settings"][weights]
except KeyError:
raise KeyError(
"Wrong pretrained weights `{}` for encoder `{}`. Available options are: {}".format(
weights,
name,
list(encoders[name]["pretrained_settings"].keys()),
)
)
encoder.load_state_dict(model_zoo.load_url(settings["url"]))
encoder.set_in_channels(in_channels, pretrained=weights is not None)
if output_stride != 32:
encoder.make_dilated(output_stride)
return encoder
def get_encoder_names():
return list(encoders.keys())
def get_preprocessing_params(encoder_name, pretrained="imagenet"):
if encoder_name.startswith("tu-"):
encoder_name = encoder_name[3:]
if not timm.models.is_model_pretrained(encoder_name):
raise ValueError(f"{encoder_name} does not have pretrained weights and preprocessing parameters")
settings = timm.models.get_pretrained_cfg(encoder_name)
else:
all_settings = encoders[encoder_name]["pretrained_settings"]
if pretrained not in all_settings.keys():
raise ValueError("Available pretrained options {}".format(all_settings.keys()))
settings = all_settings[pretrained]
formatted_settings = {}
formatted_settings["input_space"] = settings.get("input_space", "RGB")
formatted_settings["input_range"] = list(settings.get("input_range", [0, 1]))
formatted_settings["mean"] = list(settings.get("mean"))
formatted_settings["std"] = list(settings.get("std"))
return formatted_settings
def get_preprocessing_fn(encoder_name, pretrained="imagenet"):
params = get_preprocessing_params(encoder_name, pretrained=pretrained)
return functools.partial(preprocess_input, **params)