Metric3D / training /mono /model /monodepth_model.py
zach
initial commit based on github repo
3ef1661
import torch
import torch.nn as nn
from mono.utils.comm import get_func
from .__base_model__ import BaseDepthModel
class DepthModel(BaseDepthModel):
def __init__(self, cfg, criterions, **kwards):
super(DepthModel, self).__init__(cfg, criterions)
model_type = cfg.model.type
self.training = True
# def inference(self, data):
# with torch.no_grad():
# pred_depth, _, confidence = self.inference(data)
# return pred_depth, confidence
def get_monodepth_model(
cfg : dict,
criterions: dict,
**kwargs
) -> nn.Module:
# config depth model
model = DepthModel(cfg, criterions, **kwargs)
#model.init_weights(load_imagenet_model, imagenet_ckpt_fpath)
assert isinstance(model, nn.Module)
return model
def get_configured_monodepth_model(
cfg: dict,
criterions: dict,
) -> nn.Module:
"""
Args:
@ configs: configures for the network.
@ load_imagenet_model: whether to initialize from ImageNet-pretrained model.
@ imagenet_ckpt_fpath: string representing path to file with weights to initialize model with.
Returns:
# model: depth model.
"""
model = get_monodepth_model(cfg, criterions)
return model