Spaces:
Build error
Build error
from functools import partial | |
import torch | |
from timm.models.efficientnet import tf_efficientnet_b3_ns, tf_efficientnet_b5_ns | |
from torch import nn | |
from torch.nn import Dropout2d, Conv2d | |
from torch.nn.modules.dropout import Dropout | |
from torch.nn.modules.linear import Linear | |
from torch.nn.modules.pooling import AdaptiveAvgPool2d | |
from torch.nn.modules.upsampling import UpsamplingBilinear2d | |
encoder_params = { | |
"tf_efficientnet_b3_ns": { | |
"features": 1536, | |
"filters": [40, 32, 48, 136, 1536], | |
"decoder_filters": [64, 128, 256, 256], | |
"init_op": partial(tf_efficientnet_b3_ns, pretrained=True, drop_path_rate=0.2) | |
}, | |
"tf_efficientnet_b5_ns": { | |
"features": 2048, | |
"filters": [48, 40, 64, 176, 2048], | |
"decoder_filters": [64, 128, 256, 256], | |
"init_op": partial(tf_efficientnet_b5_ns, pretrained=True, drop_path_rate=0.2) | |
}, | |
} | |
class DecoderBlock(nn.Module): | |
def __init__(self, in_channels, out_channels): | |
super().__init__() | |
self.layer = nn.Sequential( | |
nn.Upsample(scale_factor=2), | |
nn.Conv2d(in_channels, out_channels, 3, padding=1), | |
nn.ReLU(inplace=True) | |
) | |
def forward(self, x): | |
return self.layer(x) | |
class ConcatBottleneck(nn.Module): | |
def __init__(self, in_channels, out_channels): | |
super().__init__() | |
self.seq = nn.Sequential( | |
nn.Conv2d(in_channels, out_channels, 3, padding=1), | |
nn.ReLU(inplace=True) | |
) | |
def forward(self, dec, enc): | |
x = torch.cat([dec, enc], dim=1) | |
return self.seq(x) | |
class Decoder(nn.Module): | |
def __init__(self, decoder_filters, filters, upsample_filters=None, | |
decoder_block=DecoderBlock, bottleneck=ConcatBottleneck, dropout=0): | |
super().__init__() | |
self.decoder_filters = decoder_filters | |
self.filters = filters | |
self.decoder_block = decoder_block | |
self.decoder_stages = nn.ModuleList([self._get_decoder(idx) for idx in range(0, len(decoder_filters))]) | |
self.bottlenecks = nn.ModuleList([bottleneck(self.filters[-i - 2] + f, f) | |
for i, f in enumerate(reversed(decoder_filters))]) | |
self.dropout = Dropout2d(dropout) if dropout > 0 else None | |
self.last_block = None | |
if upsample_filters: | |
self.last_block = decoder_block(decoder_filters[0], out_channels=upsample_filters) | |
else: | |
self.last_block = UpsamplingBilinear2d(scale_factor=2) | |
def forward(self, encoder_results: list): | |
x = encoder_results[0] | |
bottlenecks = self.bottlenecks | |
for idx, bottleneck in enumerate(bottlenecks): | |
rev_idx = - (idx + 1) | |
x = self.decoder_stages[rev_idx](x) | |
x = bottleneck(x, encoder_results[-rev_idx]) | |
if self.last_block: | |
x = self.last_block(x) | |
if self.dropout: | |
x = self.dropout(x) | |
return x | |
def _get_decoder(self, layer): | |
idx = layer + 1 | |
if idx == len(self.decoder_filters): | |
in_channels = self.filters[idx] | |
else: | |
in_channels = self.decoder_filters[idx] | |
return self.decoder_block(in_channels, self.decoder_filters[max(layer, 0)]) | |
def _initialize_weights(module): | |
for m in module.modules(): | |
if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Linear): | |
m.weight.data = nn.init.kaiming_normal_(m.weight.data) | |
if m.bias is not None: | |
m.bias.data.zero_() | |
elif isinstance(m, nn.BatchNorm2d): | |
m.weight.data.fill_(1) | |
m.bias.data.zero_() | |
class EfficientUnetClassifier(nn.Module): | |
def __init__(self, encoder, dropout_rate=0.5) -> None: | |
super().__init__() | |
self.decoder = Decoder(decoder_filters=encoder_params[encoder]["decoder_filters"], | |
filters=encoder_params[encoder]["filters"]) | |
self.avg_pool = AdaptiveAvgPool2d((1, 1)) | |
self.dropout = Dropout(dropout_rate) | |
self.fc = Linear(encoder_params[encoder]["features"], 1) | |
self.final = Conv2d(encoder_params[encoder]["decoder_filters"][0], out_channels=1, kernel_size=1, bias=False) | |
_initialize_weights(self) | |
self.encoder = encoder_params[encoder]["init_op"]() | |
def get_encoder_features(self, x): | |
encoder_results = [] | |
x = self.encoder.conv_stem(x) | |
x = self.encoder.bn1(x) | |
x = self.encoder.act1(x) | |
encoder_results.append(x) | |
x = self.encoder.blocks[:2](x) | |
encoder_results.append(x) | |
x = self.encoder.blocks[2:3](x) | |
encoder_results.append(x) | |
x = self.encoder.blocks[3:5](x) | |
encoder_results.append(x) | |
x = self.encoder.blocks[5:](x) | |
x = self.encoder.conv_head(x) | |
x = self.encoder.bn2(x) | |
x = self.encoder.act2(x) | |
encoder_results.append(x) | |
encoder_results = list(reversed(encoder_results)) | |
return encoder_results | |
def forward(self, x): | |
encoder_results = self.get_encoder_features(x) | |
seg = self.final(self.decoder(encoder_results)) | |
x = encoder_results[0] | |
x = self.avg_pool(x).flatten(1) | |
x = self.dropout(x) | |
x = self.fc(x) | |
return x, seg | |
if __name__ == '__main__': | |
model = EfficientUnetClassifier("tf_efficientnet_b5_ns") | |
model.eval() | |
with torch.no_grad(): | |
input = torch.rand(4, 3, 224, 224) | |
print(model(input)) | |