import torch.nn as nn | |
from .encoder import resnet34 | |
from .decoder import DeepLabV3Decoder | |
class DeepLabV3(nn.Module): | |
def __init__(self, input_channels): | |
super().__init__() | |
self.encoder = resnet34(input_channels=input_channels) | |
self.decoder = DeepLabV3Decoder(in_channels=128) | |
def forward(self, x): | |
feat = self.encoder(x) | |
out = self.decoder(*feat) | |
return out | |