# -*- coding: utf-8 -*- import itertools import torch from torch import nn import numpy as np import cv2 import torchvision.transforms as transforms # import torchsnooper ## for debug class DBLoss(nn.Module): def __init__(self, alpha=1., beta=10., ohem_ratio=3): """ Implement DB Loss. :param alpha: loss binary_map 前面的系数 :param beta: loss threshold 前面的系数 :param ohem_ratio: OHEM的比例 """ super().__init__() self.alpha = alpha self.beta = beta self.ohem_ratio = ohem_ratio def forward(self, outputs, labels, training_masks, G_d): """ Implement DB Loss. :param outputs: N 2 H W :param labels: N 2 H W :param training_masks: """ prob_map = outputs[:, 0, :, :] thres_map = outputs[:, 1, :, :] gt_prob = labels[:, 0, :, :] gt_thres = labels[:, 1, :, :] G_d = G_d.to(dtype = torch.float32) training_masks = training_masks.to(dtype = torch.float32) # OHEM mask (todo) # selected_masks = self.ohem_batch(prob_map, gt_prob) # selected_masks = selected_masks.to(outputs.device) # 计算 prob loss loss_prob = self.dice_loss(prob_map, gt_prob, training_masks) # loss_prob = self.bce_loss(prob_map, gt_prob, selected_masks) # 计算 binary map loss bin_map = self.DB(prob_map, thres_map) loss_bin = self.dice_loss(bin_map, gt_prob, training_masks) # loss_prob = self.bce_loss(bin_map, gt_prob, selected_masks) # 计算 threshold map loss loss_fn = torch.nn.L1Loss(reduction='mean') L1_loss = loss_fn(thres_map, gt_thres) loss_thres = L1_loss * G_d loss_prob = loss_prob.mean() loss_bin = loss_bin.mean() loss_thres = loss_thres.mean() loss_all = loss_prob + self.alpha * loss_bin + self.beta * loss_thres return loss_all, loss_prob, loss_bin, loss_thres def DB(self, prob_map, thres_map, k=50): ''' Differentiable binarization another form: torch.sigmoid(k * (prob_map - thres_map)) ''' return 1. / (torch.exp((-k * (prob_map - thres_map))) + 1) def dice_loss(self, pred_cls, gt_cls, training_mask): ''' dice loss 此处默认真实值和预测值的格式均为 NCHW :param gt_cls: :param pred_cls: :param training_mask: :return: ''' eps = 1e-5 intersection = torch.sum(gt_cls * pred_cls * training_mask) union = torch.sum(gt_cls * training_mask) + torch.sum(pred_cls * training_mask) + eps loss = 1. - (2 * intersection / union) return loss def bce_loss(self, input, target, mask): if mask.sum() == 0: return torch.tensor(0.0, device=input.device, requires_grad=True) target[target <= 0.5] = 0 target[target > 0.5] = 1 input = input[mask.bool()] target = target[mask.bool()] loss = nn.BCELoss(reduction='mean')(input, target) return loss def ohem_single(self, score, gt_text): pos_num = (int)(np.sum(gt_text > 0.5)) if pos_num == 0: selected_mask = np.zeros_like(score) selected_mask = selected_mask.reshape(1, selected_mask.shape[0], selected_mask.shape[1]).astype('float32') return selected_mask neg_num = (int)(np.sum(gt_text <= 0.5)) neg_num = (int)(min(pos_num * self.ohem_ratio, neg_num)) if neg_num == 0: selected_mask = np.zeros_like(score) selected_mask = selected_mask.reshape(1, selected_mask.shape[0], selected_mask.shape[1]).astype('float32') return selected_mask neg_score = score[gt_text <= 0.5] neg_score_sorted = np.sort(-neg_score) threshold = -neg_score_sorted[neg_num - 1] selected_mask = (score >= threshold) | (gt_text > 0.5) selected_mask = selected_mask.reshape(1, selected_mask.shape[0], selected_mask.shape[1]).astype('float32') return selected_mask def ohem_batch(self, scores, gt_texts): scores = scores.data.cpu().numpy() gt_texts = gt_texts.data.cpu().numpy() selected_masks = [] for i in range(scores.shape[0]): selected_masks.append(self.ohem_single(scores[i, :, :], gt_texts[i, :, :])) selected_masks = np.concatenate(selected_masks, 0) selected_masks = torch.from_numpy(selected_masks).float() return selected_masks