File size: 1,599 Bytes
0b32e3c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 |
import pdb
from torch import nn
from torch.nn import functional as F
from .loss_functions import Contrastive_Loss, Cosine_Sim_Loss
class _DMMI_MRACL_Framework(nn.Module):
def __init__(self, backbone, classifier):
super(_DMMI_MRACL_Framework, self).__init__()
self.backbone = backbone
self.classifier = classifier
self.cossim = Cosine_Sim_Loss()
self.contrastive = Contrastive_Loss()
def forward(self, x, l_feats, l_feats1, l_mask, cl_masks=None, target_flag=None, training_flag=True):
input_shape = x.shape[-2:]
l_1, features = self.backbone(x, l_feats, l_mask)
x_c1, x_c2, x_c3, x_c4 = features
de_feat, l_2, x = self.classifier(l_1, l_feats1, x_c4, x_c3, x_c2, x_c1)
seg_mag = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=True)
# print("de_feat shape:", de_feat.shape)
# print("filtered de_feat shape:", de_feat[cl_masks].shape)
# print("l_1 shape:", l_1.shape)
# print("target_flag shape:", target_flag.shape)
# print("cl_masks shape:", cl_masks.shape)
if training_flag and target_flag!=None:
loss_contrastive = self.contrastive(de_feat[cl_masks], l_1[cl_masks], target_flag)
loss_cossim = self.cossim(l_1[cl_masks], l_2[cl_masks], l_mask[cl_masks], target_flag)
return loss_contrastive, loss_cossim, seg_mag, x_c4
else:
loss_contrastive = 0
loss_cossim = 0
return loss_contrastive, loss_cossim, seg_mag
class DMMI_MRACL(_DMMI_MRACL_Framework):
pass |