File size: 4,663 Bytes
0742dfe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
# -*- coding: utf-8 -*-

import itertools
import torch
from torch import nn
import numpy as np
import cv2
import torchvision.transforms as transforms

# import torchsnooper  ## for debug

class DBLoss(nn.Module):
    def __init__(self, alpha=1., beta=10., ohem_ratio=3):
        """
        Implement DB Loss.
        :param alpha: loss binary_map 前面的系数
        :param beta: loss threshold 前面的系数
        :param ohem_ratio: OHEM的比例
        """        
        super().__init__()
        self.alpha = alpha
        self.beta = beta
        self.ohem_ratio = ohem_ratio
          
    def forward(self, outputs, labels, training_masks, G_d):
        """
        Implement DB Loss.
        :param outputs: N 2 H W
        :param labels: N 2 H W
        :param training_masks: 
        """     
        prob_map = outputs[:, 0, :, :]
        thres_map = outputs[:, 1, :, :]
        gt_prob = labels[:, 0, :, :]
        gt_thres = labels[:, 1, :, :]
        
        G_d = G_d.to(dtype = torch.float32)
        training_masks = training_masks.to(dtype = torch.float32)
        
        # OHEM mask (todo)
        # selected_masks = self.ohem_batch(prob_map, gt_prob)
        # selected_masks = selected_masks.to(outputs.device)
        
        # 计算 prob loss
        loss_prob = self.dice_loss(prob_map, gt_prob, training_masks)
        # loss_prob = self.bce_loss(prob_map, gt_prob, selected_masks)
        
        # 计算 binary map loss
        bin_map = self.DB(prob_map, thres_map)
        loss_bin = self.dice_loss(bin_map, gt_prob, training_masks)
        # loss_prob = self.bce_loss(bin_map, gt_prob, selected_masks)
        
        # 计算 threshold map loss
        loss_fn = torch.nn.L1Loss(reduction='mean')
        L1_loss = loss_fn(thres_map, gt_thres)
        loss_thres = L1_loss * G_d 

        loss_prob = loss_prob.mean()
        loss_bin = loss_bin.mean()
        loss_thres = loss_thres.mean()

        loss_all = loss_prob + self.alpha * loss_bin + self.beta * loss_thres 
        return loss_all, loss_prob, loss_bin, loss_thres

    def DB(self, prob_map, thres_map, k=50):
        '''
        Differentiable binarization
        another form: torch.sigmoid(k * (prob_map - thres_map))
        '''
        return 1. / (torch.exp((-k * (prob_map - thres_map))) + 1)

    def dice_loss(self, pred_cls, gt_cls, training_mask):
        '''
        dice loss
        此处默认真实值和预测值的格式均为 NCHW
        :param gt_cls: 
        :param pred_cls: 
        :param training_mask: 
        :return:
        '''
        eps = 1e-5
     
        intersection = torch.sum(gt_cls * pred_cls * training_mask)
        union = torch.sum(gt_cls * training_mask) + torch.sum(pred_cls * training_mask) + eps
        loss = 1. - (2 * intersection / union)

        return loss

    def bce_loss(self, input, target, mask):
        if mask.sum() == 0:
            return torch.tensor(0.0, device=input.device, requires_grad=True)
        target[target <= 0.5] = 0
        target[target > 0.5] = 1
        input = input[mask.bool()]
        target = target[mask.bool()]
        loss = nn.BCELoss(reduction='mean')(input, target)
        return loss

    def ohem_single(self, score, gt_text):
        pos_num = (int)(np.sum(gt_text > 0.5))

        if pos_num == 0:
            selected_mask = np.zeros_like(score)
            selected_mask = selected_mask.reshape(1, selected_mask.shape[0], selected_mask.shape[1]).astype('float32')
            return selected_mask

        neg_num = (int)(np.sum(gt_text <= 0.5))
        neg_num = (int)(min(pos_num * self.ohem_ratio, neg_num))

        if neg_num == 0:
            selected_mask = np.zeros_like(score)
            selected_mask = selected_mask.reshape(1, selected_mask.shape[0], selected_mask.shape[1]).astype('float32')
            return selected_mask

        neg_score = score[gt_text <= 0.5]
        neg_score_sorted = np.sort(-neg_score)
        threshold = -neg_score_sorted[neg_num - 1]
        selected_mask = (score >= threshold) | (gt_text > 0.5)
        selected_mask = selected_mask.reshape(1, selected_mask.shape[0], selected_mask.shape[1]).astype('float32')
        return selected_mask

    def ohem_batch(self, scores, gt_texts):
        scores = scores.data.cpu().numpy()
        gt_texts = gt_texts.data.cpu().numpy()
        selected_masks = []
        for i in range(scores.shape[0]):
            selected_masks.append(self.ohem_single(scores[i, :, :], gt_texts[i, :, :]))

        selected_masks = np.concatenate(selected_masks, 0)
        selected_masks = torch.from_numpy(selected_masks).float()

        return selected_masks