File size: 805 Bytes
a6cbc5d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
from segmentation_models_pytorch.encoders import encoders
from segmentation_models_pytorch import Unet
import torch

# Override pretrained settings for your weights
encoders["efficientnet-b7"]["pretrained_settings"]["imagenet-micronet"] = {
    "url": "https://huggingface.co/jstuckner/microscopy-efficientnet-b7-imagenet-micronet/resolve/main/efficientnet-b7_imagenet-micronet_weights.pth",
    "input_space": "RGB",
    "input_range": [0, 1],
    "mean": [0.485, 0.456, 0.406],
    "std": [0.229, 0.224, 0.225],
}

# Use as normal
model = Unet(
    encoder_name="efficientnet-b7",
    encoder_weights="imagenet-micronet",
    classes=1,
    activation=None,
)

# Test input
x = torch.randn(1, 3, 256, 256)
with torch.no_grad():
    y = model(x)
print("Output shape:", y.shape)