Spaces:
Sleeping
Sleeping
File size: 4,663 Bytes
0742dfe |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
# -*- coding: utf-8 -*-
import itertools
import torch
from torch import nn
import numpy as np
import cv2
import torchvision.transforms as transforms
# import torchsnooper ## for debug
class DBLoss(nn.Module):
def __init__(self, alpha=1., beta=10., ohem_ratio=3):
"""
Implement DB Loss.
:param alpha: loss binary_map 前面的系数
:param beta: loss threshold 前面的系数
:param ohem_ratio: OHEM的比例
"""
super().__init__()
self.alpha = alpha
self.beta = beta
self.ohem_ratio = ohem_ratio
def forward(self, outputs, labels, training_masks, G_d):
"""
Implement DB Loss.
:param outputs: N 2 H W
:param labels: N 2 H W
:param training_masks:
"""
prob_map = outputs[:, 0, :, :]
thres_map = outputs[:, 1, :, :]
gt_prob = labels[:, 0, :, :]
gt_thres = labels[:, 1, :, :]
G_d = G_d.to(dtype = torch.float32)
training_masks = training_masks.to(dtype = torch.float32)
# OHEM mask (todo)
# selected_masks = self.ohem_batch(prob_map, gt_prob)
# selected_masks = selected_masks.to(outputs.device)
# 计算 prob loss
loss_prob = self.dice_loss(prob_map, gt_prob, training_masks)
# loss_prob = self.bce_loss(prob_map, gt_prob, selected_masks)
# 计算 binary map loss
bin_map = self.DB(prob_map, thres_map)
loss_bin = self.dice_loss(bin_map, gt_prob, training_masks)
# loss_prob = self.bce_loss(bin_map, gt_prob, selected_masks)
# 计算 threshold map loss
loss_fn = torch.nn.L1Loss(reduction='mean')
L1_loss = loss_fn(thres_map, gt_thres)
loss_thres = L1_loss * G_d
loss_prob = loss_prob.mean()
loss_bin = loss_bin.mean()
loss_thres = loss_thres.mean()
loss_all = loss_prob + self.alpha * loss_bin + self.beta * loss_thres
return loss_all, loss_prob, loss_bin, loss_thres
def DB(self, prob_map, thres_map, k=50):
'''
Differentiable binarization
another form: torch.sigmoid(k * (prob_map - thres_map))
'''
return 1. / (torch.exp((-k * (prob_map - thres_map))) + 1)
def dice_loss(self, pred_cls, gt_cls, training_mask):
'''
dice loss
此处默认真实值和预测值的格式均为 NCHW
:param gt_cls:
:param pred_cls:
:param training_mask:
:return:
'''
eps = 1e-5
intersection = torch.sum(gt_cls * pred_cls * training_mask)
union = torch.sum(gt_cls * training_mask) + torch.sum(pred_cls * training_mask) + eps
loss = 1. - (2 * intersection / union)
return loss
def bce_loss(self, input, target, mask):
if mask.sum() == 0:
return torch.tensor(0.0, device=input.device, requires_grad=True)
target[target <= 0.5] = 0
target[target > 0.5] = 1
input = input[mask.bool()]
target = target[mask.bool()]
loss = nn.BCELoss(reduction='mean')(input, target)
return loss
def ohem_single(self, score, gt_text):
pos_num = (int)(np.sum(gt_text > 0.5))
if pos_num == 0:
selected_mask = np.zeros_like(score)
selected_mask = selected_mask.reshape(1, selected_mask.shape[0], selected_mask.shape[1]).astype('float32')
return selected_mask
neg_num = (int)(np.sum(gt_text <= 0.5))
neg_num = (int)(min(pos_num * self.ohem_ratio, neg_num))
if neg_num == 0:
selected_mask = np.zeros_like(score)
selected_mask = selected_mask.reshape(1, selected_mask.shape[0], selected_mask.shape[1]).astype('float32')
return selected_mask
neg_score = score[gt_text <= 0.5]
neg_score_sorted = np.sort(-neg_score)
threshold = -neg_score_sorted[neg_num - 1]
selected_mask = (score >= threshold) | (gt_text > 0.5)
selected_mask = selected_mask.reshape(1, selected_mask.shape[0], selected_mask.shape[1]).astype('float32')
return selected_mask
def ohem_batch(self, scores, gt_texts):
scores = scores.data.cpu().numpy()
gt_texts = gt_texts.data.cpu().numpy()
selected_masks = []
for i in range(scores.shape[0]):
selected_masks.append(self.ohem_single(scores[i, :, :], gt_texts[i, :, :]))
selected_masks = np.concatenate(selected_masks, 0)
selected_masks = torch.from_numpy(selected_masks).float()
return selected_masks
|