File size: 1,177 Bytes
0b32e3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import pdb
from torch import nn
from torch.nn import functional as F
from .loss_functions import Contrastive_Loss, Cosine_Sim_Loss

class _DMMI_Framework(nn.Module):
    def __init__(self, backbone, classifier):
        super(_DMMI_Framework, self).__init__()
        self.backbone = backbone
        self.classifier = classifier
        self.cossim = Cosine_Sim_Loss()
        self.contrastive = Contrastive_Loss()

    def forward(self, x, l_feats, l_feats1, l_mask, target_flag=None, training_flag=True):

        input_shape = x.shape[-2:]
        l_1, features = self.backbone(x, l_feats, l_mask)
        x_c1, x_c2, x_c3, x_c4 = features
        de_feat, l_2, x = self.classifier(l_1, l_feats1, x_c4, x_c3, x_c2, x_c1)
        seg_mag = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=True)

        if training_flag and target_flag!=None:
            loss_contrastive = self.contrastive(de_feat, l_1, target_flag)
            loss_cossim = self.cossim(l_1, l_2, l_mask, target_flag)
        else:
            loss_contrastive = 0
            loss_cossim = 0

        return loss_contrastive, loss_cossim, seg_mag

class DMMI(_DMMI_Framework):
    pass