import torch | |
import torch.nn as nn | |
from mono.utils.comm import get_func | |
class DensePredModel(nn.Module): | |
def __init__(self, cfg): | |
super(DensePredModel, self).__init__() | |
self.encoder = get_func('mono.model.' + cfg.model.backbone.prefix + cfg.model.backbone.type)(**cfg.model.backbone) | |
self.decoder = get_func('mono.model.' + cfg.model.decode_head.prefix + cfg.model.decode_head.type)(cfg) | |
# try: | |
# decoder_compiled = torch.compile(decoder, mode='max-autotune') | |
# "Decoder compile finished" | |
# self.decoder = decoder_compiled | |
# except: | |
# "Decoder compile failed, use default setting" | |
# self.decoder = decoder | |
self.training = True | |
def forward(self, input, **kwargs): | |
# [f_32, f_16, f_8, f_4] | |
features = self.encoder(input) | |
# [x_32, x_16, x_8, x_4, x, ...] | |
out = self.decoder(features, **kwargs) | |
return out |