zach
initial commit based on github repo
3ef1661
raw
history blame contribute delete
966 Bytes
import torch
import torch.nn as nn
from mono.utils.comm import get_func
class DensePredModel(nn.Module):
def __init__(self, cfg):
super(DensePredModel, self).__init__()
self.encoder = get_func('mono.model.' + cfg.model.backbone.prefix + cfg.model.backbone.type)(**cfg.model.backbone)
self.decoder = get_func('mono.model.' + cfg.model.decode_head.prefix + cfg.model.decode_head.type)(cfg)
# try:
# decoder_compiled = torch.compile(decoder, mode='max-autotune')
# "Decoder compile finished"
# self.decoder = decoder_compiled
# except:
# "Decoder compile failed, use default setting"
# self.decoder = decoder
self.training = True
def forward(self, input, **kwargs):
# [f_32, f_16, f_8, f_4]
features = self.encoder(input)
# [x_32, x_16, x_8, x_4, x, ...]
out = self.decoder(features, **kwargs)
return out