import pdb from torch import nn from torch.nn import functional as F from .loss_functions import Contrastive_Loss, Cosine_Sim_Loss class _DMMI_MRACL_Framework(nn.Module): def __init__(self, backbone, classifier): super(_DMMI_MRACL_Framework, self).__init__() self.backbone = backbone self.classifier = classifier self.cossim = Cosine_Sim_Loss() self.contrastive = Contrastive_Loss() def forward(self, x, l_feats, l_feats1, l_mask, cl_masks=None, target_flag=None, training_flag=True): input_shape = x.shape[-2:] l_1, features = self.backbone(x, l_feats, l_mask) x_c1, x_c2, x_c3, x_c4 = features de_feat, l_2, x = self.classifier(l_1, l_feats1, x_c4, x_c3, x_c2, x_c1) seg_mag = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=True) # print("de_feat shape:", de_feat.shape) # print("filtered de_feat shape:", de_feat[cl_masks].shape) # print("l_1 shape:", l_1.shape) # print("target_flag shape:", target_flag.shape) # print("cl_masks shape:", cl_masks.shape) if training_flag and target_flag!=None: loss_contrastive = self.contrastive(de_feat[cl_masks], l_1[cl_masks], target_flag) loss_cossim = self.cossim(l_1[cl_masks], l_2[cl_masks], l_mask[cl_masks], target_flag) return loss_contrastive, loss_cossim, seg_mag, x_c4 else: loss_contrastive = 0 loss_cossim = 0 return loss_contrastive, loss_cossim, seg_mag class DMMI_MRACL(_DMMI_MRACL_Framework): pass