Spaces:
Build error
Build error
File size: 5,543 Bytes
c426e13 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
from functools import partial
import torch
from timm.models.efficientnet import tf_efficientnet_b3_ns, tf_efficientnet_b5_ns
from torch import nn
from torch.nn import Dropout2d, Conv2d
from torch.nn.modules.dropout import Dropout
from torch.nn.modules.linear import Linear
from torch.nn.modules.pooling import AdaptiveAvgPool2d
from torch.nn.modules.upsampling import UpsamplingBilinear2d
encoder_params = {
"tf_efficientnet_b3_ns": {
"features": 1536,
"filters": [40, 32, 48, 136, 1536],
"decoder_filters": [64, 128, 256, 256],
"init_op": partial(tf_efficientnet_b3_ns, pretrained=True, drop_path_rate=0.2)
},
"tf_efficientnet_b5_ns": {
"features": 2048,
"filters": [48, 40, 64, 176, 2048],
"decoder_filters": [64, 128, 256, 256],
"init_op": partial(tf_efficientnet_b5_ns, pretrained=True, drop_path_rate=0.2)
},
}
class DecoderBlock(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.layer = nn.Sequential(
nn.Upsample(scale_factor=2),
nn.Conv2d(in_channels, out_channels, 3, padding=1),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.layer(x)
class ConcatBottleneck(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.seq = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, padding=1),
nn.ReLU(inplace=True)
)
def forward(self, dec, enc):
x = torch.cat([dec, enc], dim=1)
return self.seq(x)
class Decoder(nn.Module):
def __init__(self, decoder_filters, filters, upsample_filters=None,
decoder_block=DecoderBlock, bottleneck=ConcatBottleneck, dropout=0):
super().__init__()
self.decoder_filters = decoder_filters
self.filters = filters
self.decoder_block = decoder_block
self.decoder_stages = nn.ModuleList([self._get_decoder(idx) for idx in range(0, len(decoder_filters))])
self.bottlenecks = nn.ModuleList([bottleneck(self.filters[-i - 2] + f, f)
for i, f in enumerate(reversed(decoder_filters))])
self.dropout = Dropout2d(dropout) if dropout > 0 else None
self.last_block = None
if upsample_filters:
self.last_block = decoder_block(decoder_filters[0], out_channels=upsample_filters)
else:
self.last_block = UpsamplingBilinear2d(scale_factor=2)
def forward(self, encoder_results: list):
x = encoder_results[0]
bottlenecks = self.bottlenecks
for idx, bottleneck in enumerate(bottlenecks):
rev_idx = - (idx + 1)
x = self.decoder_stages[rev_idx](x)
x = bottleneck(x, encoder_results[-rev_idx])
if self.last_block:
x = self.last_block(x)
if self.dropout:
x = self.dropout(x)
return x
def _get_decoder(self, layer):
idx = layer + 1
if idx == len(self.decoder_filters):
in_channels = self.filters[idx]
else:
in_channels = self.decoder_filters[idx]
return self.decoder_block(in_channels, self.decoder_filters[max(layer, 0)])
def _initialize_weights(module):
for m in module.modules():
if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Linear):
m.weight.data = nn.init.kaiming_normal_(m.weight.data)
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
class EfficientUnetClassifier(nn.Module):
def __init__(self, encoder, dropout_rate=0.5) -> None:
super().__init__()
self.decoder = Decoder(decoder_filters=encoder_params[encoder]["decoder_filters"],
filters=encoder_params[encoder]["filters"])
self.avg_pool = AdaptiveAvgPool2d((1, 1))
self.dropout = Dropout(dropout_rate)
self.fc = Linear(encoder_params[encoder]["features"], 1)
self.final = Conv2d(encoder_params[encoder]["decoder_filters"][0], out_channels=1, kernel_size=1, bias=False)
_initialize_weights(self)
self.encoder = encoder_params[encoder]["init_op"]()
def get_encoder_features(self, x):
encoder_results = []
x = self.encoder.conv_stem(x)
x = self.encoder.bn1(x)
x = self.encoder.act1(x)
encoder_results.append(x)
x = self.encoder.blocks[:2](x)
encoder_results.append(x)
x = self.encoder.blocks[2:3](x)
encoder_results.append(x)
x = self.encoder.blocks[3:5](x)
encoder_results.append(x)
x = self.encoder.blocks[5:](x)
x = self.encoder.conv_head(x)
x = self.encoder.bn2(x)
x = self.encoder.act2(x)
encoder_results.append(x)
encoder_results = list(reversed(encoder_results))
return encoder_results
def forward(self, x):
encoder_results = self.get_encoder_features(x)
seg = self.final(self.decoder(encoder_results))
x = encoder_results[0]
x = self.avg_pool(x).flatten(1)
x = self.dropout(x)
x = self.fc(x)
return x, seg
if __name__ == '__main__':
model = EfficientUnetClassifier("tf_efficientnet_b5_ns")
model.eval()
with torch.no_grad():
input = torch.rand(4, 3, 224, 224)
print(model(input))
|