jmliu commited on
Commit
0742dfe
·
1 Parent(s): dbd0af1
app.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import torch
4
+ from torchvision import transforms
5
+ import os
6
+ import cv2
7
+ import time
8
+ import numpy as np
9
+
10
+ import pyclipper
11
+ import gradio as gr
12
+
13
+ from models import get_model
14
+ from utils.util import show_img, draw_bbox
15
+
16
+ class Pytorch_model:
17
+ def __init__(self, model_path, gpu_id=None):
18
+ '''
19
+ 初始化pytorch模型
20
+ :param model_path: 模型地址(可以是模型的参数或者参数和计算图一起保存的文件)
21
+ :param gpu_id: 在哪一块gpu上运行
22
+ '''
23
+ self.gpu_id = gpu_id
24
+
25
+ if self.gpu_id is not None and isinstance(self.gpu_id, int) and torch.cuda.is_available():
26
+ self.device = torch.device("cuda:%s" % self.gpu_id)
27
+ checkpoint = torch.load(model_path)
28
+ else:
29
+ self.device = torch.device("cpu")
30
+ checkpoint = torch.load(model_path, map_location='cpu')
31
+ print('device:', self.device)
32
+
33
+ config = checkpoint['config']
34
+ config['arch']['args']['pretrained'] = False
35
+ self.net = get_model(config)
36
+
37
+ self.img_channel = config['data_loader']['args']['dataset']['img_channel']
38
+ self.net.load_state_dict(checkpoint['state_dict']) ## load weights
39
+ self.net.to(self.device)
40
+ self.net.eval()
41
+
42
+ def predict(self, img, short_size: int = 736, min_area: int = 100):
43
+ '''
44
+ 对传入的图像进行预测,支持图像地址, opencv读取图片,偏慢
45
+ :param img: the image as an np array
46
+ :param short_size:
47
+ :param min_area: 小于该尺度的bbox忽略
48
+ :return:
49
+ '''
50
+ img = cv2.imread(img)
51
+ if self.img_channel == 3:
52
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
53
+ h, w = img.shape[:2]
54
+ scale = short_size / min(h, w)
55
+ img = cv2.resize(img, None, fx=scale, fy=scale)
56
+
57
+ tensor = transforms.ToTensor()(img)
58
+ tensor = tensor.unsqueeze_(0)
59
+
60
+ tensor = tensor.to(self.device)
61
+ with torch.no_grad():
62
+ torch.cuda.synchronize(self.device)
63
+ start = time.time()
64
+ preds = self.net(tensor)[0]
65
+ torch.cuda.synchronize(self.device)
66
+ scale = (preds.shape[2] / w, preds.shape[1] / h)
67
+ t = time.time() - start
68
+
69
+ '''inference'''
70
+ start = time.time()
71
+ prob_map, thres_map = preds[0], preds[1]
72
+
73
+ ## Step 1: Use threshold to get the binary map
74
+ thr = 0.2
75
+ out = (prob_map > thr).float() * 255
76
+ out = out.data.cpu().numpy().astype(np.uint8)
77
+ # cv2.imwrite('c_bin_map.png', out)
78
+
79
+ ## Step 2: Connected components findContours
80
+ contours, hierarchy = cv2.findContours(out, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
81
+ contours = [(i / scale).astype(int) for i in contours if len(i)>=4]
82
+
83
+ # Step 3: Dilate the shrunk region (not necessary)
84
+ ratio_prime = 1.5
85
+ dilated_polys = []
86
+ for poly in contours:
87
+ poly = poly[:,0,:]
88
+ D_prime = cv2.contourArea(poly) * ratio_prime / cv2.arcLength(poly, True) # formula(10) in the thesis
89
+ pco = pyclipper.PyclipperOffset()
90
+ pco.AddPath(poly, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
91
+ dilated_poly = np.array(pco.Execute(D_prime))
92
+ if dilated_poly.size == 0 or dilated_poly.dtype != int or len(dilated_poly) != 1:
93
+ continue
94
+ dilated_polys.append(dilated_poly)
95
+
96
+ boxes_list = []
97
+ for cnt in dilated_polys:
98
+ # print('=============')
99
+ # print(cnt)
100
+ # print(len(cnt))
101
+ if cv2.contourArea(cnt) < min_area:
102
+ continue
103
+ rect = cv2.minAreaRect(cnt)
104
+ box = (cv2.boxPoints(rect)).astype(int)
105
+ boxes_list.append(box)
106
+
107
+ t = time.time() - start + t
108
+
109
+ boxes_list = np.array(boxes_list)
110
+ return dilated_polys, boxes_list, t
111
+
112
+
113
+ def run(img):
114
+ # os.environ["CUDA_VISIBLE_DEVICES"] = "5"
115
+ model_path = 'pre_trained/MobileNetv2_best_loss.pth'
116
+
117
+ # init model
118
+ model = Pytorch_model(model_path, gpu_id = None) ## set GPU id or None if you only have cpu
119
+ contours, boxes_list, t = model.predict(img)
120
+ print('Time: %.4f' %t)
121
+
122
+ img = img[:, :, ::-1]
123
+ imgc = img.copy()
124
+ cv2.drawContours(imgc, contours, -1, (22,222,22), 2, cv2.LINE_AA)
125
+ return imgc
126
+ # cv2.imwrite('contour.png', imgc)
127
+ # img = draw_bbox(img, boxes_list)
128
+ # cv2.imwrite('predict.jpg', img)
129
+
130
+ if __name__ == '__main__':
131
+ iface = gr.Interface(fn=run,
132
+ title="CJK Font Detection Using DBNet",
133
+ description="Gives an image containing CJK fonts, we will mark the texts out!",
134
+ inputs="image",
135
+ outputs="image")
136
+ iface.launch()
models/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from .model import Model
4
+ from .loss import DBLoss
5
+
6
+
7
+ def get_model(config):
8
+ model_config = config['arch']['args']
9
+ return Model(model_config)
10
+
11
+ def get_loss(config):
12
+ alpha = config['loss']['args']['alpha']
13
+ beta = config['loss']['args']['beta']
14
+ ohem_ratio = config['loss']['args']['ohem_ratio']
15
+ return DBLoss(alpha=alpha, beta=beta, ohem_ratio=ohem_ratio)
models/loss.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import itertools
4
+ import torch
5
+ from torch import nn
6
+ import numpy as np
7
+ import cv2
8
+ import torchvision.transforms as transforms
9
+
10
+ # import torchsnooper ## for debug
11
+
12
+ class DBLoss(nn.Module):
13
+ def __init__(self, alpha=1., beta=10., ohem_ratio=3):
14
+ """
15
+ Implement DB Loss.
16
+ :param alpha: loss binary_map 前面的系数
17
+ :param beta: loss threshold 前面的系数
18
+ :param ohem_ratio: OHEM的比例
19
+ """
20
+ super().__init__()
21
+ self.alpha = alpha
22
+ self.beta = beta
23
+ self.ohem_ratio = ohem_ratio
24
+
25
+ def forward(self, outputs, labels, training_masks, G_d):
26
+ """
27
+ Implement DB Loss.
28
+ :param outputs: N 2 H W
29
+ :param labels: N 2 H W
30
+ :param training_masks:
31
+ """
32
+ prob_map = outputs[:, 0, :, :]
33
+ thres_map = outputs[:, 1, :, :]
34
+ gt_prob = labels[:, 0, :, :]
35
+ gt_thres = labels[:, 1, :, :]
36
+
37
+ G_d = G_d.to(dtype = torch.float32)
38
+ training_masks = training_masks.to(dtype = torch.float32)
39
+
40
+ # OHEM mask (todo)
41
+ # selected_masks = self.ohem_batch(prob_map, gt_prob)
42
+ # selected_masks = selected_masks.to(outputs.device)
43
+
44
+ # 计算 prob loss
45
+ loss_prob = self.dice_loss(prob_map, gt_prob, training_masks)
46
+ # loss_prob = self.bce_loss(prob_map, gt_prob, selected_masks)
47
+
48
+ # 计算 binary map loss
49
+ bin_map = self.DB(prob_map, thres_map)
50
+ loss_bin = self.dice_loss(bin_map, gt_prob, training_masks)
51
+ # loss_prob = self.bce_loss(bin_map, gt_prob, selected_masks)
52
+
53
+ # 计算 threshold map loss
54
+ loss_fn = torch.nn.L1Loss(reduction='mean')
55
+ L1_loss = loss_fn(thres_map, gt_thres)
56
+ loss_thres = L1_loss * G_d
57
+
58
+ loss_prob = loss_prob.mean()
59
+ loss_bin = loss_bin.mean()
60
+ loss_thres = loss_thres.mean()
61
+
62
+ loss_all = loss_prob + self.alpha * loss_bin + self.beta * loss_thres
63
+ return loss_all, loss_prob, loss_bin, loss_thres
64
+
65
+ def DB(self, prob_map, thres_map, k=50):
66
+ '''
67
+ Differentiable binarization
68
+ another form: torch.sigmoid(k * (prob_map - thres_map))
69
+ '''
70
+ return 1. / (torch.exp((-k * (prob_map - thres_map))) + 1)
71
+
72
+ def dice_loss(self, pred_cls, gt_cls, training_mask):
73
+ '''
74
+ dice loss
75
+ 此处默认真实值和预测值的格式均为 NCHW
76
+ :param gt_cls:
77
+ :param pred_cls:
78
+ :param training_mask:
79
+ :return:
80
+ '''
81
+ eps = 1e-5
82
+
83
+ intersection = torch.sum(gt_cls * pred_cls * training_mask)
84
+ union = torch.sum(gt_cls * training_mask) + torch.sum(pred_cls * training_mask) + eps
85
+ loss = 1. - (2 * intersection / union)
86
+
87
+ return loss
88
+
89
+ def bce_loss(self, input, target, mask):
90
+ if mask.sum() == 0:
91
+ return torch.tensor(0.0, device=input.device, requires_grad=True)
92
+ target[target <= 0.5] = 0
93
+ target[target > 0.5] = 1
94
+ input = input[mask.bool()]
95
+ target = target[mask.bool()]
96
+ loss = nn.BCELoss(reduction='mean')(input, target)
97
+ return loss
98
+
99
+ def ohem_single(self, score, gt_text):
100
+ pos_num = (int)(np.sum(gt_text > 0.5))
101
+
102
+ if pos_num == 0:
103
+ selected_mask = np.zeros_like(score)
104
+ selected_mask = selected_mask.reshape(1, selected_mask.shape[0], selected_mask.shape[1]).astype('float32')
105
+ return selected_mask
106
+
107
+ neg_num = (int)(np.sum(gt_text <= 0.5))
108
+ neg_num = (int)(min(pos_num * self.ohem_ratio, neg_num))
109
+
110
+ if neg_num == 0:
111
+ selected_mask = np.zeros_like(score)
112
+ selected_mask = selected_mask.reshape(1, selected_mask.shape[0], selected_mask.shape[1]).astype('float32')
113
+ return selected_mask
114
+
115
+ neg_score = score[gt_text <= 0.5]
116
+ neg_score_sorted = np.sort(-neg_score)
117
+ threshold = -neg_score_sorted[neg_num - 1]
118
+ selected_mask = (score >= threshold) | (gt_text > 0.5)
119
+ selected_mask = selected_mask.reshape(1, selected_mask.shape[0], selected_mask.shape[1]).astype('float32')
120
+ return selected_mask
121
+
122
+ def ohem_batch(self, scores, gt_texts):
123
+ scores = scores.data.cpu().numpy()
124
+ gt_texts = gt_texts.data.cpu().numpy()
125
+ selected_masks = []
126
+ for i in range(scores.shape[0]):
127
+ selected_masks.append(self.ohem_single(scores[i, :, :], gt_texts[i, :, :]))
128
+
129
+ selected_masks = np.concatenate(selected_masks, 0)
130
+ selected_masks = torch.from_numpy(selected_masks).float()
131
+
132
+ return selected_masks
models/model.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import torch
4
+ from torch import nn
5
+ import torch.nn.functional as F
6
+ from models.modules import *
7
+ #from modules import *
8
+
9
+ backbone_dict = {'resnet18': {'models': resnet18, 'out': [64, 128, 256, 512]},
10
+ 'resnet34': {'models': resnet34, 'out': [64, 128, 256, 512]},
11
+ 'resnet50': {'models': resnet50, 'out': [256, 512, 1024, 2048]},
12
+ 'resnet101': {'models': resnet101, 'out': [256, 512, 1024, 2048]},
13
+ 'resnet152': {'models': resnet152, 'out': [256, 512, 1024, 2048]},
14
+ 'resnext50_32x4d': {'models': resnext50_32x4d, 'out': [256, 512, 1024, 2048]},
15
+ 'resnext101_32x8d': {'models': resnext101_32x8d, 'out': [256, 512, 1024, 2048]},
16
+ 'shufflenetv2': {'models': shufflenet_v2_x1_0, 'out': [24, 116, 232, 464]},
17
+ 'mobilenetv2': {'models': mobilenet_v2_x1_0, 'out': [24, 40, 160, 160]}
18
+ }
19
+
20
+ segmentation_head_dict = {'FPN': FPN, 'FPEM_FFM': FPEM_FFM}
21
+
22
+ # 'MobileNetV3_Large': {'models': MobileNetV3_Large, 'out': [24, 40, 160, 160]},
23
+ # 'MobileNetV3_Small': {'models': MobileNetV3_Small, 'out': [16, 24, 48, 96]},
24
+ # 'shufflenetv2': {'models': shufflenet_v2_x1_0, 'out': [24, 116, 232, 464]}}
25
+
26
+ class Model(nn.Module):
27
+ def __init__(self, model_config: dict):
28
+ """
29
+ PANnet
30
+ :param model_config: 模型配置
31
+ """
32
+ super().__init__()
33
+ backbone = model_config['backbone']
34
+ pretrained = model_config['pretrained']
35
+ segmentation_head = model_config['segmentation_head']
36
+
37
+ assert backbone in backbone_dict, 'backbone must in: {}'.format(backbone_dict)
38
+ assert segmentation_head in segmentation_head_dict, 'segmentation_head must in: {}'.format(
39
+ segmentation_head_dict)
40
+
41
+ backbone_model, backbone_out = backbone_dict[backbone]['models'], backbone_dict[backbone]['out']
42
+ self.backbone = backbone_model(pretrained=pretrained)
43
+ self.segmentation_head = segmentation_head_dict[segmentation_head](backbone_out, **model_config)
44
+ self.name = '{}_{}'.format(backbone, segmentation_head)
45
+
46
+ def forward(self, x):
47
+ _, _, H, W = x.size()
48
+ backbone_out = self.backbone(x)
49
+ segmentation_head_out = self.segmentation_head(backbone_out)
50
+ y = segmentation_head_out
51
+ return y
52
+
53
+
54
+ if __name__ == '__main__':
55
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
56
+ print(device)
57
+ x = torch.zeros(1, 3, 640, 640).to(device)
58
+
59
+ model_config = {
60
+ 'backbone': 'mobilenetv2',
61
+ 'fpem_repeat': 2, # fpem模块重复的次数
62
+ 'pretrained': False, # backbone 是否使用imagesnet的预训练模型
63
+ 'segmentation_head': 'FPN' # 分割头,FPN or FPEM_FFM
64
+ }
65
+ model = Model(model_config=model_config).to(device)
66
+ y = model(x)
67
+
68
+ print(model)
69
+ #torch.save(model.state_dict(), 'PAN.pth')
models/modules/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from .resnet import *
4
+ from .shufflenetv2 import *
5
+ from .mobilenetv2 import *
6
+ from .segmentation_head import FPEM_FFM,FPN
models/modules/mobilenetv2.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ from torch.hub import load_state_dict_from_url
5
+
6
+ __all__ = ['mobilenet_v2_x1_0']
7
+
8
+ model_urls = {
9
+ #currently hadn't found a pretrained weight
10
+ 'mobilenet_v2_x1_0': None,
11
+ }
12
+ def _make_divisible(v, divisor, min_value=None):
13
+ """
14
+ This function is taken from the original tf repo.
15
+ It ensures that all layers have a channel number that is divisible by 8
16
+ It can be seen here:
17
+ https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
18
+ :param v:
19
+ :param divisor:
20
+ :param min_value:
21
+ :return:
22
+ """
23
+ if min_value is None:
24
+ min_value = divisor
25
+ new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
26
+ # Make sure that round down does not go down by more than 10%.
27
+ if new_v < 0.9 * v:
28
+ new_v += divisor
29
+ return new_v
30
+
31
+ class InvertedResidual(nn.Module):
32
+ def __init__(self, inp, oup, stride, expand_ratio):
33
+ super(InvertedResidual, self).__init__()
34
+
35
+ if not (1 <= stride <= 2):
36
+ raise ValueError('illegal stride value')
37
+ self.stride = stride
38
+
39
+ self.exp_r = expand_ratio;
40
+ hidden_dim = round(inp * self.exp_r);
41
+
42
+ if self.exp_r == 1:
43
+ self.branch = nn.Sequential(
44
+ # dw conv
45
+ self.depthwise_conv(hidden_dim, hidden_dim, 3, stride, padding = 1, bias=False),
46
+ nn.BatchNorm2d(hidden_dim),
47
+ nn.ReLU6(inplace=True),
48
+ # pw-linear
49
+ nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
50
+ nn.BatchNorm2d(oup),
51
+ )
52
+
53
+ else:
54
+ self.branch = nn.Sequential(
55
+ # pw
56
+ nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
57
+ nn.BatchNorm2d(hidden_dim),
58
+ nn.ReLU6(inplace=True),
59
+ # dw
60
+ self.depthwise_conv(hidden_dim, hidden_dim, 3, stride, padding = 1, bias=False),
61
+ nn.BatchNorm2d(hidden_dim),
62
+ nn.ReLU6(inplace=True),
63
+ # pw-linear
64
+ nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
65
+ nn.BatchNorm2d(oup),
66
+ )
67
+
68
+ self.downsample = nn.Sequential(
69
+ nn.Conv2d(inp, oup, kernel_size=1, stride=stride, bias=False),
70
+ nn.BatchNorm2d(oup),
71
+ )
72
+
73
+ self.identity = stride == 1
74
+ @staticmethod
75
+ def depthwise_conv(i, o, kernel_size, stride, padding, bias=False):
76
+ return nn.Conv2d(i, o, kernel_size, stride, padding, bias=bias, groups=i)
77
+
78
+ def forward(self, x):
79
+ if self.identity:
80
+ downsampx = self.downsample(x)
81
+ return downsampx + self.branch(x)
82
+ else:
83
+ return self.branch(x)
84
+
85
+
86
+ class MobileNetV2(nn.Module):
87
+ def __init__(self, stages_repeats, stages_out_channels, num_classes=1000, width_mult=1.):
88
+ super(MobileNetV2, self).__init__()
89
+
90
+ if len(stages_repeats) != 5:
91
+ raise ValueError('expected stages_repeats as list of 4 positive ints')
92
+ if len(stages_out_channels) != 6:
93
+ raise ValueError('expected stages_out_channels as list of 5 positive ints')
94
+ self._stage_out_channels = stages_out_channels
95
+
96
+ self.tlist = [1,6,6,6,6]
97
+ self.slist = [1,2,2,2,1]
98
+
99
+ input_channels = 3
100
+ output_channels = self._stage_out_channels[0] # 32
101
+ #output_channels = _make_divisible(output_channels * width_mult, 4 if width_mult == 0.1 else 8)
102
+
103
+ self.conv1 = nn.Sequential(
104
+ nn.Conv2d(input_channels, output_channels, 3, 2, 1, bias=False),
105
+ nn.BatchNorm2d(output_channels),
106
+ nn.ReLU6(inplace=True),
107
+ )
108
+ input_channels = output_channels
109
+
110
+
111
+ stage_names = ['stage{}'.format(i) for i in [2, 3, 4, 5, 6]]
112
+ for name, repeats, output_channels, t, s in zip(
113
+ stage_names, stages_repeats, self._stage_out_channels[1:], self.tlist, self.slist):
114
+
115
+ #output_channels = _make_divisible(output_channels * width_mult, 4 if width_mult == 0.1 else 8)
116
+
117
+ seq = [InvertedResidual(input_channels, output_channels, s, t)]
118
+ for i in range(repeats - 1):
119
+
120
+ seq.append(InvertedResidual(output_channels, output_channels, 1, t))
121
+
122
+
123
+
124
+ setattr(self, name, nn.Sequential(*seq))
125
+ input_channels = output_channels
126
+
127
+
128
+ output_channels = self._stage_out_channels[-1]
129
+
130
+ self.conv9 = nn.Sequential(
131
+ nn.Conv2d(input_channels, output_channels, 1, 1, 0, bias=False),
132
+ nn.BatchNorm2d(output_channels),
133
+ nn.ReLU6(inplace=True),
134
+ )
135
+
136
+
137
+
138
+ def forward(self, x):
139
+ x = self.conv1(x)
140
+ c2 = self.stage2(x)
141
+ c3 = self.stage3(c2)
142
+ c4 = self.stage4(c3)
143
+ c5 = self.stage5(c4)
144
+ c6 = self.stage6(c5)
145
+ #c7 = self.stage7(c6)
146
+ #c8 = self.stage8(c7)
147
+ #c9 = self.conv9(c8)
148
+
149
+ return c3, c4, c5, c6
150
+
151
+
152
+
153
+ def _mobilenetv2(arch, pretrained, progress, *args, **kwargs):
154
+ model = MobileNetV2(*args, **kwargs)
155
+
156
+ if pretrained:
157
+ model_url = model_urls[arch]
158
+ if model_url is None:
159
+ raise NotImplementedError('pretrained {} is not supported as of now'.format(arch))
160
+ else:
161
+ state_dict = load_state_dict_from_url(model_url, progress=progress)
162
+ model.load_state_dict(state_dict,strict=False)
163
+
164
+ return model
165
+
166
+
167
+ def mobilenet_v2_x1_0(pretrained=False, progress=True, **kwargs):
168
+ """
169
+ Constructs a ShuffleNetV2 with 0.5x output channels, as described in
170
+ `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
171
+ <https://arxiv.org/abs/1807.11164>`_.
172
+
173
+ Args:
174
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
175
+ progress (bool): If True, displays a progress bar of the download to stderr
176
+ """
177
+ return _mobilenetv2('mobilenet_v2_x1_0', pretrained, progress,
178
+ [1, 2, 3, 4, 3], [32, 16, 24, 40, 160, 160], **kwargs)
models/modules/resnet.py ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import torch.nn as nn
4
+ from torch.hub import load_state_dict_from_url
5
+ #from torchvision.models.utils import load_state_dict_from_url
6
+
7
+ __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
8
+ 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d']
9
+
10
+ model_urls = {
11
+ 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
12
+ 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
13
+ 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
14
+ 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
15
+ 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
16
+ 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
17
+ 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
18
+ }
19
+
20
+
21
+ def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
22
+ """3x3 convolution with padding"""
23
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
24
+ padding=dilation, groups=groups, bias=False, dilation=dilation)
25
+
26
+
27
+ def conv1x1(in_planes, out_planes, stride=1):
28
+ """1x1 convolution"""
29
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
30
+
31
+
32
+ class BasicBlock(nn.Module):
33
+ expansion = 1
34
+
35
+ def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
36
+ base_width=64, dilation=1, norm_layer=None):
37
+ super(BasicBlock, self).__init__()
38
+ if norm_layer is None:
39
+ norm_layer = nn.BatchNorm2d
40
+ if groups != 1 or base_width != 64:
41
+ raise ValueError('BasicBlock only supports groups=1 and base_width=64')
42
+ if dilation > 1:
43
+ raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
44
+ # Both self.conv1 and self.downsample layers downsample the input when stride != 1
45
+ self.conv1 = conv3x3(inplanes, planes, stride)
46
+ self.bn1 = norm_layer(planes)
47
+ self.relu = nn.ReLU(inplace=True)
48
+ self.conv2 = conv3x3(planes, planes)
49
+ self.bn2 = norm_layer(planes)
50
+ self.downsample = downsample
51
+ self.stride = stride
52
+
53
+ def forward(self, x):
54
+ identity = x
55
+
56
+ out = self.conv1(x)
57
+ out = self.bn1(out)
58
+ out = self.relu(out)
59
+
60
+ out = self.conv2(out)
61
+ out = self.bn2(out)
62
+
63
+ if self.downsample is not None:
64
+ identity = self.downsample(x)
65
+
66
+ out += identity
67
+ out = self.relu(out)
68
+
69
+ return out
70
+
71
+
72
+ class Bottleneck(nn.Module):
73
+ expansion = 4
74
+
75
+ def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
76
+ base_width=64, dilation=1, norm_layer=None):
77
+ super(Bottleneck, self).__init__()
78
+ if norm_layer is None:
79
+ norm_layer = nn.BatchNorm2d
80
+ width = int(planes * (base_width / 64.)) * groups
81
+ # Both self.conv2 and self.downsample layers downsample the input when stride != 1
82
+ self.conv1 = conv1x1(inplanes, width)
83
+ self.bn1 = norm_layer(width)
84
+ self.conv2 = conv3x3(width, width, stride, groups, dilation)
85
+ self.bn2 = norm_layer(width)
86
+ self.conv3 = conv1x1(width, planes * self.expansion)
87
+ self.bn3 = norm_layer(planes * self.expansion)
88
+ self.relu = nn.ReLU(inplace=True)
89
+ self.downsample = downsample
90
+ self.stride = stride
91
+
92
+ def forward(self, x):
93
+ identity = x
94
+
95
+ out = self.conv1(x)
96
+ out = self.bn1(out)
97
+ out = self.relu(out)
98
+
99
+ out = self.conv2(out)
100
+ out = self.bn2(out)
101
+ out = self.relu(out)
102
+
103
+ out = self.conv3(out)
104
+ out = self.bn3(out)
105
+
106
+ if self.downsample is not None:
107
+ identity = self.downsample(x)
108
+
109
+ out += identity
110
+ out = self.relu(out)
111
+
112
+ return out
113
+
114
+
115
+ class ResNet(nn.Module):
116
+
117
+ def __init__(self, block, layers, zero_init_residual=False,
118
+ groups=1, width_per_group=64, replace_stride_with_dilation=None,
119
+ norm_layer=None):
120
+ super(ResNet, self).__init__()
121
+ if norm_layer is None:
122
+ norm_layer = nn.BatchNorm2d
123
+ self._norm_layer = norm_layer
124
+
125
+ self.inplanes = 64
126
+ self.dilation = 1
127
+ if replace_stride_with_dilation is None:
128
+ # each element in the tuple indicates if we should replace
129
+ # the 2x2 stride with a dilated convolution instead
130
+ replace_stride_with_dilation = [False, False, False]
131
+ if len(replace_stride_with_dilation) != 3:
132
+ raise ValueError("replace_stride_with_dilation should be None "
133
+ "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
134
+ self.groups = groups
135
+ self.base_width = width_per_group
136
+ self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
137
+ bias=False)
138
+ self.bn1 = norm_layer(self.inplanes)
139
+ self.relu = nn.ReLU(inplace=True)
140
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
141
+ self.layer1 = self._make_layer(block, 64, layers[0])
142
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
143
+ dilate=replace_stride_with_dilation[0])
144
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
145
+ dilate=replace_stride_with_dilation[1])
146
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
147
+ dilate=replace_stride_with_dilation[2])
148
+
149
+ for m in self.modules():
150
+ if isinstance(m, nn.Conv2d):
151
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
152
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
153
+ nn.init.constant_(m.weight, 1)
154
+ nn.init.constant_(m.bias, 0)
155
+
156
+ # Zero-initialize the last BN in each residual branch,
157
+ # so that the residual branch starts with zeros, and each residual block behaves like an identity.
158
+ # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
159
+ if zero_init_residual:
160
+ for m in self.modules():
161
+ if isinstance(m, Bottleneck):
162
+ nn.init.constant_(m.bn3.weight, 0)
163
+ elif isinstance(m, BasicBlock):
164
+ nn.init.constant_(m.bn2.weight, 0)
165
+
166
+ def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
167
+ norm_layer = self._norm_layer
168
+ downsample = None
169
+ previous_dilation = self.dilation
170
+ if dilate:
171
+ self.dilation *= stride
172
+ stride = 1
173
+ if stride != 1 or self.inplanes != planes * block.expansion:
174
+ downsample = nn.Sequential(
175
+ conv1x1(self.inplanes, planes * block.expansion, stride),
176
+ norm_layer(planes * block.expansion),
177
+ )
178
+
179
+ layers = []
180
+ layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
181
+ self.base_width, previous_dilation, norm_layer))
182
+ self.inplanes = planes * block.expansion
183
+ for _ in range(1, blocks):
184
+ layers.append(block(self.inplanes, planes, groups=self.groups,
185
+ base_width=self.base_width, dilation=self.dilation,
186
+ norm_layer=norm_layer))
187
+
188
+ return nn.Sequential(*layers)
189
+
190
+ def forward(self, x):
191
+ x = self.conv1(x)
192
+ x = self.bn1(x)
193
+ x = self.relu(x)
194
+ x = self.maxpool(x)
195
+
196
+ c2 = self.layer1(x)
197
+ c3 = self.layer2(c2)
198
+ c4 = self.layer3(c3)
199
+ c5 = self.layer4(c4)
200
+
201
+ return c2, c3, c4, c5
202
+
203
+
204
+ def _resnet(arch, block, layers, pretrained, progress, **kwargs):
205
+ model = ResNet(block, layers, **kwargs)
206
+ if pretrained:
207
+ state_dict = load_state_dict_from_url(model_urls[arch],
208
+ progress=progress)
209
+ model.load_state_dict(state_dict, strict=False)
210
+ print('load pretrained models from imagenet')
211
+ return model
212
+
213
+
214
+ def resnet18(pretrained=False, progress=True, **kwargs):
215
+ """Constructs a ResNet-18 model.
216
+
217
+ Args:
218
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
219
+ progress (bool): If True, displays a progress bar of the download to stderr
220
+ """
221
+ return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
222
+ **kwargs)
223
+
224
+
225
+ def resnet34(pretrained=False, progress=True, **kwargs):
226
+ """Constructs a ResNet-34 model.
227
+
228
+ Args:
229
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
230
+ progress (bool): If True, displays a progress bar of the download to stderr
231
+ """
232
+ return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
233
+ **kwargs)
234
+
235
+
236
+ def resnet50(pretrained=False, progress=True, **kwargs):
237
+ """Constructs a ResNet-50 model.
238
+
239
+ Args:
240
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
241
+ progress (bool): If True, displays a progress bar of the download to stderr
242
+ """
243
+ return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
244
+ **kwargs)
245
+
246
+
247
+ def resnet101(pretrained=False, progress=True, **kwargs):
248
+ """Constructs a ResNet-101 model.
249
+
250
+ Args:
251
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
252
+ progress (bool): If True, displays a progress bar of the download to stderr
253
+ """
254
+ return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
255
+ **kwargs)
256
+
257
+
258
+ def resnet152(pretrained=False, progress=True, **kwargs):
259
+ """Constructs a ResNet-152 model.
260
+
261
+ Args:
262
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
263
+ progress (bool): If True, displays a progress bar of the download to stderr
264
+ """
265
+ return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
266
+ **kwargs)
267
+
268
+
269
+ def resnext50_32x4d(pretrained=False, progress=True, **kwargs):
270
+ """Constructs a ResNeXt-50 32x4d model.
271
+
272
+ Args:
273
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
274
+ progress (bool): If True, displays a progress bar of the download to stderr
275
+ """
276
+ kwargs['groups'] = 32
277
+ kwargs['width_per_group'] = 4
278
+ return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
279
+ pretrained, progress, **kwargs)
280
+
281
+
282
+ def resnext101_32x8d(pretrained=False, progress=True, **kwargs):
283
+ """Constructs a ResNeXt-101 32x8d model.
284
+
285
+ Args:
286
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
287
+ progress (bool): If True, displays a progress bar of the download to stderr
288
+ """
289
+ kwargs['groups'] = 32
290
+ kwargs['width_per_group'] = 8
291
+ return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
292
+ pretrained, progress, **kwargs)
293
+
294
+ if __name__ == '__main__':
295
+ import torch
296
+ x = torch.zeros(1, 3, 640, 640)
297
+ net = resnext101_32x8d(pretrained=False)
298
+ y = net(x)
299
+ for u in y:
300
+ print(u.shape)
models/modules/segmentation_head.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import torch
4
+ from torch import nn
5
+ import torch.nn.functional as F
6
+
7
+
8
+ class FPN(nn.Module):
9
+ def __init__(self, backbone_out_channels, **kwargs):
10
+ """
11
+ :param backbone_out_channels: 基础网络输出的维度
12
+ :param kwargs:
13
+ """
14
+ super().__init__()
15
+ # result_num = kwargs.get('result_num', 6)
16
+ inplace = True
17
+ conv_out = 256
18
+ # reduce layers
19
+ self.reduce_conv_c2 = nn.Sequential(
20
+ nn.Conv2d(backbone_out_channels[0], conv_out, kernel_size=1, stride=1, padding=0),
21
+ nn.BatchNorm2d(conv_out),
22
+ nn.ReLU(inplace=inplace)
23
+ )
24
+ self.reduce_conv_c3 = nn.Sequential(
25
+ nn.Conv2d(backbone_out_channels[1], conv_out, kernel_size=1, stride=1, padding=0),
26
+ nn.BatchNorm2d(conv_out),
27
+ nn.ReLU(inplace=inplace)
28
+ )
29
+ self.reduce_conv_c4 = nn.Sequential(
30
+ nn.Conv2d(backbone_out_channels[2], conv_out, kernel_size=1, stride=1, padding=0),
31
+ nn.BatchNorm2d(conv_out),
32
+ nn.ReLU(inplace=inplace)
33
+ )
34
+
35
+ self.reduce_conv_c5 = nn.Sequential(
36
+ nn.Conv2d(backbone_out_channels[3], conv_out, kernel_size=1, stride=1, padding=0),
37
+ nn.BatchNorm2d(conv_out),
38
+ nn.ReLU(inplace=inplace)
39
+ )
40
+ # Smooth layers
41
+ self.smooth_p4 = nn.Sequential(
42
+ nn.Conv2d(conv_out, conv_out, kernel_size=3, stride=1, padding=1),
43
+ nn.BatchNorm2d(conv_out),
44
+ nn.ReLU(inplace=inplace)
45
+ )
46
+ self.smooth_p3 = nn.Sequential(
47
+ nn.Conv2d(conv_out, conv_out, kernel_size=3, stride=1, padding=1),
48
+ nn.BatchNorm2d(conv_out),
49
+ nn.ReLU(inplace=inplace)
50
+ )
51
+ self.smooth_p2 = nn.Sequential(
52
+ nn.Conv2d(conv_out, conv_out, kernel_size=3, stride=1, padding=1),
53
+ nn.BatchNorm2d(conv_out),
54
+ nn.ReLU(inplace=inplace)
55
+ )
56
+
57
+ self.conv = nn.Sequential(
58
+ nn.Conv2d(conv_out * 4, conv_out, kernel_size=3, padding=1, stride=1),
59
+ nn.BatchNorm2d(conv_out),
60
+ nn.ReLU(inplace=inplace)
61
+ )
62
+ # self.out_conv = nn.Conv2d(conv_out, result_num, kernel_size=1, stride=1)
63
+
64
+ self.pred_conv = nn.Sequential(
65
+ nn.Conv2d(conv_out, 2, kernel_size=1, stride=1, padding=0),
66
+ nn.Sigmoid()
67
+ )
68
+
69
+ def forward(self, x):
70
+ c2, c3, c4, c5 = x
71
+ # Top-down
72
+ p5 = self.reduce_conv_c5(c5)
73
+ p4 = self._upsample_add(p5, self.reduce_conv_c4(c4))
74
+ p4 = self.smooth_p4(p4)
75
+ p3 = self._upsample_add(p4, self.reduce_conv_c3(c3))
76
+ p3 = self.smooth_p3(p3)
77
+ p2 = self._upsample_add(p3, self.reduce_conv_c2(c2))
78
+ p2 = self.smooth_p2(p2)
79
+
80
+ x = self._upsample_cat(p2, p3, p4, p5)
81
+ x = self.conv(x)
82
+
83
+ # x = self.out_conv(x)
84
+
85
+ x = self.pred_conv(x)
86
+ return x
87
+
88
+ def _upsample_add(self, x, y):
89
+ return F.interpolate(x, size=y.size()[2:], mode='bilinear', align_corners=True) + y
90
+
91
+ def _upsample_cat(self, p2, p3, p4, p5):
92
+ h, w = p2.size()[2:]
93
+ p3 = F.interpolate(p3, size=(h, w), mode='bilinear', align_corners=True)
94
+ p4 = F.interpolate(p4, size=(h, w), mode='bilinear', align_corners=True)
95
+ p5 = F.interpolate(p5, size=(h, w), mode='bilinear', align_corners=True)
96
+ return torch.cat([p2, p3, p4, p5], dim=1)
97
+
98
+
99
+ class FPEM_FFM(nn.Module):
100
+ def __init__(self, backbone_out_channels, **kwargs):
101
+ """
102
+ PANnet
103
+ :param backbone_out_channels: 基础网络输出的维度
104
+ """
105
+ super().__init__()
106
+ fpem_repeat = kwargs.get('fpem_repeat', 2)
107
+ conv_out = 128
108
+ # reduce layers
109
+ self.reduce_conv_c2 = nn.Sequential(
110
+ nn.Conv2d(in_channels=backbone_out_channels[0], out_channels=conv_out, kernel_size=1),
111
+ nn.BatchNorm2d(conv_out),
112
+ nn.ReLU()
113
+ )
114
+ self.reduce_conv_c3 = nn.Sequential(
115
+ nn.Conv2d(in_channels=backbone_out_channels[1], out_channels=conv_out, kernel_size=1),
116
+ nn.BatchNorm2d(conv_out),
117
+ nn.ReLU()
118
+ )
119
+ self.reduce_conv_c4 = nn.Sequential(
120
+ nn.Conv2d(in_channels=backbone_out_channels[2], out_channels=conv_out, kernel_size=1),
121
+ nn.BatchNorm2d(conv_out),
122
+ nn.ReLU()
123
+ )
124
+ self.reduce_conv_c5 = nn.Sequential(
125
+ nn.Conv2d(in_channels=backbone_out_channels[3], out_channels=conv_out, kernel_size=1),
126
+ nn.BatchNorm2d(conv_out),
127
+ nn.ReLU()
128
+ )
129
+ self.fpems = nn.ModuleList()
130
+ for i in range(fpem_repeat):
131
+ self.fpems.append(FPEM(conv_out))
132
+ self.out_conv = nn.Conv2d(in_channels=conv_out * 4, out_channels=6, kernel_size=1)
133
+
134
+ def forward(self, x):
135
+ c2, c3, c4, c5 = x
136
+ # reduce channel
137
+ c2 = self.reduce_conv_c2(c2)
138
+ c3 = self.reduce_conv_c3(c3)
139
+ c4 = self.reduce_conv_c4(c4)
140
+ c5 = self.reduce_conv_c5(c5)
141
+
142
+ # FPEM
143
+ for i, fpem in enumerate(self.fpems):
144
+ c2, c3, c4, c5 = fpem(c2, c3, c4, c5)
145
+ if i == 0:
146
+ c2_ffm = c2
147
+ c3_ffm = c3
148
+ c4_ffm = c4
149
+ c5_ffm = c5
150
+ else:
151
+ c2_ffm += c2
152
+ c3_ffm += c3
153
+ c4_ffm += c4
154
+ c5_ffm += c5
155
+
156
+ # FFM
157
+ c5 = F.interpolate(c5_ffm, c2_ffm.size()[-2:], mode='bilinear')
158
+ c4 = F.interpolate(c4_ffm, c2_ffm.size()[-2:], mode='bilinear')
159
+ c3 = F.interpolate(c3_ffm, c2_ffm.size()[-2:], mode='bilinear')
160
+ Fy = torch.cat([c2_ffm, c3, c4, c5], dim=1)
161
+ y = self.out_conv(Fy)
162
+ return y
163
+
164
+
165
+ class FPEM(nn.Module):
166
+ def __init__(self, in_channels=128):
167
+ super().__init__()
168
+ self.up_add1 = SeparableConv2d(in_channels, in_channels, 1)
169
+ self.up_add2 = SeparableConv2d(in_channels, in_channels, 1)
170
+ self.up_add3 = SeparableConv2d(in_channels, in_channels, 1)
171
+ self.down_add1 = SeparableConv2d(in_channels, in_channels, 2)
172
+ self.down_add2 = SeparableConv2d(in_channels, in_channels, 2)
173
+ self.down_add3 = SeparableConv2d(in_channels, in_channels, 2)
174
+
175
+ def forward(self, c2, c3, c4, c5):
176
+ # up阶段
177
+ c4 = self.up_add1(self._upsample_add(c5, c4))
178
+ c3 = self.up_add2(self._upsample_add(c4, c3))
179
+ c2 = self.up_add3(self._upsample_add(c3, c2))
180
+
181
+ # down 阶段
182
+ c3 = self.down_add1(self._upsample_add(c3, c2))
183
+ c4 = self.down_add2(self._upsample_add(c4, c3))
184
+ c5 = self.down_add3(self._upsample_add(c5, c4))
185
+ return c2, c3, c4, c5
186
+
187
+ def _upsample_add(self, x, y):
188
+ return F.interpolate(x, size=y.size()[2:], mode='bilinear') + y
189
+
190
+
191
+ class SeparableConv2d(nn.Module):
192
+ def __init__(self, in_channels, out_channels, stride=1):
193
+ super(SeparableConv2d, self).__init__()
194
+
195
+ self.depthwise_conv = nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=3, padding=1,
196
+ stride=stride, groups=in_channels)
197
+ self.pointwise_conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1)
198
+ self.bn = nn.BatchNorm2d(out_channels)
199
+ self.relu = nn.ReLU()
200
+
201
+ def forward(self, x):
202
+ x = self.depthwise_conv(x)
203
+ x = self.pointwise_conv(x)
204
+ x = self.bn(x)
205
+ x = self.relu(x)
206
+ return x
models/modules/shufflenetv2.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch.hub import load_state_dict_from_url
6
+ #from torchvision.models.utils import load_state_dict_from_url
7
+
8
+ __all__ = [
9
+ 'ShuffleNetV2', 'shufflenet_v2_x0_5', 'shufflenet_v2_x1_0',
10
+ 'shufflenet_v2_x1_5', 'shufflenet_v2_x2_0'
11
+ ]
12
+
13
+ model_urls = {
14
+ 'shufflenetv2_x0.5': 'https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth',
15
+ 'shufflenetv2_x1.0': 'https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth',
16
+ 'shufflenetv2_x1.5': None,
17
+ 'shufflenetv2_x2.0': None,
18
+ }
19
+
20
+ def channel_shuffle(x, groups):
21
+ batchsize, num_channels, height, width = x.data.size()
22
+ channels_per_group = num_channels // groups
23
+
24
+ # reshape
25
+ x = x.view(batchsize, groups,
26
+ channels_per_group, height, width)
27
+
28
+ x = torch.transpose(x, 1, 2).contiguous()
29
+
30
+ # flatten
31
+ x = x.view(batchsize, -1, height, width)
32
+
33
+ return x
34
+
35
+
36
+ class InvertedResidual(nn.Module):
37
+ def __init__(self, inp, oup, stride):
38
+ super(InvertedResidual, self).__init__()
39
+
40
+ if not (1 <= stride <= 3):
41
+ raise ValueError('illegal stride value')
42
+ self.stride = stride
43
+
44
+ branch_features = oup // 2
45
+ assert (self.stride != 1) or (inp == branch_features << 1)
46
+
47
+ if self.stride > 1:
48
+ self.branch1 = nn.Sequential(
49
+ self.depthwise_conv(inp, inp, kernel_size=3, stride=self.stride, padding=1),
50
+ nn.BatchNorm2d(inp),
51
+ nn.Conv2d(inp, branch_features, kernel_size=1, stride=1, padding=0, bias=False),
52
+ nn.BatchNorm2d(branch_features),
53
+ nn.ReLU(inplace=True),
54
+ )
55
+
56
+ self.branch2 = nn.Sequential(
57
+ nn.Conv2d(inp if (self.stride > 1) else branch_features,
58
+ branch_features, kernel_size=1, stride=1, padding=0, bias=False),
59
+ nn.BatchNorm2d(branch_features),
60
+ nn.ReLU(inplace=True),
61
+ self.depthwise_conv(branch_features, branch_features, kernel_size=3, stride=self.stride, padding=1),
62
+ nn.BatchNorm2d(branch_features),
63
+ nn.Conv2d(branch_features, branch_features, kernel_size=1, stride=1, padding=0, bias=False),
64
+ nn.BatchNorm2d(branch_features),
65
+ nn.ReLU(inplace=True),
66
+ )
67
+
68
+ @staticmethod
69
+ def depthwise_conv(i, o, kernel_size, stride=1, padding=0, bias=False):
70
+ return nn.Conv2d(i, o, kernel_size, stride, padding, bias=bias, groups=i)
71
+
72
+ def forward(self, x):
73
+ if self.stride == 1:
74
+ x1, x2 = x.chunk(2, dim=1)
75
+ out = torch.cat((x1, self.branch2(x2)), dim=1)
76
+ else:
77
+ out = torch.cat((self.branch1(x), self.branch2(x)), dim=1)
78
+
79
+ out = channel_shuffle(out, 2)
80
+
81
+ return out
82
+
83
+
84
+ class ShuffleNetV2(nn.Module):
85
+ def __init__(self, stages_repeats, stages_out_channels, num_classes=1000):
86
+ super(ShuffleNetV2, self).__init__()
87
+
88
+ if len(stages_repeats) != 3:
89
+ raise ValueError('expected stages_repeats as list of 3 positive ints')
90
+ if len(stages_out_channels) != 5:
91
+ raise ValueError('expected stages_out_channels as list of 5 positive ints')
92
+ self._stage_out_channels = stages_out_channels
93
+
94
+ input_channels = 3
95
+ output_channels = self._stage_out_channels[0]
96
+ self.conv1 = nn.Sequential(
97
+ nn.Conv2d(input_channels, output_channels, 3, 2, 1, bias=False),
98
+ nn.BatchNorm2d(output_channels),
99
+ nn.ReLU(inplace=True),
100
+ )
101
+ input_channels = output_channels
102
+
103
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
104
+
105
+ stage_names = ['stage{}'.format(i) for i in [2, 3, 4]]
106
+ for name, repeats, output_channels in zip(
107
+ stage_names, stages_repeats, self._stage_out_channels[1:]):
108
+ seq = [InvertedResidual(input_channels, output_channels, 2)]
109
+ for i in range(repeats - 1):
110
+ seq.append(InvertedResidual(output_channels, output_channels, 1))
111
+ setattr(self, name, nn.Sequential(*seq))
112
+ input_channels = output_channels
113
+
114
+ output_channels = self._stage_out_channels[-1]
115
+ self.conv5 = nn.Sequential(
116
+ nn.Conv2d(input_channels, output_channels, 1, 1, 0, bias=False),
117
+ nn.BatchNorm2d(output_channels),
118
+ nn.ReLU(inplace=True),
119
+ )
120
+
121
+ def forward(self, x):
122
+ x = self.conv1(x)
123
+ c2 = self.maxpool(x)
124
+ c3 = self.stage2(c2)
125
+ c4 = self.stage3(c3)
126
+ c5 = self.stage4(c4)
127
+ # c5 = self.conv5(c5)
128
+ return c2, c3, c4, c5
129
+
130
+
131
+ def _shufflenetv2(arch, pretrained, progress, *args, **kwargs):
132
+ model = ShuffleNetV2(*args, **kwargs)
133
+
134
+ if pretrained:
135
+ model_url = model_urls[arch]
136
+ if model_url is None:
137
+ raise NotImplementedError('pretrained {} is not supported as of now'.format(arch))
138
+ else:
139
+ state_dict = load_state_dict_from_url(model_url, progress=progress)
140
+ model.load_state_dict(state_dict,strict=False)
141
+
142
+ return model
143
+
144
+
145
+ def shufflenet_v2_x0_5(pretrained=False, progress=True, **kwargs):
146
+ """
147
+ Constructs a ShuffleNetV2 with 0.5x output channels, as described in
148
+ `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
149
+ <https://arxiv.org/abs/1807.11164>`_.
150
+
151
+ Args:
152
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
153
+ progress (bool): If True, displays a progress bar of the download to stderr
154
+ """
155
+ return _shufflenetv2('shufflenetv2_x0.5', pretrained, progress,
156
+ [4, 8, 4], [24, 48, 96, 192, 1024], **kwargs)
157
+
158
+
159
+ def shufflenet_v2_x1_0(pretrained=False, progress=True, **kwargs):
160
+ """
161
+ Constructs a ShuffleNetV2 with 1.0x output channels, as described in
162
+ `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
163
+ <https://arxiv.org/abs/1807.11164>`_.
164
+
165
+ Args:
166
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
167
+ progress (bool): If True, displays a progress bar of the download to stderr
168
+ """
169
+ return _shufflenetv2('shufflenetv2_x1.0', pretrained, progress,
170
+ [4, 8, 4], [24, 116, 232, 464, 1024], **kwargs)
171
+
172
+
173
+ def shufflenet_v2_x1_5(pretrained=False, progress=True, **kwargs):
174
+ """
175
+ Constructs a ShuffleNetV2 with 1.5x output channels, as described in
176
+ `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
177
+ <https://arxiv.org/abs/1807.11164>`_.
178
+
179
+ Args:
180
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
181
+ progress (bool): If True, displays a progress bar of the download to stderr
182
+ """
183
+ return _shufflenetv2('shufflenetv2_x1.5', pretrained, progress,
184
+ [4, 8, 4], [24, 176, 352, 704, 1024], **kwargs)
185
+
186
+
187
+ def shufflenet_v2_x2_0(pretrained=False, progress=True, **kwargs):
188
+ """
189
+ Constructs a ShuffleNetV2 with 2.0x output channels, as described in
190
+ `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
191
+ <https://arxiv.org/abs/1807.11164>`_.
192
+
193
+ Args:
194
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
195
+ progress (bool): If True, displays a progress bar of the download to stderr
196
+ """
197
+ return _shufflenetv2('shufflenetv2_x2.0', pretrained, progress,
198
+ [4, 8, 4], [24, 244, 488, 976, 2048], **kwargs)
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ opencv-python
4
+ numpy
5
+ pyclipper
6
+ gradio
7
+ matplotlib
utils/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from .util import *
4
+ from .metrics import *
5
+ from .schedulers import *
6
+ from .cal_recall.script import cal_recall_precison_f1
utils/cal_recall/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from .script import cal_recall_precison_f1
4
+ __all__ = ['cal_recall_precison_f1']
utils/cal_recall/rrc_evaluation_funcs.py ADDED
@@ -0,0 +1,394 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python2
2
+ #encoding: UTF-8
3
+ import json
4
+ import sys;sys.path.append('./')
5
+ import zipfile
6
+ import re
7
+ import sys
8
+ import os
9
+ import codecs
10
+ import traceback
11
+ import numpy as np
12
+ from utils import order_points_clockwise
13
+
14
+ def print_help():
15
+ sys.stdout.write('Usage: python %s.py -g=<gtFile> -s=<submFile> [-o=<outputFolder> -p=<jsonParams>]' %sys.argv[0])
16
+ sys.exit(2)
17
+
18
+
19
+ def load_zip_file_keys(file,fileNameRegExp=''):
20
+ """
21
+ Returns an array with the entries of the ZIP file that match with the regular expression.
22
+ The key's are the names or the file or the capturing group definied in the fileNameRegExp
23
+ """
24
+ try:
25
+ archive=zipfile.ZipFile(file, mode='r', allowZip64=True)
26
+ except :
27
+ raise Exception('Error loading the ZIP archive.')
28
+
29
+ pairs = []
30
+
31
+ for name in archive.namelist():
32
+ addFile = True
33
+ keyName = name
34
+ if fileNameRegExp!="":
35
+ m = re.match(fileNameRegExp,name)
36
+ if m == None:
37
+ addFile = False
38
+ else:
39
+ if len(m.groups())>0:
40
+ keyName = m.group(1)
41
+
42
+ if addFile:
43
+ pairs.append( keyName )
44
+
45
+ return pairs
46
+
47
+
48
+ def load_zip_file(file,fileNameRegExp='',allEntries=False):
49
+ """
50
+ Returns an array with the contents (filtered by fileNameRegExp) of a ZIP file.
51
+ The key's are the names or the file or the capturing group definied in the fileNameRegExp
52
+ allEntries validates that all entries in the ZIP file pass the fileNameRegExp
53
+ """
54
+ try:
55
+ archive=zipfile.ZipFile(file, mode='r', allowZip64=True)
56
+ except :
57
+ raise Exception('Error loading the ZIP archive')
58
+
59
+ pairs = []
60
+ for name in archive.namelist():
61
+ addFile = True
62
+ keyName = name
63
+ if fileNameRegExp!="":
64
+ m = re.match(fileNameRegExp,name)
65
+ if m == None:
66
+ addFile = False
67
+ else:
68
+ if len(m.groups())>0:
69
+ keyName = m.group(1)
70
+
71
+ if addFile:
72
+ pairs.append( [ keyName , archive.read(name)] )
73
+ else:
74
+ if allEntries:
75
+ raise Exception('ZIP entry not valid: %s' %name)
76
+
77
+ return dict(pairs)
78
+
79
+
80
+ def load_folder_file(file, fileNameRegExp='', allEntries=False):
81
+ """
82
+ Returns an array with the contents (filtered by fileNameRegExp) of a ZIP file.
83
+ The key's are the names or the file or the capturing group definied in the fileNameRegExp
84
+ allEntries validates that all entries in the ZIP file pass the fileNameRegExp
85
+ """
86
+ pairs = []
87
+ for name in os.listdir(file):
88
+ addFile = True
89
+ keyName = name
90
+ print(name)
91
+ print(fileNameRegExp)
92
+ if fileNameRegExp != "":
93
+ m = re.match(fileNameRegExp, name)
94
+ print(m)
95
+ if m == None:
96
+ addFile = False
97
+ else:
98
+ if len(m.groups()) > 0:
99
+ keyName = m.group(1)
100
+
101
+ if addFile:
102
+ pairs.append([keyName, open(os.path.join(file,name),encoding="utf-8").read()])
103
+ else:
104
+ if allEntries:
105
+ raise Exception('ZIP entry not valid: %s' % name)
106
+
107
+ return dict(pairs)
108
+
109
+
110
+ def decode_utf8(raw):
111
+ """
112
+ Returns a Unicode object on success, or None on failure
113
+ """
114
+ try:
115
+ raw = codecs.decode(raw,'utf-8', 'replace')
116
+ #extracts BOM if exists
117
+ raw = raw.encode('utf8')
118
+ if raw.startswith(codecs.BOM_UTF8):
119
+ raw = raw.replace(codecs.BOM_UTF8, '', 1)
120
+ return raw.decode('utf-8')
121
+ except:
122
+ return None
123
+
124
+ def validate_lines_in_file(fileName,file_contents,CRLF=True,LTRB=True,withTranscription=False,withConfidence=False,imWidth=0,imHeight=0):
125
+ """
126
+ This function validates that all lines of the file calling the Line validation function for each line
127
+ """
128
+ utf8File = decode_utf8(file_contents)
129
+ if (utf8File is None) :
130
+ raise Exception("The file %s is not UTF-8" %fileName)
131
+
132
+ lines = utf8File.split( "\r\n" if CRLF else "\n" )
133
+ for line in lines:
134
+ line = line.replace("\r","").replace("\n","")
135
+ if(line != ""):
136
+ try:
137
+ validate_tl_line(line,LTRB,withTranscription,withConfidence,imWidth,imHeight)
138
+ except Exception as e:
139
+ raise Exception(("Line in sample not valid. Sample: %s Line: %s Error: %s" %(fileName,line,str(e))).encode('utf-8', 'replace'))
140
+
141
+
142
+
143
+ def validate_tl_line(line,LTRB=True,withTranscription=True,withConfidence=True,imWidth=0,imHeight=0):
144
+ """
145
+ Validate the format of the line. If the line is not valid an exception will be raised.
146
+ If maxWidth and maxHeight are specified, all points must be inside the imgage bounds.
147
+ Posible values are:
148
+ LTRB=True: xmin,ymin,xmax,ymax[,confidence][,transcription]
149
+ LTRB=False: x1,y1,x2,y2,x3,y3,x4,y4[,confidence][,transcription]
150
+ """
151
+ get_tl_line_values(line,LTRB,withTranscription,withConfidence,imWidth,imHeight)
152
+
153
+
154
+ def get_tl_line_values(line,LTRB=True,withTranscription=False,withConfidence=False,imWidth=0,imHeight=0):
155
+ """
156
+ Validate the format of the line. If the line is not valid an exception will be raised.
157
+ If maxWidth and maxHeight are specified, all points must be inside the imgage bounds.
158
+ Posible values are:
159
+ LTRB=True: xmin,ymin,xmax,ymax[,confidence][,transcription]
160
+ LTRB=False: x1,y1,x2,y2,x3,y3,x4,y4[,confidence][,transcription]
161
+ Returns values from a textline. Points , [Confidences], [Transcriptions]
162
+ """
163
+ confidence = 0.0
164
+ transcription = "";
165
+ points = []
166
+
167
+ numPoints = 4;
168
+
169
+ if LTRB:
170
+
171
+ numPoints = 4;
172
+
173
+ if withTranscription and withConfidence:
174
+ m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-1].?[0-9]*)\s*,(.*)$',line)
175
+ if m == None :
176
+ m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-1].?[0-9]*)\s*,(.*)$',line)
177
+ raise Exception("Format incorrect. Should be: xmin,ymin,xmax,ymax,confidence,transcription")
178
+ elif withConfidence:
179
+ m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-1].?[0-9]*)\s*$',line)
180
+ if m == None :
181
+ raise Exception("Format incorrect. Should be: xmin,ymin,xmax,ymax,confidence")
182
+ elif withTranscription:
183
+ m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,(.*)$',line)
184
+ if m == None :
185
+ raise Exception("Format incorrect. Should be: xmin,ymin,xmax,ymax,transcription")
186
+ else:
187
+ m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,?\s*$',line)
188
+ if m == None :
189
+ raise Exception("Format incorrect. Should be: xmin,ymin,xmax,ymax")
190
+
191
+ xmin = int(m.group(1))
192
+ ymin = int(m.group(2))
193
+ xmax = int(m.group(3))
194
+ ymax = int(m.group(4))
195
+ if(xmax<xmin):
196
+ raise Exception("Xmax value (%s) not valid (Xmax < Xmin)." %(xmax))
197
+ if(ymax<ymin):
198
+ raise Exception("Ymax value (%s) not valid (Ymax < Ymin)." %(ymax))
199
+
200
+ points = [ float(m.group(i)) for i in range(1, (numPoints+1) ) ]
201
+
202
+ if (imWidth>0 and imHeight>0):
203
+ validate_point_inside_bounds(xmin,ymin,imWidth,imHeight);
204
+ validate_point_inside_bounds(xmax,ymax,imWidth,imHeight);
205
+
206
+ else:
207
+
208
+ numPoints = 8;
209
+
210
+ if withTranscription and withConfidence:
211
+ m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-1].?[0-9]*)\s*,(.*)$',line)
212
+ if m == None :
213
+ raise Exception("Format incorrect. Should be: x1,y1,x2,y2,x3,y3,x4,y4,confidence,transcription")
214
+ elif withConfidence:
215
+ m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-1].?[0-9]*)\s*$',line)
216
+ if m == None :
217
+ raise Exception("Format incorrect. Should be: x1,y1,x2,y2,x3,y3,x4,y4,confidence")
218
+ elif withTranscription:
219
+ m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,(.*)$',line)
220
+ if m == None :
221
+ raise Exception("Format incorrect. Should be: x1,y1,x2,y2,x3,y3,x4,y4,transcription")
222
+ else:
223
+ m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*$',line)
224
+ if m == None :
225
+ raise Exception("Format incorrect. Should be: x1,y1,x2,y2,x3,y3,x4,y4")
226
+
227
+ points = [ float(m.group(i)) for i in range(1, (numPoints+1) ) ]
228
+
229
+ points = order_points_clockwise(np.array(points).reshape(-1, 2)).reshape(-1)
230
+ validate_clockwise_points(points)
231
+
232
+ if (imWidth>0 and imHeight>0):
233
+ validate_point_inside_bounds(points[0],points[1],imWidth,imHeight);
234
+ validate_point_inside_bounds(points[2],points[3],imWidth,imHeight);
235
+ validate_point_inside_bounds(points[4],points[5],imWidth,imHeight);
236
+ validate_point_inside_bounds(points[6],points[7],imWidth,imHeight);
237
+
238
+
239
+ if withConfidence:
240
+ try:
241
+ confidence = float(m.group(numPoints+1))
242
+ except ValueError:
243
+ raise Exception("Confidence value must be a float")
244
+
245
+ if withTranscription:
246
+ posTranscription = numPoints + (2 if withConfidence else 1)
247
+ transcription = m.group(posTranscription)
248
+ m2 = re.match(r'^\s*\"(.*)\"\s*$',transcription)
249
+ if m2 != None : #Transcription with double quotes, we extract the value and replace escaped characters
250
+ transcription = m2.group(1).replace("\\\\", "\\").replace("\\\"", "\"")
251
+
252
+ return points,confidence,transcription
253
+
254
+
255
+ def validate_point_inside_bounds(x,y,imWidth,imHeight):
256
+ if(x<0 or x>imWidth):
257
+ raise Exception("X value (%s) not valid. Image dimensions: (%s,%s)" %(xmin,imWidth,imHeight))
258
+ if(y<0 or y>imHeight):
259
+ raise Exception("Y value (%s) not valid. Image dimensions: (%s,%s) Sample: %s Line:%s" %(ymin,imWidth,imHeight))
260
+
261
+ def validate_clockwise_points(points):
262
+ """
263
+ Validates that the points that the 4 points that dlimite a polygon are in clockwise order.
264
+ """
265
+
266
+ if len(points) != 8:
267
+ raise Exception("Points list not valid." + str(len(points)))
268
+
269
+ point = [
270
+ [int(points[0]) , int(points[1])],
271
+ [int(points[2]) , int(points[3])],
272
+ [int(points[4]) , int(points[5])],
273
+ [int(points[6]) , int(points[7])]
274
+ ]
275
+ edge = [
276
+ ( point[1][0] - point[0][0])*( point[1][1] + point[0][1]),
277
+ ( point[2][0] - point[1][0])*( point[2][1] + point[1][1]),
278
+ ( point[3][0] - point[2][0])*( point[3][1] + point[2][1]),
279
+ ( point[0][0] - point[3][0])*( point[0][1] + point[3][1])
280
+ ]
281
+
282
+ summatory = edge[0] + edge[1] + edge[2] + edge[3];
283
+ if summatory>0:
284
+ raise Exception("Points are not clockwise. The coordinates of bounding quadrilaterals have to be given in clockwise order. Regarding the correct interpretation of 'clockwise' remember that the image coordinate system used is the standard one, with the image origin at the upper left, the X axis extending to the right and Y axis extending downwards.")
285
+
286
+ def get_tl_line_values_from_file_contents(content,CRLF=True,LTRB=True,withTranscription=False,withConfidence=False,imWidth=0,imHeight=0,sort_by_confidences=True):
287
+ """
288
+ Returns all points, confindences and transcriptions of a file in lists. Valid line formats:
289
+ xmin,ymin,xmax,ymax,[confidence],[transcription]
290
+ x1,y1,x2,y2,x3,y3,x4,y4,[confidence],[transcription]
291
+ """
292
+ pointsList = []
293
+ transcriptionsList = []
294
+ confidencesList = []
295
+
296
+ lines = content.split( "\r\n" if CRLF else "\n" )
297
+ for line in lines:
298
+ line = line.replace("\r","").replace("\n","")
299
+ if(line != "") :
300
+ points, confidence, transcription = get_tl_line_values(line,LTRB,withTranscription,withConfidence,imWidth,imHeight);
301
+ pointsList.append(points)
302
+ transcriptionsList.append(transcription)
303
+ confidencesList.append(confidence)
304
+
305
+ if withConfidence and len(confidencesList)>0 and sort_by_confidences:
306
+ import numpy as np
307
+ sorted_ind = np.argsort(-np.array(confidencesList))
308
+ confidencesList = [confidencesList[i] for i in sorted_ind]
309
+ pointsList = [pointsList[i] for i in sorted_ind]
310
+ transcriptionsList = [transcriptionsList[i] for i in sorted_ind]
311
+
312
+ return pointsList,confidencesList,transcriptionsList
313
+
314
+ def main_evaluation(p,default_evaluation_params_fn,validate_data_fn,evaluate_method_fn,show_result=True,per_sample=True):
315
+ """
316
+ This process validates a method, evaluates it and if it succed generates a ZIP file with a JSON entry for each sample.
317
+ Params:
318
+ p: Dictionary of parmeters with the GT/submission locations. If None is passed, the parameters send by the system are used.
319
+ default_evaluation_params_fn: points to a function that returns a dictionary with the default parameters used for the evaluation
320
+ validate_data_fn: points to a method that validates the corrct format of the submission
321
+ evaluate_method_fn: points to a function that evaluated the submission and return a Dictionary with the results
322
+ """
323
+ evalParams = default_evaluation_params_fn()
324
+ if 'p' in p.keys():
325
+ evalParams.update( p['p'] if isinstance(p['p'], dict) else json.loads(p['p'][1:-1]) )
326
+
327
+ resDict={'calculated':True,'Message':'','method':'{}','per_sample':'{}'}
328
+ try:
329
+ # validate_data_fn(p['g'], p['s'], evalParams)
330
+ evalData = evaluate_method_fn(p['g'], p['s'], evalParams)
331
+ resDict.update(evalData)
332
+
333
+ except Exception as e:
334
+ traceback.print_exc()
335
+ resDict['Message']= str(e)
336
+ resDict['calculated']=False
337
+
338
+ if 'o' in p:
339
+ if not os.path.exists(p['o']):
340
+ os.makedirs(p['o'])
341
+
342
+ resultsOutputname = p['o'] + '/results.zip'
343
+ outZip = zipfile.ZipFile(resultsOutputname, mode='w', allowZip64=True)
344
+
345
+ del resDict['per_sample']
346
+ if 'output_items' in resDict.keys():
347
+ del resDict['output_items']
348
+
349
+ outZip.writestr('method.json',json.dumps(resDict))
350
+
351
+ if not resDict['calculated']:
352
+ if show_result:
353
+ sys.stderr.write('Error!\n'+ resDict['Message']+'\n\n')
354
+ if 'o' in p:
355
+ outZip.close()
356
+ return resDict
357
+
358
+ if 'o' in p:
359
+ if per_sample == True:
360
+ for k,v in evalData['per_sample'].iteritems():
361
+ outZip.writestr( k + '.json',json.dumps(v))
362
+
363
+ if 'output_items' in evalData.keys():
364
+ for k, v in evalData['output_items'].iteritems():
365
+ outZip.writestr( k,v)
366
+
367
+ outZip.close()
368
+
369
+ if show_result:
370
+ sys.stdout.write("Calculated!")
371
+ sys.stdout.write(json.dumps(resDict['method']))
372
+
373
+ return resDict
374
+
375
+
376
+ def main_validation(default_evaluation_params_fn,validate_data_fn):
377
+ """
378
+ This process validates a method
379
+ Params:
380
+ default_evaluation_params_fn: points to a function that returns a dictionary with the default parameters used for the evaluation
381
+ validate_data_fn: points to a method that validates the corrct format of the submission
382
+ """
383
+ try:
384
+ p = dict([s[1:].split('=') for s in sys.argv[1:]])
385
+ evalParams = default_evaluation_params_fn()
386
+ if 'p' in p.keys():
387
+ evalParams.update( p['p'] if isinstance(p['p'], dict) else json.loads(p['p'][1:-1]) )
388
+
389
+ validate_data_fn(p['g'], p['s'], evalParams)
390
+ print('SUCCESS')
391
+ sys.exit(0)
392
+ except Exception as e:
393
+ print(str(e))
394
+ sys.exit(101)
utils/cal_recall/script.py ADDED
@@ -0,0 +1,325 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ from collections import namedtuple
4
+ from . import rrc_evaluation_funcs
5
+ #import Polygon as plg
6
+ import shapely.geometry as plg
7
+ import numpy as np
8
+
9
+
10
+ def default_evaluation_params():
11
+ """
12
+ default_evaluation_params: Default parameters to use for the validation and evaluation.
13
+ """
14
+ return {
15
+ 'IOU_CONSTRAINT': 0.5,
16
+ 'AREA_PRECISION_CONSTRAINT': 0.5,
17
+ 'GT_SAMPLE_NAME_2_ID': '.+_([0-9]+).txt',
18
+ 'DET_SAMPLE_NAME_2_ID': 'res_.+_([0-9]+).txt',
19
+ 'LTRB': False, # LTRB:2points(left,top,right,bottom) or 4 points(x1,y1,x2,y2,x3,y3,x4,y4)
20
+ 'CRLF': False, # Lines are delimited by Windows CRLF format
21
+ 'CONFIDENCES': False, # Detections must include confidence value. AP will be calculated
22
+ 'PER_SAMPLE_RESULTS': True # Generate per sample results and produce data for visualization
23
+ }
24
+
25
+
26
+ def validate_data(gtFilePath, submFilePath, evaluationParams):
27
+ """
28
+ Method validate_data: validates that all files in the results folder are correct (have the correct name contents).
29
+ Validates also that there are no missing files in the folder.
30
+ If some error detected, the method raises the error
31
+ """
32
+ gt = rrc_evaluation_funcs.load_folder_file(gtFilePath, evaluationParams['GT_SAMPLE_NAME_2_ID'])
33
+
34
+ subm = rrc_evaluation_funcs.load_folder_file(submFilePath, evaluationParams['DET_SAMPLE_NAME_2_ID'], True)
35
+
36
+ # Validate format of GroundTruth
37
+ for k in gt:
38
+ rrc_evaluation_funcs.validate_lines_in_file(k, gt[k], evaluationParams['CRLF'], evaluationParams['LTRB'], True)
39
+
40
+ # Validate format of results
41
+ for k in subm:
42
+ if (k in gt) == False:
43
+ raise Exception("The sample %s not present in GT" % k)
44
+
45
+ rrc_evaluation_funcs.validate_lines_in_file(k, subm[k], evaluationParams['CRLF'], evaluationParams['LTRB'],
46
+ False, evaluationParams['CONFIDENCES'])
47
+
48
+
49
+ def evaluate_method(gtFilePath, submFilePath, evaluationParams):
50
+ """
51
+ Method evaluate_method: evaluate method and returns the results
52
+ Results. Dictionary with the following values:
53
+ - method (required) Global method metrics. Ex: { 'Precision':0.8,'Recall':0.9 }
54
+ - samples (optional) Per sample metrics. Ex: {'sample1' : { 'Precision':0.8,'Recall':0.9 } , 'sample2' : { 'Precision':0.8,'Recall':0.9 }
55
+ """
56
+
57
+ def polygon_from_points(points):
58
+ """
59
+ Returns a Polygon object to use with the Polygon2 class from a list of 8 points: x1,y1,x2,y2,x3,y3,x4,y4
60
+ """
61
+ resBoxes = np.empty([1, 8], dtype='int32')
62
+ resBoxes[0, 0] = int(points[0])
63
+ resBoxes[0, 4] = int(points[1])
64
+ resBoxes[0, 1] = int(points[2])
65
+ resBoxes[0, 5] = int(points[3])
66
+ resBoxes[0, 2] = int(points[4])
67
+ resBoxes[0, 6] = int(points[5])
68
+ resBoxes[0, 3] = int(points[6])
69
+ resBoxes[0, 7] = int(points[7])
70
+ pointMat = resBoxes[0].reshape([2, 4]).T
71
+ return plg.Polygon(pointMat)
72
+
73
+ def rectangle_to_polygon(rect):
74
+ resBoxes = np.empty([1, 8], dtype='int32')
75
+ resBoxes[0, 0] = int(rect.xmin)
76
+ resBoxes[0, 4] = int(rect.ymax)
77
+ resBoxes[0, 1] = int(rect.xmin)
78
+ resBoxes[0, 5] = int(rect.ymin)
79
+ resBoxes[0, 2] = int(rect.xmax)
80
+ resBoxes[0, 6] = int(rect.ymin)
81
+ resBoxes[0, 3] = int(rect.xmax)
82
+ resBoxes[0, 7] = int(rect.ymax)
83
+
84
+ pointMat = resBoxes[0].reshape([2, 4]).T
85
+
86
+ return plg.Polygon(pointMat)
87
+
88
+ def rectangle_to_points(rect):
89
+ points = [int(rect.xmin), int(rect.ymax), int(rect.xmax), int(rect.ymax), int(rect.xmax), int(rect.ymin),
90
+ int(rect.xmin), int(rect.ymin)]
91
+ return points
92
+
93
+ def get_union(pD, pG):
94
+ areaA = pD.area;
95
+ areaB = pG.area;
96
+ return areaA + areaB - get_intersection(pD, pG);
97
+
98
+ def get_intersection_over_union(pD, pG):
99
+ try:
100
+ return get_intersection(pD, pG) / get_union(pD, pG);
101
+ except:
102
+ return 0
103
+
104
+ def get_intersection(pD, pG):
105
+ pInt = pD & pG
106
+ if pInt.is_empty:
107
+ return 0
108
+ return pInt.area
109
+
110
+ def compute_ap(confList, matchList, numGtCare):
111
+ correct = 0
112
+ AP = 0
113
+ if len(confList) > 0:
114
+ confList = np.array(confList)
115
+ matchList = np.array(matchList)
116
+ sorted_ind = np.argsort(-confList)
117
+ confList = confList[sorted_ind]
118
+ matchList = matchList[sorted_ind]
119
+ for n in range(len(confList)):
120
+ match = matchList[n]
121
+ if match:
122
+ correct += 1
123
+ AP += float(correct) / (n + 1)
124
+
125
+ if numGtCare > 0:
126
+ AP /= numGtCare
127
+
128
+ return AP
129
+
130
+ perSampleMetrics = {}
131
+
132
+ matchedSum = 0
133
+
134
+ Rectangle = namedtuple('Rectangle', 'xmin ymin xmax ymax')
135
+
136
+ gt = rrc_evaluation_funcs.load_folder_file(gtFilePath, evaluationParams['GT_SAMPLE_NAME_2_ID'])
137
+ subm = rrc_evaluation_funcs.load_folder_file(submFilePath, evaluationParams['DET_SAMPLE_NAME_2_ID'], True)
138
+
139
+ numGlobalCareGt = 0;
140
+ numGlobalCareDet = 0;
141
+
142
+ arrGlobalConfidences = [];
143
+ arrGlobalMatches = [];
144
+
145
+ for resFile in gt:
146
+
147
+ gtFile = gt[resFile] # rrc_evaluation_funcs.decode_utf8(gt[resFile])
148
+ recall = 0
149
+ precision = 0
150
+ hmean = 0
151
+
152
+ detMatched = 0
153
+
154
+ iouMat = np.empty([1, 1])
155
+
156
+ gtPols = []
157
+ detPols = []
158
+
159
+ gtPolPoints = []
160
+ detPolPoints = []
161
+
162
+ # Array of Ground Truth Polygons' keys marked as don't Care
163
+ gtDontCarePolsNum = []
164
+ # Array of Detected Polygons' matched with a don't Care GT
165
+ detDontCarePolsNum = []
166
+
167
+ pairs = []
168
+ detMatchedNums = []
169
+
170
+ arrSampleConfidences = [];
171
+ arrSampleMatch = [];
172
+ sampleAP = 0;
173
+
174
+ evaluationLog = ""
175
+
176
+ pointsList, _, transcriptionsList = rrc_evaluation_funcs.get_tl_line_values_from_file_contents(gtFile,
177
+ evaluationParams[
178
+ 'CRLF'],
179
+ evaluationParams[
180
+ 'LTRB'],
181
+ True, False)
182
+ for n in range(len(pointsList)):
183
+ points = pointsList[n]
184
+ transcription = transcriptionsList[n]
185
+ dontCare = transcription == "###"
186
+ if evaluationParams['LTRB']:
187
+ gtRect = Rectangle(*points)
188
+ gtPol = rectangle_to_polygon(gtRect)
189
+ else:
190
+ gtPol = polygon_from_points(points)
191
+ gtPols.append(gtPol)
192
+ gtPolPoints.append(points)
193
+ if dontCare:
194
+ gtDontCarePolsNum.append(len(gtPols) - 1)
195
+
196
+ evaluationLog += "GT polygons: " + str(len(gtPols)) + (
197
+ " (" + str(len(gtDontCarePolsNum)) + " don't care)\n" if len(gtDontCarePolsNum) > 0 else "\n")
198
+
199
+ if resFile in subm:
200
+
201
+ detFile = subm[resFile] # rrc_evaluation_funcs.decode_utf8(subm[resFile])
202
+
203
+ pointsList, confidencesList, _ = rrc_evaluation_funcs.get_tl_line_values_from_file_contents(detFile,
204
+ evaluationParams[
205
+ 'CRLF'],
206
+ evaluationParams[
207
+ 'LTRB'],
208
+ False,
209
+ evaluationParams[
210
+ 'CONFIDENCES'])
211
+ for n in range(len(pointsList)):
212
+ points = pointsList[n]
213
+
214
+ if evaluationParams['LTRB']:
215
+ detRect = Rectangle(*points)
216
+ detPol = rectangle_to_polygon(detRect)
217
+ else:
218
+ detPol = polygon_from_points(points)
219
+ detPols.append(detPol)
220
+ detPolPoints.append(points)
221
+ if len(gtDontCarePolsNum) > 0:
222
+ for dontCarePol in gtDontCarePolsNum:
223
+ dontCarePol = gtPols[dontCarePol]
224
+ intersected_area = get_intersection(dontCarePol, detPol)
225
+ pdDimensions = detPol.area
226
+ precision = 0 if pdDimensions == 0 else intersected_area / pdDimensions
227
+ if (precision > evaluationParams['AREA_PRECISION_CONSTRAINT']):
228
+ detDontCarePolsNum.append(len(detPols) - 1)
229
+ break
230
+
231
+ evaluationLog += "DET polygons: " + str(len(detPols)) + (
232
+ " (" + str(len(detDontCarePolsNum)) + " don't care)\n" if len(detDontCarePolsNum) > 0 else "\n")
233
+
234
+ if len(gtPols) > 0 and len(detPols) > 0:
235
+ # Calculate IoU and precision matrixs
236
+ outputShape = [len(gtPols), len(detPols)]
237
+ iouMat = np.empty(outputShape)
238
+ gtRectMat = np.zeros(len(gtPols), np.int8)
239
+ detRectMat = np.zeros(len(detPols), np.int8)
240
+ for gtNum in range(len(gtPols)):
241
+ for detNum in range(len(detPols)):
242
+ pG = gtPols[gtNum]
243
+ pD = detPols[detNum]
244
+ iouMat[gtNum, detNum] = get_intersection_over_union(pD, pG)
245
+
246
+ for gtNum in range(len(gtPols)):
247
+ for detNum in range(len(detPols)):
248
+ if gtRectMat[gtNum] == 0 and detRectMat[
249
+ detNum] == 0 and gtNum not in gtDontCarePolsNum and detNum not in detDontCarePolsNum:
250
+ if iouMat[gtNum, detNum] > evaluationParams['IOU_CONSTRAINT']:
251
+ gtRectMat[gtNum] = 1
252
+ detRectMat[detNum] = 1
253
+ detMatched += 1
254
+ pairs.append({'gt': gtNum, 'det': detNum})
255
+ detMatchedNums.append(detNum)
256
+ evaluationLog += "Match GT #" + str(gtNum) + " with Det #" + str(detNum) + "\n"
257
+
258
+ if evaluationParams['CONFIDENCES']:
259
+ for detNum in range(len(detPols)):
260
+ if detNum not in detDontCarePolsNum:
261
+ # we exclude the don't care detections
262
+ match = detNum in detMatchedNums
263
+
264
+ arrSampleConfidences.append(confidencesList[detNum])
265
+ arrSampleMatch.append(match)
266
+
267
+ arrGlobalConfidences.append(confidencesList[detNum]);
268
+ arrGlobalMatches.append(match);
269
+
270
+ numGtCare = (len(gtPols) - len(gtDontCarePolsNum))
271
+ numDetCare = (len(detPols) - len(detDontCarePolsNum))
272
+ if numGtCare == 0:
273
+ recall = float(1)
274
+ precision = float(0) if numDetCare > 0 else float(1)
275
+ sampleAP = precision
276
+ else:
277
+ recall = float(detMatched) / numGtCare
278
+ precision = 0 if numDetCare == 0 else float(detMatched) / numDetCare
279
+ if evaluationParams['CONFIDENCES'] and evaluationParams['PER_SAMPLE_RESULTS']:
280
+ sampleAP = compute_ap(arrSampleConfidences, arrSampleMatch, numGtCare)
281
+
282
+ hmean = 0 if (precision + recall) == 0 else 2.0 * precision * recall / (precision + recall)
283
+
284
+ matchedSum += detMatched
285
+ numGlobalCareGt += numGtCare
286
+ numGlobalCareDet += numDetCare
287
+
288
+ if evaluationParams['PER_SAMPLE_RESULTS']:
289
+ perSampleMetrics[resFile] = {
290
+ 'precision': precision,
291
+ 'recall': recall,
292
+ 'hmean': hmean,
293
+ 'pairs': pairs,
294
+ 'AP': sampleAP,
295
+ 'iouMat': [] if len(detPols) > 100 else iouMat.tolist(),
296
+ 'gtPolPoints': gtPolPoints,
297
+ 'detPolPoints': detPolPoints,
298
+ 'gtDontCare': gtDontCarePolsNum,
299
+ 'detDontCare': detDontCarePolsNum,
300
+ 'evaluationParams': evaluationParams,
301
+ 'evaluationLog': evaluationLog
302
+ }
303
+
304
+ # Compute MAP and MAR
305
+ AP = 0
306
+ if evaluationParams['CONFIDENCES']:
307
+ AP = compute_ap(arrGlobalConfidences, arrGlobalMatches, numGlobalCareGt)
308
+
309
+ methodRecall = 0 if numGlobalCareGt == 0 else float(matchedSum) / numGlobalCareGt
310
+ methodPrecision = 0 if numGlobalCareDet == 0 else float(matchedSum) / numGlobalCareDet
311
+ methodHmean = 0 if methodRecall + methodPrecision == 0 else 2 * methodRecall * methodPrecision / (
312
+ methodRecall + methodPrecision)
313
+
314
+ methodMetrics = {'precision': methodPrecision, 'recall': methodRecall, 'hmean': methodHmean, 'AP': AP}
315
+
316
+ resDict = {'calculated': True, 'Message': '', 'method': methodMetrics, 'per_sample': perSampleMetrics}
317
+
318
+ return resDict;
319
+
320
+
321
+ def cal_recall_precison_f1(gt_path, result_path, show_result=False):
322
+ p = {'g': gt_path, 's': result_path}
323
+ result = rrc_evaluation_funcs.main_evaluation(p, default_evaluation_params, validate_data, evaluate_method,
324
+ show_result)
325
+ return result['method']
utils/make_trainfile.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import os
4
+ import glob
5
+ import pathlib
6
+
7
+ data_path = r'E:\zj\dataset\icdar2015\test'
8
+ # data_path/img 存放图片
9
+ # data_path/gt 存放标签文件
10
+
11
+ f_w = open(os.path.join(data_path, 'test.txt'), 'w', encoding='utf8')
12
+ for img_path in glob.glob(data_path + '/img/*.jpg', recursive=True):
13
+ d = pathlib.Path(img_path)
14
+ label_path = os.path.join(data_path, 'gt', ('gt_' + str(d.stem) + '.txt'))
15
+ if os.path.exists(img_path) and os.path.exists(label_path):
16
+ print(img_path, label_path)
17
+ else:
18
+ print('不存在', img_path, label_path)
19
+ f_w.write('{}\t{}\n'.format(img_path, label_path))
20
+ f_w.close()
utils/metrics.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from score written by wkentaro
2
+ # https://github.com/wkentaro/pytorch-fcn/blob/master/torchfcn/utils.py
3
+
4
+ import numpy as np
5
+
6
+
7
+ class runningScore(object):
8
+
9
+ def __init__(self, n_classes):
10
+ self.n_classes = n_classes
11
+ self.confusion_matrix = np.zeros((n_classes, n_classes))
12
+
13
+ def _fast_hist(self, label_true, label_pred, n_class):
14
+ mask = (label_true >= 0) & (label_true < n_class)
15
+
16
+ if np.sum((label_pred[mask] < 0)) > 0:
17
+ print(label_pred[label_pred < 0])
18
+ hist = np.bincount(n_class * label_true[mask].astype(int) +
19
+ label_pred[mask], minlength=n_class ** 2).reshape(n_class, n_class)
20
+ return hist
21
+
22
+ def update(self, label_trues, label_preds):
23
+ # print label_trues.dtype, label_preds.dtype
24
+ for lt, lp in zip(label_trues, label_preds):
25
+ try:
26
+ self.confusion_matrix += self._fast_hist(lt.flatten(), lp.flatten(), self.n_classes)
27
+ except:
28
+ pass
29
+
30
+ def get_scores(self):
31
+ """Returns accuracy score evaluation result.
32
+ - overall accuracy
33
+ - mean accuracy
34
+ - mean IU
35
+ - fwavacc
36
+ """
37
+ hist = self.confusion_matrix
38
+ acc = np.diag(hist).sum() / (hist.sum() + 0.0001)
39
+ acc_cls = np.diag(hist) / (hist.sum(axis=1) + 0.0001)
40
+ acc_cls = np.nanmean(acc_cls)
41
+ iu = np.diag(hist) / (hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist) + 0.0001)
42
+ mean_iu = np.nanmean(iu)
43
+ freq = hist.sum(axis=1) / (hist.sum() + 0.0001)
44
+ fwavacc = (freq[freq > 0] * iu[freq > 0]).sum()
45
+ cls_iu = dict(zip(range(self.n_classes), iu))
46
+
47
+ return {'Overall Acc': acc,
48
+ 'Mean Acc': acc_cls,
49
+ 'FreqW Acc': fwavacc,
50
+ 'Mean IoU': mean_iu, }, cls_iu
51
+
52
+ def reset(self):
53
+ self.confusion_matrix = np.zeros((self.n_classes, self.n_classes))
utils/schedulers.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.optim.lr_scheduler import _LRScheduler
2
+
3
+
4
+ class ConstantLR(_LRScheduler):
5
+ def __init__(self, optimizer, last_epoch=-1):
6
+ super(ConstantLR, self).__init__(optimizer, last_epoch)
7
+
8
+ def get_lr(self):
9
+ return [base_lr for base_lr in self.base_lrs]
10
+
11
+
12
+ class PolynomialLR(_LRScheduler):
13
+ def __init__(self, optimizer, max_iter, power=0.9, last_epoch=-1):
14
+ self.max_iter = max_iter
15
+ self.power = power
16
+ super(PolynomialLR, self).__init__(optimizer, last_epoch)
17
+
18
+ def get_lr(self):
19
+ factor = (1 - self.last_epoch / float(self.max_iter)) ** self.power
20
+ return [base_lr * factor for base_lr in self.base_lrs]
21
+
22
+
23
+ class WarmUpLR(_LRScheduler):
24
+ def __init__(
25
+ self, optimizer, scheduler, mode="linear", warmup_iters=100, gamma=0.2, last_epoch=-1
26
+ ):
27
+ self.mode = mode
28
+ self.scheduler = scheduler
29
+ self.warmup_iters = warmup_iters
30
+ self.gamma = gamma
31
+ super(WarmUpLR, self).__init__(optimizer, last_epoch)
32
+
33
+ def get_lr(self):
34
+ cold_lrs = self.scheduler.get_lr()
35
+
36
+ if self.last_epoch < self.warmup_iters:
37
+ if self.mode == "linear":
38
+ alpha = self.last_epoch / float(self.warmup_iters)
39
+ factor = self.gamma * (1 - alpha) + alpha
40
+
41
+ elif self.mode == "constant":
42
+ factor = self.gamma
43
+ else:
44
+ raise KeyError("WarmUp type {} not implemented".format(self.mode))
45
+
46
+ return [factor * base_lr for base_lr in cold_lrs]
47
+
48
+ return cold_lrs
49
+
50
+ if __name__ == '__main__':
51
+ import torch
52
+ from torchvision.models import resnet18
53
+ max_iter = 600 * 125
54
+ model = resnet18()
55
+ op = torch.optim.SGD(model.parameters(),0.001)
56
+ sc = PolynomialLR(op,max_iter)
57
+ lr = []
58
+ for i in range(max_iter):
59
+ sc.step()
60
+ print(i,sc.last_epoch,sc.get_lr()[0])
61
+ lr.append(sc.get_lr()[0])
62
+ from matplotlib import pyplot as plt
63
+ plt.plot(list(range(max_iter)),lr)
64
+ plt.show()
utils/util.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import time
4
+ import json
5
+ import cv2
6
+ import torch
7
+ import numpy as np
8
+ import matplotlib.pyplot as plt
9
+
10
+
11
+ def setup_logger(log_file_path: str = None):
12
+ import logging
13
+ from colorlog import ColoredFormatter
14
+ logging.basicConfig(filename=log_file_path,
15
+ format='%(asctime)s %(levelname)-8s %(filename)s[line:%(lineno)d]: %(message)s',
16
+ # 定义输出log的格式
17
+ datefmt='%Y-%m-%d %H:%M:%S', )
18
+ """Return a logger with a default ColoredFormatter."""
19
+ formatter = ColoredFormatter(
20
+ "%(asctime)s %(log_color)s%(levelname)-8s %(reset)s %(filename)s[line:%(lineno)d]: %(message)s",
21
+ datefmt='%Y-%m-%d %H:%M:%S',
22
+ reset=True,
23
+ log_colors={
24
+ 'DEBUG': 'blue',
25
+ 'INFO': 'green',
26
+ 'WARNING': 'yellow',
27
+ 'ERROR': 'red',
28
+ 'CRITICAL': 'red',
29
+ })
30
+
31
+ logger = logging.getLogger('PAN')
32
+ handler = logging.StreamHandler()
33
+ handler.setFormatter(formatter)
34
+ logger.addHandler(handler)
35
+ logger.setLevel(logging.DEBUG)
36
+ logger.info('logger init finished')
37
+ return logger
38
+
39
+
40
+ # --exeTime
41
+ def exe_time(func):
42
+ def newFunc(*args, **args2):
43
+ t0 = time.time()
44
+ back = func(*args, **args2)
45
+ print("{} cost {:.3f}s".format(func.__name__, time.time() - t0))
46
+ return back
47
+
48
+ return newFunc
49
+
50
+
51
+ def save_json(data, json_path):
52
+ with open(json_path, mode='w', encoding='utf8') as f:
53
+ json.dump(data, f, indent=4)
54
+
55
+
56
+ def load_json(json_path):
57
+ with open(json_path, mode='r', encoding='utf8') as f:
58
+ data = json.load(f)
59
+ return data
60
+
61
+
62
+ def show_img(imgs: np.ndarray, color=False):
63
+ if (len(imgs.shape) == 3 and color) or (len(imgs.shape) == 2 and not color):
64
+ imgs = np.expand_dims(imgs, axis=0)
65
+ for img in imgs:
66
+ plt.figure()
67
+ plt.imshow(img, cmap=None if color else 'gray')
68
+
69
+
70
+ def draw_bbox(img_path, result, color=(128, 240, 128), thickness=3):
71
+ if isinstance(img_path, str):
72
+ img_path = cv2.imread(img_path)
73
+ # img_path = cv2.cvtColor(img_path, cv2.COLOR_BGR2RGB)
74
+ img_path = img_path.copy()
75
+ for point in result:
76
+ point = point.astype(int)
77
+ cv2.line(img_path, tuple(point[0]), tuple(point[1]), color, thickness)
78
+ cv2.line(img_path, tuple(point[1]), tuple(point[2]), color, thickness)
79
+ cv2.line(img_path, tuple(point[2]), tuple(point[3]), color, thickness)
80
+ cv2.line(img_path, tuple(point[3]), tuple(point[0]), color, thickness)
81
+ return img_path
82
+
83
+
84
+ def cal_text_score(texts, gt_texts, training_masks, running_metric_text):
85
+ training_masks = training_masks.data.cpu().numpy()
86
+ pred_text = torch.sigmoid(texts).data.cpu().numpy() * training_masks
87
+ pred_text[pred_text <= 0.5] = 0
88
+ pred_text[pred_text > 0.5] = 1
89
+ pred_text = pred_text.astype(np.int32)
90
+ gt_text = gt_texts.data.cpu().numpy() * training_masks
91
+ gt_text = gt_text.astype(np.int32)
92
+ running_metric_text.update(gt_text, pred_text)
93
+ score_text, _ = running_metric_text.get_scores()
94
+ return score_text
95
+
96
+
97
+ def cal_kernel_score(kernel, gt_kernel, gt_texts, training_masks, running_metric_kernel):
98
+ mask = (gt_texts * training_masks.float()).data.cpu().numpy()
99
+ pred_kernel = torch.sigmoid(kernel).data.cpu().numpy()
100
+ pred_kernel[pred_kernel <= 0.5] = 0
101
+ pred_kernel[pred_kernel > 0.5] = 1
102
+ pred_kernel = (pred_kernel * mask).astype(np.int32)
103
+ gt_kernel = gt_kernel.data.cpu().numpy()
104
+ gt_kernel = (gt_kernel * mask).astype(np.int32)
105
+ running_metric_kernel.update(gt_kernel, pred_kernel)
106
+ score_kernel, _ = running_metric_kernel.get_scores()
107
+ return score_kernel
108
+
109
+
110
+ def order_points_clockwise(pts):
111
+ rect = np.zeros((4, 2), dtype="float32")
112
+ s = pts.sum(axis=1)
113
+ rect[0] = pts[np.argmin(s)]
114
+ rect[2] = pts[np.argmax(s)]
115
+ diff = np.diff(pts, axis=1)
116
+ rect[1] = pts[np.argmin(diff)]
117
+ rect[3] = pts[np.argmax(diff)]
118
+ return rect
119
+
120
+
121
+ def order_points_clockwise_list(pts):
122
+ pts = pts.tolist()
123
+ pts.sort(key=lambda x: (x[1], x[0]))
124
+ pts[:2] = sorted(pts[:2], key=lambda x: x[0])
125
+ pts[2:] = sorted(pts[2:], key=lambda x: -x[0])
126
+ pts = np.array(pts)
127
+ return pts
128
+
129
+
130
+ if __name__ == '__main__':
131
+ #box = np.array([382, 1080, 443, 999, 423, 1014, 362, 1095]).reshape(-1, 2)
132
+ box = np.array([0, 4, 2, 2, 0, 8, 4, 4]).reshape(-1, 2)
133
+ # box = np.array([0, 0, 2, 2, 0, 4, 4, 4]).reshape(-1, 2)
134
+ from scipy.spatial import ConvexHull
135
+ print(box)
136
+ print(order_points_clockwise(box))
137
+ print(order_points_clockwise_list(box))