Spaces:
Sleeping
Sleeping
# -*- coding: utf-8 -*- | |
import itertools | |
import torch | |
from torch import nn | |
import numpy as np | |
import cv2 | |
import torchvision.transforms as transforms | |
# import torchsnooper ## for debug | |
class DBLoss(nn.Module): | |
def __init__(self, alpha=1., beta=10., ohem_ratio=3): | |
""" | |
Implement DB Loss. | |
:param alpha: loss binary_map 前面的系数 | |
:param beta: loss threshold 前面的系数 | |
:param ohem_ratio: OHEM的比例 | |
""" | |
super().__init__() | |
self.alpha = alpha | |
self.beta = beta | |
self.ohem_ratio = ohem_ratio | |
def forward(self, outputs, labels, training_masks, G_d): | |
""" | |
Implement DB Loss. | |
:param outputs: N 2 H W | |
:param labels: N 2 H W | |
:param training_masks: | |
""" | |
prob_map = outputs[:, 0, :, :] | |
thres_map = outputs[:, 1, :, :] | |
gt_prob = labels[:, 0, :, :] | |
gt_thres = labels[:, 1, :, :] | |
G_d = G_d.to(dtype = torch.float32) | |
training_masks = training_masks.to(dtype = torch.float32) | |
# OHEM mask (todo) | |
# selected_masks = self.ohem_batch(prob_map, gt_prob) | |
# selected_masks = selected_masks.to(outputs.device) | |
# 计算 prob loss | |
loss_prob = self.dice_loss(prob_map, gt_prob, training_masks) | |
# loss_prob = self.bce_loss(prob_map, gt_prob, selected_masks) | |
# 计算 binary map loss | |
bin_map = self.DB(prob_map, thres_map) | |
loss_bin = self.dice_loss(bin_map, gt_prob, training_masks) | |
# loss_prob = self.bce_loss(bin_map, gt_prob, selected_masks) | |
# 计算 threshold map loss | |
loss_fn = torch.nn.L1Loss(reduction='mean') | |
L1_loss = loss_fn(thres_map, gt_thres) | |
loss_thres = L1_loss * G_d | |
loss_prob = loss_prob.mean() | |
loss_bin = loss_bin.mean() | |
loss_thres = loss_thres.mean() | |
loss_all = loss_prob + self.alpha * loss_bin + self.beta * loss_thres | |
return loss_all, loss_prob, loss_bin, loss_thres | |
def DB(self, prob_map, thres_map, k=50): | |
''' | |
Differentiable binarization | |
another form: torch.sigmoid(k * (prob_map - thres_map)) | |
''' | |
return 1. / (torch.exp((-k * (prob_map - thres_map))) + 1) | |
def dice_loss(self, pred_cls, gt_cls, training_mask): | |
''' | |
dice loss | |
此处默认真实值和预测值的格式均为 NCHW | |
:param gt_cls: | |
:param pred_cls: | |
:param training_mask: | |
:return: | |
''' | |
eps = 1e-5 | |
intersection = torch.sum(gt_cls * pred_cls * training_mask) | |
union = torch.sum(gt_cls * training_mask) + torch.sum(pred_cls * training_mask) + eps | |
loss = 1. - (2 * intersection / union) | |
return loss | |
def bce_loss(self, input, target, mask): | |
if mask.sum() == 0: | |
return torch.tensor(0.0, device=input.device, requires_grad=True) | |
target[target <= 0.5] = 0 | |
target[target > 0.5] = 1 | |
input = input[mask.bool()] | |
target = target[mask.bool()] | |
loss = nn.BCELoss(reduction='mean')(input, target) | |
return loss | |
def ohem_single(self, score, gt_text): | |
pos_num = (int)(np.sum(gt_text > 0.5)) | |
if pos_num == 0: | |
selected_mask = np.zeros_like(score) | |
selected_mask = selected_mask.reshape(1, selected_mask.shape[0], selected_mask.shape[1]).astype('float32') | |
return selected_mask | |
neg_num = (int)(np.sum(gt_text <= 0.5)) | |
neg_num = (int)(min(pos_num * self.ohem_ratio, neg_num)) | |
if neg_num == 0: | |
selected_mask = np.zeros_like(score) | |
selected_mask = selected_mask.reshape(1, selected_mask.shape[0], selected_mask.shape[1]).astype('float32') | |
return selected_mask | |
neg_score = score[gt_text <= 0.5] | |
neg_score_sorted = np.sort(-neg_score) | |
threshold = -neg_score_sorted[neg_num - 1] | |
selected_mask = (score >= threshold) | (gt_text > 0.5) | |
selected_mask = selected_mask.reshape(1, selected_mask.shape[0], selected_mask.shape[1]).astype('float32') | |
return selected_mask | |
def ohem_batch(self, scores, gt_texts): | |
scores = scores.data.cpu().numpy() | |
gt_texts = gt_texts.data.cpu().numpy() | |
selected_masks = [] | |
for i in range(scores.shape[0]): | |
selected_masks.append(self.ohem_single(scores[i, :, :], gt_texts[i, :, :])) | |
selected_masks = np.concatenate(selected_masks, 0) | |
selected_masks = torch.from_numpy(selected_masks).float() | |
return selected_masks | |