Spaces:
Sleeping
Sleeping
Add app
Browse files- app.py +136 -0
- models/__init__.py +15 -0
- models/loss.py +132 -0
- models/model.py +69 -0
- models/modules/__init__.py +6 -0
- models/modules/mobilenetv2.py +178 -0
- models/modules/resnet.py +300 -0
- models/modules/segmentation_head.py +206 -0
- models/modules/shufflenetv2.py +198 -0
- requirements.txt +7 -0
- utils/__init__.py +6 -0
- utils/cal_recall/__init__.py +4 -0
- utils/cal_recall/rrc_evaluation_funcs.py +394 -0
- utils/cal_recall/script.py +325 -0
- utils/make_trainfile.py +20 -0
- utils/metrics.py +53 -0
- utils/schedulers.py +64 -0
- utils/util.py +137 -0
app.py
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torchvision import transforms
|
5 |
+
import os
|
6 |
+
import cv2
|
7 |
+
import time
|
8 |
+
import numpy as np
|
9 |
+
|
10 |
+
import pyclipper
|
11 |
+
import gradio as gr
|
12 |
+
|
13 |
+
from models import get_model
|
14 |
+
from utils.util import show_img, draw_bbox
|
15 |
+
|
16 |
+
class Pytorch_model:
|
17 |
+
def __init__(self, model_path, gpu_id=None):
|
18 |
+
'''
|
19 |
+
初始化pytorch模型
|
20 |
+
:param model_path: 模型地址(可以是模型的参数或者参数和计算图一起保存的文件)
|
21 |
+
:param gpu_id: 在哪一块gpu上运行
|
22 |
+
'''
|
23 |
+
self.gpu_id = gpu_id
|
24 |
+
|
25 |
+
if self.gpu_id is not None and isinstance(self.gpu_id, int) and torch.cuda.is_available():
|
26 |
+
self.device = torch.device("cuda:%s" % self.gpu_id)
|
27 |
+
checkpoint = torch.load(model_path)
|
28 |
+
else:
|
29 |
+
self.device = torch.device("cpu")
|
30 |
+
checkpoint = torch.load(model_path, map_location='cpu')
|
31 |
+
print('device:', self.device)
|
32 |
+
|
33 |
+
config = checkpoint['config']
|
34 |
+
config['arch']['args']['pretrained'] = False
|
35 |
+
self.net = get_model(config)
|
36 |
+
|
37 |
+
self.img_channel = config['data_loader']['args']['dataset']['img_channel']
|
38 |
+
self.net.load_state_dict(checkpoint['state_dict']) ## load weights
|
39 |
+
self.net.to(self.device)
|
40 |
+
self.net.eval()
|
41 |
+
|
42 |
+
def predict(self, img, short_size: int = 736, min_area: int = 100):
|
43 |
+
'''
|
44 |
+
对传入的图像进行预测,支持图像地址, opencv读取图片,偏慢
|
45 |
+
:param img: the image as an np array
|
46 |
+
:param short_size:
|
47 |
+
:param min_area: 小于该尺度的bbox忽略
|
48 |
+
:return:
|
49 |
+
'''
|
50 |
+
img = cv2.imread(img)
|
51 |
+
if self.img_channel == 3:
|
52 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
53 |
+
h, w = img.shape[:2]
|
54 |
+
scale = short_size / min(h, w)
|
55 |
+
img = cv2.resize(img, None, fx=scale, fy=scale)
|
56 |
+
|
57 |
+
tensor = transforms.ToTensor()(img)
|
58 |
+
tensor = tensor.unsqueeze_(0)
|
59 |
+
|
60 |
+
tensor = tensor.to(self.device)
|
61 |
+
with torch.no_grad():
|
62 |
+
torch.cuda.synchronize(self.device)
|
63 |
+
start = time.time()
|
64 |
+
preds = self.net(tensor)[0]
|
65 |
+
torch.cuda.synchronize(self.device)
|
66 |
+
scale = (preds.shape[2] / w, preds.shape[1] / h)
|
67 |
+
t = time.time() - start
|
68 |
+
|
69 |
+
'''inference'''
|
70 |
+
start = time.time()
|
71 |
+
prob_map, thres_map = preds[0], preds[1]
|
72 |
+
|
73 |
+
## Step 1: Use threshold to get the binary map
|
74 |
+
thr = 0.2
|
75 |
+
out = (prob_map > thr).float() * 255
|
76 |
+
out = out.data.cpu().numpy().astype(np.uint8)
|
77 |
+
# cv2.imwrite('c_bin_map.png', out)
|
78 |
+
|
79 |
+
## Step 2: Connected components findContours
|
80 |
+
contours, hierarchy = cv2.findContours(out, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
81 |
+
contours = [(i / scale).astype(int) for i in contours if len(i)>=4]
|
82 |
+
|
83 |
+
# Step 3: Dilate the shrunk region (not necessary)
|
84 |
+
ratio_prime = 1.5
|
85 |
+
dilated_polys = []
|
86 |
+
for poly in contours:
|
87 |
+
poly = poly[:,0,:]
|
88 |
+
D_prime = cv2.contourArea(poly) * ratio_prime / cv2.arcLength(poly, True) # formula(10) in the thesis
|
89 |
+
pco = pyclipper.PyclipperOffset()
|
90 |
+
pco.AddPath(poly, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
|
91 |
+
dilated_poly = np.array(pco.Execute(D_prime))
|
92 |
+
if dilated_poly.size == 0 or dilated_poly.dtype != int or len(dilated_poly) != 1:
|
93 |
+
continue
|
94 |
+
dilated_polys.append(dilated_poly)
|
95 |
+
|
96 |
+
boxes_list = []
|
97 |
+
for cnt in dilated_polys:
|
98 |
+
# print('=============')
|
99 |
+
# print(cnt)
|
100 |
+
# print(len(cnt))
|
101 |
+
if cv2.contourArea(cnt) < min_area:
|
102 |
+
continue
|
103 |
+
rect = cv2.minAreaRect(cnt)
|
104 |
+
box = (cv2.boxPoints(rect)).astype(int)
|
105 |
+
boxes_list.append(box)
|
106 |
+
|
107 |
+
t = time.time() - start + t
|
108 |
+
|
109 |
+
boxes_list = np.array(boxes_list)
|
110 |
+
return dilated_polys, boxes_list, t
|
111 |
+
|
112 |
+
|
113 |
+
def run(img):
|
114 |
+
# os.environ["CUDA_VISIBLE_DEVICES"] = "5"
|
115 |
+
model_path = 'pre_trained/MobileNetv2_best_loss.pth'
|
116 |
+
|
117 |
+
# init model
|
118 |
+
model = Pytorch_model(model_path, gpu_id = None) ## set GPU id or None if you only have cpu
|
119 |
+
contours, boxes_list, t = model.predict(img)
|
120 |
+
print('Time: %.4f' %t)
|
121 |
+
|
122 |
+
img = img[:, :, ::-1]
|
123 |
+
imgc = img.copy()
|
124 |
+
cv2.drawContours(imgc, contours, -1, (22,222,22), 2, cv2.LINE_AA)
|
125 |
+
return imgc
|
126 |
+
# cv2.imwrite('contour.png', imgc)
|
127 |
+
# img = draw_bbox(img, boxes_list)
|
128 |
+
# cv2.imwrite('predict.jpg', img)
|
129 |
+
|
130 |
+
if __name__ == '__main__':
|
131 |
+
iface = gr.Interface(fn=run,
|
132 |
+
title="CJK Font Detection Using DBNet",
|
133 |
+
description="Gives an image containing CJK fonts, we will mark the texts out!",
|
134 |
+
inputs="image",
|
135 |
+
outputs="image")
|
136 |
+
iface.launch()
|
models/__init__.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
from .model import Model
|
4 |
+
from .loss import DBLoss
|
5 |
+
|
6 |
+
|
7 |
+
def get_model(config):
|
8 |
+
model_config = config['arch']['args']
|
9 |
+
return Model(model_config)
|
10 |
+
|
11 |
+
def get_loss(config):
|
12 |
+
alpha = config['loss']['args']['alpha']
|
13 |
+
beta = config['loss']['args']['beta']
|
14 |
+
ohem_ratio = config['loss']['args']['ohem_ratio']
|
15 |
+
return DBLoss(alpha=alpha, beta=beta, ohem_ratio=ohem_ratio)
|
models/loss.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
import itertools
|
4 |
+
import torch
|
5 |
+
from torch import nn
|
6 |
+
import numpy as np
|
7 |
+
import cv2
|
8 |
+
import torchvision.transforms as transforms
|
9 |
+
|
10 |
+
# import torchsnooper ## for debug
|
11 |
+
|
12 |
+
class DBLoss(nn.Module):
|
13 |
+
def __init__(self, alpha=1., beta=10., ohem_ratio=3):
|
14 |
+
"""
|
15 |
+
Implement DB Loss.
|
16 |
+
:param alpha: loss binary_map 前面的系数
|
17 |
+
:param beta: loss threshold 前面的系数
|
18 |
+
:param ohem_ratio: OHEM的比例
|
19 |
+
"""
|
20 |
+
super().__init__()
|
21 |
+
self.alpha = alpha
|
22 |
+
self.beta = beta
|
23 |
+
self.ohem_ratio = ohem_ratio
|
24 |
+
|
25 |
+
def forward(self, outputs, labels, training_masks, G_d):
|
26 |
+
"""
|
27 |
+
Implement DB Loss.
|
28 |
+
:param outputs: N 2 H W
|
29 |
+
:param labels: N 2 H W
|
30 |
+
:param training_masks:
|
31 |
+
"""
|
32 |
+
prob_map = outputs[:, 0, :, :]
|
33 |
+
thres_map = outputs[:, 1, :, :]
|
34 |
+
gt_prob = labels[:, 0, :, :]
|
35 |
+
gt_thres = labels[:, 1, :, :]
|
36 |
+
|
37 |
+
G_d = G_d.to(dtype = torch.float32)
|
38 |
+
training_masks = training_masks.to(dtype = torch.float32)
|
39 |
+
|
40 |
+
# OHEM mask (todo)
|
41 |
+
# selected_masks = self.ohem_batch(prob_map, gt_prob)
|
42 |
+
# selected_masks = selected_masks.to(outputs.device)
|
43 |
+
|
44 |
+
# 计算 prob loss
|
45 |
+
loss_prob = self.dice_loss(prob_map, gt_prob, training_masks)
|
46 |
+
# loss_prob = self.bce_loss(prob_map, gt_prob, selected_masks)
|
47 |
+
|
48 |
+
# 计算 binary map loss
|
49 |
+
bin_map = self.DB(prob_map, thres_map)
|
50 |
+
loss_bin = self.dice_loss(bin_map, gt_prob, training_masks)
|
51 |
+
# loss_prob = self.bce_loss(bin_map, gt_prob, selected_masks)
|
52 |
+
|
53 |
+
# 计算 threshold map loss
|
54 |
+
loss_fn = torch.nn.L1Loss(reduction='mean')
|
55 |
+
L1_loss = loss_fn(thres_map, gt_thres)
|
56 |
+
loss_thres = L1_loss * G_d
|
57 |
+
|
58 |
+
loss_prob = loss_prob.mean()
|
59 |
+
loss_bin = loss_bin.mean()
|
60 |
+
loss_thres = loss_thres.mean()
|
61 |
+
|
62 |
+
loss_all = loss_prob + self.alpha * loss_bin + self.beta * loss_thres
|
63 |
+
return loss_all, loss_prob, loss_bin, loss_thres
|
64 |
+
|
65 |
+
def DB(self, prob_map, thres_map, k=50):
|
66 |
+
'''
|
67 |
+
Differentiable binarization
|
68 |
+
another form: torch.sigmoid(k * (prob_map - thres_map))
|
69 |
+
'''
|
70 |
+
return 1. / (torch.exp((-k * (prob_map - thres_map))) + 1)
|
71 |
+
|
72 |
+
def dice_loss(self, pred_cls, gt_cls, training_mask):
|
73 |
+
'''
|
74 |
+
dice loss
|
75 |
+
此处默认真实值和预测值的格式均为 NCHW
|
76 |
+
:param gt_cls:
|
77 |
+
:param pred_cls:
|
78 |
+
:param training_mask:
|
79 |
+
:return:
|
80 |
+
'''
|
81 |
+
eps = 1e-5
|
82 |
+
|
83 |
+
intersection = torch.sum(gt_cls * pred_cls * training_mask)
|
84 |
+
union = torch.sum(gt_cls * training_mask) + torch.sum(pred_cls * training_mask) + eps
|
85 |
+
loss = 1. - (2 * intersection / union)
|
86 |
+
|
87 |
+
return loss
|
88 |
+
|
89 |
+
def bce_loss(self, input, target, mask):
|
90 |
+
if mask.sum() == 0:
|
91 |
+
return torch.tensor(0.0, device=input.device, requires_grad=True)
|
92 |
+
target[target <= 0.5] = 0
|
93 |
+
target[target > 0.5] = 1
|
94 |
+
input = input[mask.bool()]
|
95 |
+
target = target[mask.bool()]
|
96 |
+
loss = nn.BCELoss(reduction='mean')(input, target)
|
97 |
+
return loss
|
98 |
+
|
99 |
+
def ohem_single(self, score, gt_text):
|
100 |
+
pos_num = (int)(np.sum(gt_text > 0.5))
|
101 |
+
|
102 |
+
if pos_num == 0:
|
103 |
+
selected_mask = np.zeros_like(score)
|
104 |
+
selected_mask = selected_mask.reshape(1, selected_mask.shape[0], selected_mask.shape[1]).astype('float32')
|
105 |
+
return selected_mask
|
106 |
+
|
107 |
+
neg_num = (int)(np.sum(gt_text <= 0.5))
|
108 |
+
neg_num = (int)(min(pos_num * self.ohem_ratio, neg_num))
|
109 |
+
|
110 |
+
if neg_num == 0:
|
111 |
+
selected_mask = np.zeros_like(score)
|
112 |
+
selected_mask = selected_mask.reshape(1, selected_mask.shape[0], selected_mask.shape[1]).astype('float32')
|
113 |
+
return selected_mask
|
114 |
+
|
115 |
+
neg_score = score[gt_text <= 0.5]
|
116 |
+
neg_score_sorted = np.sort(-neg_score)
|
117 |
+
threshold = -neg_score_sorted[neg_num - 1]
|
118 |
+
selected_mask = (score >= threshold) | (gt_text > 0.5)
|
119 |
+
selected_mask = selected_mask.reshape(1, selected_mask.shape[0], selected_mask.shape[1]).astype('float32')
|
120 |
+
return selected_mask
|
121 |
+
|
122 |
+
def ohem_batch(self, scores, gt_texts):
|
123 |
+
scores = scores.data.cpu().numpy()
|
124 |
+
gt_texts = gt_texts.data.cpu().numpy()
|
125 |
+
selected_masks = []
|
126 |
+
for i in range(scores.shape[0]):
|
127 |
+
selected_masks.append(self.ohem_single(scores[i, :, :], gt_texts[i, :, :]))
|
128 |
+
|
129 |
+
selected_masks = np.concatenate(selected_masks, 0)
|
130 |
+
selected_masks = torch.from_numpy(selected_masks).float()
|
131 |
+
|
132 |
+
return selected_masks
|
models/model.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from models.modules import *
|
7 |
+
#from modules import *
|
8 |
+
|
9 |
+
backbone_dict = {'resnet18': {'models': resnet18, 'out': [64, 128, 256, 512]},
|
10 |
+
'resnet34': {'models': resnet34, 'out': [64, 128, 256, 512]},
|
11 |
+
'resnet50': {'models': resnet50, 'out': [256, 512, 1024, 2048]},
|
12 |
+
'resnet101': {'models': resnet101, 'out': [256, 512, 1024, 2048]},
|
13 |
+
'resnet152': {'models': resnet152, 'out': [256, 512, 1024, 2048]},
|
14 |
+
'resnext50_32x4d': {'models': resnext50_32x4d, 'out': [256, 512, 1024, 2048]},
|
15 |
+
'resnext101_32x8d': {'models': resnext101_32x8d, 'out': [256, 512, 1024, 2048]},
|
16 |
+
'shufflenetv2': {'models': shufflenet_v2_x1_0, 'out': [24, 116, 232, 464]},
|
17 |
+
'mobilenetv2': {'models': mobilenet_v2_x1_0, 'out': [24, 40, 160, 160]}
|
18 |
+
}
|
19 |
+
|
20 |
+
segmentation_head_dict = {'FPN': FPN, 'FPEM_FFM': FPEM_FFM}
|
21 |
+
|
22 |
+
# 'MobileNetV3_Large': {'models': MobileNetV3_Large, 'out': [24, 40, 160, 160]},
|
23 |
+
# 'MobileNetV3_Small': {'models': MobileNetV3_Small, 'out': [16, 24, 48, 96]},
|
24 |
+
# 'shufflenetv2': {'models': shufflenet_v2_x1_0, 'out': [24, 116, 232, 464]}}
|
25 |
+
|
26 |
+
class Model(nn.Module):
|
27 |
+
def __init__(self, model_config: dict):
|
28 |
+
"""
|
29 |
+
PANnet
|
30 |
+
:param model_config: 模型配置
|
31 |
+
"""
|
32 |
+
super().__init__()
|
33 |
+
backbone = model_config['backbone']
|
34 |
+
pretrained = model_config['pretrained']
|
35 |
+
segmentation_head = model_config['segmentation_head']
|
36 |
+
|
37 |
+
assert backbone in backbone_dict, 'backbone must in: {}'.format(backbone_dict)
|
38 |
+
assert segmentation_head in segmentation_head_dict, 'segmentation_head must in: {}'.format(
|
39 |
+
segmentation_head_dict)
|
40 |
+
|
41 |
+
backbone_model, backbone_out = backbone_dict[backbone]['models'], backbone_dict[backbone]['out']
|
42 |
+
self.backbone = backbone_model(pretrained=pretrained)
|
43 |
+
self.segmentation_head = segmentation_head_dict[segmentation_head](backbone_out, **model_config)
|
44 |
+
self.name = '{}_{}'.format(backbone, segmentation_head)
|
45 |
+
|
46 |
+
def forward(self, x):
|
47 |
+
_, _, H, W = x.size()
|
48 |
+
backbone_out = self.backbone(x)
|
49 |
+
segmentation_head_out = self.segmentation_head(backbone_out)
|
50 |
+
y = segmentation_head_out
|
51 |
+
return y
|
52 |
+
|
53 |
+
|
54 |
+
if __name__ == '__main__':
|
55 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
56 |
+
print(device)
|
57 |
+
x = torch.zeros(1, 3, 640, 640).to(device)
|
58 |
+
|
59 |
+
model_config = {
|
60 |
+
'backbone': 'mobilenetv2',
|
61 |
+
'fpem_repeat': 2, # fpem模块重复的次数
|
62 |
+
'pretrained': False, # backbone 是否使用imagesnet的预训练模型
|
63 |
+
'segmentation_head': 'FPN' # 分割头,FPN or FPEM_FFM
|
64 |
+
}
|
65 |
+
model = Model(model_config=model_config).to(device)
|
66 |
+
y = model(x)
|
67 |
+
|
68 |
+
print(model)
|
69 |
+
#torch.save(model.state_dict(), 'PAN.pth')
|
models/modules/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
from .resnet import *
|
4 |
+
from .shufflenetv2 import *
|
5 |
+
from .mobilenetv2 import *
|
6 |
+
from .segmentation_head import FPEM_FFM,FPN
|
models/modules/mobilenetv2.py
ADDED
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
from torch.hub import load_state_dict_from_url
|
5 |
+
|
6 |
+
__all__ = ['mobilenet_v2_x1_0']
|
7 |
+
|
8 |
+
model_urls = {
|
9 |
+
#currently hadn't found a pretrained weight
|
10 |
+
'mobilenet_v2_x1_0': None,
|
11 |
+
}
|
12 |
+
def _make_divisible(v, divisor, min_value=None):
|
13 |
+
"""
|
14 |
+
This function is taken from the original tf repo.
|
15 |
+
It ensures that all layers have a channel number that is divisible by 8
|
16 |
+
It can be seen here:
|
17 |
+
https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
|
18 |
+
:param v:
|
19 |
+
:param divisor:
|
20 |
+
:param min_value:
|
21 |
+
:return:
|
22 |
+
"""
|
23 |
+
if min_value is None:
|
24 |
+
min_value = divisor
|
25 |
+
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
|
26 |
+
# Make sure that round down does not go down by more than 10%.
|
27 |
+
if new_v < 0.9 * v:
|
28 |
+
new_v += divisor
|
29 |
+
return new_v
|
30 |
+
|
31 |
+
class InvertedResidual(nn.Module):
|
32 |
+
def __init__(self, inp, oup, stride, expand_ratio):
|
33 |
+
super(InvertedResidual, self).__init__()
|
34 |
+
|
35 |
+
if not (1 <= stride <= 2):
|
36 |
+
raise ValueError('illegal stride value')
|
37 |
+
self.stride = stride
|
38 |
+
|
39 |
+
self.exp_r = expand_ratio;
|
40 |
+
hidden_dim = round(inp * self.exp_r);
|
41 |
+
|
42 |
+
if self.exp_r == 1:
|
43 |
+
self.branch = nn.Sequential(
|
44 |
+
# dw conv
|
45 |
+
self.depthwise_conv(hidden_dim, hidden_dim, 3, stride, padding = 1, bias=False),
|
46 |
+
nn.BatchNorm2d(hidden_dim),
|
47 |
+
nn.ReLU6(inplace=True),
|
48 |
+
# pw-linear
|
49 |
+
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
|
50 |
+
nn.BatchNorm2d(oup),
|
51 |
+
)
|
52 |
+
|
53 |
+
else:
|
54 |
+
self.branch = nn.Sequential(
|
55 |
+
# pw
|
56 |
+
nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
|
57 |
+
nn.BatchNorm2d(hidden_dim),
|
58 |
+
nn.ReLU6(inplace=True),
|
59 |
+
# dw
|
60 |
+
self.depthwise_conv(hidden_dim, hidden_dim, 3, stride, padding = 1, bias=False),
|
61 |
+
nn.BatchNorm2d(hidden_dim),
|
62 |
+
nn.ReLU6(inplace=True),
|
63 |
+
# pw-linear
|
64 |
+
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
|
65 |
+
nn.BatchNorm2d(oup),
|
66 |
+
)
|
67 |
+
|
68 |
+
self.downsample = nn.Sequential(
|
69 |
+
nn.Conv2d(inp, oup, kernel_size=1, stride=stride, bias=False),
|
70 |
+
nn.BatchNorm2d(oup),
|
71 |
+
)
|
72 |
+
|
73 |
+
self.identity = stride == 1
|
74 |
+
@staticmethod
|
75 |
+
def depthwise_conv(i, o, kernel_size, stride, padding, bias=False):
|
76 |
+
return nn.Conv2d(i, o, kernel_size, stride, padding, bias=bias, groups=i)
|
77 |
+
|
78 |
+
def forward(self, x):
|
79 |
+
if self.identity:
|
80 |
+
downsampx = self.downsample(x)
|
81 |
+
return downsampx + self.branch(x)
|
82 |
+
else:
|
83 |
+
return self.branch(x)
|
84 |
+
|
85 |
+
|
86 |
+
class MobileNetV2(nn.Module):
|
87 |
+
def __init__(self, stages_repeats, stages_out_channels, num_classes=1000, width_mult=1.):
|
88 |
+
super(MobileNetV2, self).__init__()
|
89 |
+
|
90 |
+
if len(stages_repeats) != 5:
|
91 |
+
raise ValueError('expected stages_repeats as list of 4 positive ints')
|
92 |
+
if len(stages_out_channels) != 6:
|
93 |
+
raise ValueError('expected stages_out_channels as list of 5 positive ints')
|
94 |
+
self._stage_out_channels = stages_out_channels
|
95 |
+
|
96 |
+
self.tlist = [1,6,6,6,6]
|
97 |
+
self.slist = [1,2,2,2,1]
|
98 |
+
|
99 |
+
input_channels = 3
|
100 |
+
output_channels = self._stage_out_channels[0] # 32
|
101 |
+
#output_channels = _make_divisible(output_channels * width_mult, 4 if width_mult == 0.1 else 8)
|
102 |
+
|
103 |
+
self.conv1 = nn.Sequential(
|
104 |
+
nn.Conv2d(input_channels, output_channels, 3, 2, 1, bias=False),
|
105 |
+
nn.BatchNorm2d(output_channels),
|
106 |
+
nn.ReLU6(inplace=True),
|
107 |
+
)
|
108 |
+
input_channels = output_channels
|
109 |
+
|
110 |
+
|
111 |
+
stage_names = ['stage{}'.format(i) for i in [2, 3, 4, 5, 6]]
|
112 |
+
for name, repeats, output_channels, t, s in zip(
|
113 |
+
stage_names, stages_repeats, self._stage_out_channels[1:], self.tlist, self.slist):
|
114 |
+
|
115 |
+
#output_channels = _make_divisible(output_channels * width_mult, 4 if width_mult == 0.1 else 8)
|
116 |
+
|
117 |
+
seq = [InvertedResidual(input_channels, output_channels, s, t)]
|
118 |
+
for i in range(repeats - 1):
|
119 |
+
|
120 |
+
seq.append(InvertedResidual(output_channels, output_channels, 1, t))
|
121 |
+
|
122 |
+
|
123 |
+
|
124 |
+
setattr(self, name, nn.Sequential(*seq))
|
125 |
+
input_channels = output_channels
|
126 |
+
|
127 |
+
|
128 |
+
output_channels = self._stage_out_channels[-1]
|
129 |
+
|
130 |
+
self.conv9 = nn.Sequential(
|
131 |
+
nn.Conv2d(input_channels, output_channels, 1, 1, 0, bias=False),
|
132 |
+
nn.BatchNorm2d(output_channels),
|
133 |
+
nn.ReLU6(inplace=True),
|
134 |
+
)
|
135 |
+
|
136 |
+
|
137 |
+
|
138 |
+
def forward(self, x):
|
139 |
+
x = self.conv1(x)
|
140 |
+
c2 = self.stage2(x)
|
141 |
+
c3 = self.stage3(c2)
|
142 |
+
c4 = self.stage4(c3)
|
143 |
+
c5 = self.stage5(c4)
|
144 |
+
c6 = self.stage6(c5)
|
145 |
+
#c7 = self.stage7(c6)
|
146 |
+
#c8 = self.stage8(c7)
|
147 |
+
#c9 = self.conv9(c8)
|
148 |
+
|
149 |
+
return c3, c4, c5, c6
|
150 |
+
|
151 |
+
|
152 |
+
|
153 |
+
def _mobilenetv2(arch, pretrained, progress, *args, **kwargs):
|
154 |
+
model = MobileNetV2(*args, **kwargs)
|
155 |
+
|
156 |
+
if pretrained:
|
157 |
+
model_url = model_urls[arch]
|
158 |
+
if model_url is None:
|
159 |
+
raise NotImplementedError('pretrained {} is not supported as of now'.format(arch))
|
160 |
+
else:
|
161 |
+
state_dict = load_state_dict_from_url(model_url, progress=progress)
|
162 |
+
model.load_state_dict(state_dict,strict=False)
|
163 |
+
|
164 |
+
return model
|
165 |
+
|
166 |
+
|
167 |
+
def mobilenet_v2_x1_0(pretrained=False, progress=True, **kwargs):
|
168 |
+
"""
|
169 |
+
Constructs a ShuffleNetV2 with 0.5x output channels, as described in
|
170 |
+
`"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
|
171 |
+
<https://arxiv.org/abs/1807.11164>`_.
|
172 |
+
|
173 |
+
Args:
|
174 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
175 |
+
progress (bool): If True, displays a progress bar of the download to stderr
|
176 |
+
"""
|
177 |
+
return _mobilenetv2('mobilenet_v2_x1_0', pretrained, progress,
|
178 |
+
[1, 2, 3, 4, 3], [32, 16, 24, 40, 160, 160], **kwargs)
|
models/modules/resnet.py
ADDED
@@ -0,0 +1,300 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
import torch.nn as nn
|
4 |
+
from torch.hub import load_state_dict_from_url
|
5 |
+
#from torchvision.models.utils import load_state_dict_from_url
|
6 |
+
|
7 |
+
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
|
8 |
+
'resnet152', 'resnext50_32x4d', 'resnext101_32x8d']
|
9 |
+
|
10 |
+
model_urls = {
|
11 |
+
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
|
12 |
+
'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
|
13 |
+
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
|
14 |
+
'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
|
15 |
+
'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
|
16 |
+
'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
|
17 |
+
'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
|
18 |
+
}
|
19 |
+
|
20 |
+
|
21 |
+
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
|
22 |
+
"""3x3 convolution with padding"""
|
23 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
24 |
+
padding=dilation, groups=groups, bias=False, dilation=dilation)
|
25 |
+
|
26 |
+
|
27 |
+
def conv1x1(in_planes, out_planes, stride=1):
|
28 |
+
"""1x1 convolution"""
|
29 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
30 |
+
|
31 |
+
|
32 |
+
class BasicBlock(nn.Module):
|
33 |
+
expansion = 1
|
34 |
+
|
35 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
|
36 |
+
base_width=64, dilation=1, norm_layer=None):
|
37 |
+
super(BasicBlock, self).__init__()
|
38 |
+
if norm_layer is None:
|
39 |
+
norm_layer = nn.BatchNorm2d
|
40 |
+
if groups != 1 or base_width != 64:
|
41 |
+
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
|
42 |
+
if dilation > 1:
|
43 |
+
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
|
44 |
+
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
|
45 |
+
self.conv1 = conv3x3(inplanes, planes, stride)
|
46 |
+
self.bn1 = norm_layer(planes)
|
47 |
+
self.relu = nn.ReLU(inplace=True)
|
48 |
+
self.conv2 = conv3x3(planes, planes)
|
49 |
+
self.bn2 = norm_layer(planes)
|
50 |
+
self.downsample = downsample
|
51 |
+
self.stride = stride
|
52 |
+
|
53 |
+
def forward(self, x):
|
54 |
+
identity = x
|
55 |
+
|
56 |
+
out = self.conv1(x)
|
57 |
+
out = self.bn1(out)
|
58 |
+
out = self.relu(out)
|
59 |
+
|
60 |
+
out = self.conv2(out)
|
61 |
+
out = self.bn2(out)
|
62 |
+
|
63 |
+
if self.downsample is not None:
|
64 |
+
identity = self.downsample(x)
|
65 |
+
|
66 |
+
out += identity
|
67 |
+
out = self.relu(out)
|
68 |
+
|
69 |
+
return out
|
70 |
+
|
71 |
+
|
72 |
+
class Bottleneck(nn.Module):
|
73 |
+
expansion = 4
|
74 |
+
|
75 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
|
76 |
+
base_width=64, dilation=1, norm_layer=None):
|
77 |
+
super(Bottleneck, self).__init__()
|
78 |
+
if norm_layer is None:
|
79 |
+
norm_layer = nn.BatchNorm2d
|
80 |
+
width = int(planes * (base_width / 64.)) * groups
|
81 |
+
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
|
82 |
+
self.conv1 = conv1x1(inplanes, width)
|
83 |
+
self.bn1 = norm_layer(width)
|
84 |
+
self.conv2 = conv3x3(width, width, stride, groups, dilation)
|
85 |
+
self.bn2 = norm_layer(width)
|
86 |
+
self.conv3 = conv1x1(width, planes * self.expansion)
|
87 |
+
self.bn3 = norm_layer(planes * self.expansion)
|
88 |
+
self.relu = nn.ReLU(inplace=True)
|
89 |
+
self.downsample = downsample
|
90 |
+
self.stride = stride
|
91 |
+
|
92 |
+
def forward(self, x):
|
93 |
+
identity = x
|
94 |
+
|
95 |
+
out = self.conv1(x)
|
96 |
+
out = self.bn1(out)
|
97 |
+
out = self.relu(out)
|
98 |
+
|
99 |
+
out = self.conv2(out)
|
100 |
+
out = self.bn2(out)
|
101 |
+
out = self.relu(out)
|
102 |
+
|
103 |
+
out = self.conv3(out)
|
104 |
+
out = self.bn3(out)
|
105 |
+
|
106 |
+
if self.downsample is not None:
|
107 |
+
identity = self.downsample(x)
|
108 |
+
|
109 |
+
out += identity
|
110 |
+
out = self.relu(out)
|
111 |
+
|
112 |
+
return out
|
113 |
+
|
114 |
+
|
115 |
+
class ResNet(nn.Module):
|
116 |
+
|
117 |
+
def __init__(self, block, layers, zero_init_residual=False,
|
118 |
+
groups=1, width_per_group=64, replace_stride_with_dilation=None,
|
119 |
+
norm_layer=None):
|
120 |
+
super(ResNet, self).__init__()
|
121 |
+
if norm_layer is None:
|
122 |
+
norm_layer = nn.BatchNorm2d
|
123 |
+
self._norm_layer = norm_layer
|
124 |
+
|
125 |
+
self.inplanes = 64
|
126 |
+
self.dilation = 1
|
127 |
+
if replace_stride_with_dilation is None:
|
128 |
+
# each element in the tuple indicates if we should replace
|
129 |
+
# the 2x2 stride with a dilated convolution instead
|
130 |
+
replace_stride_with_dilation = [False, False, False]
|
131 |
+
if len(replace_stride_with_dilation) != 3:
|
132 |
+
raise ValueError("replace_stride_with_dilation should be None "
|
133 |
+
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
|
134 |
+
self.groups = groups
|
135 |
+
self.base_width = width_per_group
|
136 |
+
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
|
137 |
+
bias=False)
|
138 |
+
self.bn1 = norm_layer(self.inplanes)
|
139 |
+
self.relu = nn.ReLU(inplace=True)
|
140 |
+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
141 |
+
self.layer1 = self._make_layer(block, 64, layers[0])
|
142 |
+
self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
|
143 |
+
dilate=replace_stride_with_dilation[0])
|
144 |
+
self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
|
145 |
+
dilate=replace_stride_with_dilation[1])
|
146 |
+
self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
|
147 |
+
dilate=replace_stride_with_dilation[2])
|
148 |
+
|
149 |
+
for m in self.modules():
|
150 |
+
if isinstance(m, nn.Conv2d):
|
151 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
152 |
+
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
153 |
+
nn.init.constant_(m.weight, 1)
|
154 |
+
nn.init.constant_(m.bias, 0)
|
155 |
+
|
156 |
+
# Zero-initialize the last BN in each residual branch,
|
157 |
+
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
|
158 |
+
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
|
159 |
+
if zero_init_residual:
|
160 |
+
for m in self.modules():
|
161 |
+
if isinstance(m, Bottleneck):
|
162 |
+
nn.init.constant_(m.bn3.weight, 0)
|
163 |
+
elif isinstance(m, BasicBlock):
|
164 |
+
nn.init.constant_(m.bn2.weight, 0)
|
165 |
+
|
166 |
+
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
|
167 |
+
norm_layer = self._norm_layer
|
168 |
+
downsample = None
|
169 |
+
previous_dilation = self.dilation
|
170 |
+
if dilate:
|
171 |
+
self.dilation *= stride
|
172 |
+
stride = 1
|
173 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
174 |
+
downsample = nn.Sequential(
|
175 |
+
conv1x1(self.inplanes, planes * block.expansion, stride),
|
176 |
+
norm_layer(planes * block.expansion),
|
177 |
+
)
|
178 |
+
|
179 |
+
layers = []
|
180 |
+
layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
|
181 |
+
self.base_width, previous_dilation, norm_layer))
|
182 |
+
self.inplanes = planes * block.expansion
|
183 |
+
for _ in range(1, blocks):
|
184 |
+
layers.append(block(self.inplanes, planes, groups=self.groups,
|
185 |
+
base_width=self.base_width, dilation=self.dilation,
|
186 |
+
norm_layer=norm_layer))
|
187 |
+
|
188 |
+
return nn.Sequential(*layers)
|
189 |
+
|
190 |
+
def forward(self, x):
|
191 |
+
x = self.conv1(x)
|
192 |
+
x = self.bn1(x)
|
193 |
+
x = self.relu(x)
|
194 |
+
x = self.maxpool(x)
|
195 |
+
|
196 |
+
c2 = self.layer1(x)
|
197 |
+
c3 = self.layer2(c2)
|
198 |
+
c4 = self.layer3(c3)
|
199 |
+
c5 = self.layer4(c4)
|
200 |
+
|
201 |
+
return c2, c3, c4, c5
|
202 |
+
|
203 |
+
|
204 |
+
def _resnet(arch, block, layers, pretrained, progress, **kwargs):
|
205 |
+
model = ResNet(block, layers, **kwargs)
|
206 |
+
if pretrained:
|
207 |
+
state_dict = load_state_dict_from_url(model_urls[arch],
|
208 |
+
progress=progress)
|
209 |
+
model.load_state_dict(state_dict, strict=False)
|
210 |
+
print('load pretrained models from imagenet')
|
211 |
+
return model
|
212 |
+
|
213 |
+
|
214 |
+
def resnet18(pretrained=False, progress=True, **kwargs):
|
215 |
+
"""Constructs a ResNet-18 model.
|
216 |
+
|
217 |
+
Args:
|
218 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
219 |
+
progress (bool): If True, displays a progress bar of the download to stderr
|
220 |
+
"""
|
221 |
+
return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
|
222 |
+
**kwargs)
|
223 |
+
|
224 |
+
|
225 |
+
def resnet34(pretrained=False, progress=True, **kwargs):
|
226 |
+
"""Constructs a ResNet-34 model.
|
227 |
+
|
228 |
+
Args:
|
229 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
230 |
+
progress (bool): If True, displays a progress bar of the download to stderr
|
231 |
+
"""
|
232 |
+
return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
|
233 |
+
**kwargs)
|
234 |
+
|
235 |
+
|
236 |
+
def resnet50(pretrained=False, progress=True, **kwargs):
|
237 |
+
"""Constructs a ResNet-50 model.
|
238 |
+
|
239 |
+
Args:
|
240 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
241 |
+
progress (bool): If True, displays a progress bar of the download to stderr
|
242 |
+
"""
|
243 |
+
return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
|
244 |
+
**kwargs)
|
245 |
+
|
246 |
+
|
247 |
+
def resnet101(pretrained=False, progress=True, **kwargs):
|
248 |
+
"""Constructs a ResNet-101 model.
|
249 |
+
|
250 |
+
Args:
|
251 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
252 |
+
progress (bool): If True, displays a progress bar of the download to stderr
|
253 |
+
"""
|
254 |
+
return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
|
255 |
+
**kwargs)
|
256 |
+
|
257 |
+
|
258 |
+
def resnet152(pretrained=False, progress=True, **kwargs):
|
259 |
+
"""Constructs a ResNet-152 model.
|
260 |
+
|
261 |
+
Args:
|
262 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
263 |
+
progress (bool): If True, displays a progress bar of the download to stderr
|
264 |
+
"""
|
265 |
+
return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
|
266 |
+
**kwargs)
|
267 |
+
|
268 |
+
|
269 |
+
def resnext50_32x4d(pretrained=False, progress=True, **kwargs):
|
270 |
+
"""Constructs a ResNeXt-50 32x4d model.
|
271 |
+
|
272 |
+
Args:
|
273 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
274 |
+
progress (bool): If True, displays a progress bar of the download to stderr
|
275 |
+
"""
|
276 |
+
kwargs['groups'] = 32
|
277 |
+
kwargs['width_per_group'] = 4
|
278 |
+
return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
|
279 |
+
pretrained, progress, **kwargs)
|
280 |
+
|
281 |
+
|
282 |
+
def resnext101_32x8d(pretrained=False, progress=True, **kwargs):
|
283 |
+
"""Constructs a ResNeXt-101 32x8d model.
|
284 |
+
|
285 |
+
Args:
|
286 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
287 |
+
progress (bool): If True, displays a progress bar of the download to stderr
|
288 |
+
"""
|
289 |
+
kwargs['groups'] = 32
|
290 |
+
kwargs['width_per_group'] = 8
|
291 |
+
return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
|
292 |
+
pretrained, progress, **kwargs)
|
293 |
+
|
294 |
+
if __name__ == '__main__':
|
295 |
+
import torch
|
296 |
+
x = torch.zeros(1, 3, 640, 640)
|
297 |
+
net = resnext101_32x8d(pretrained=False)
|
298 |
+
y = net(x)
|
299 |
+
for u in y:
|
300 |
+
print(u.shape)
|
models/modules/segmentation_head.py
ADDED
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
|
8 |
+
class FPN(nn.Module):
|
9 |
+
def __init__(self, backbone_out_channels, **kwargs):
|
10 |
+
"""
|
11 |
+
:param backbone_out_channels: 基础网络输出的维度
|
12 |
+
:param kwargs:
|
13 |
+
"""
|
14 |
+
super().__init__()
|
15 |
+
# result_num = kwargs.get('result_num', 6)
|
16 |
+
inplace = True
|
17 |
+
conv_out = 256
|
18 |
+
# reduce layers
|
19 |
+
self.reduce_conv_c2 = nn.Sequential(
|
20 |
+
nn.Conv2d(backbone_out_channels[0], conv_out, kernel_size=1, stride=1, padding=0),
|
21 |
+
nn.BatchNorm2d(conv_out),
|
22 |
+
nn.ReLU(inplace=inplace)
|
23 |
+
)
|
24 |
+
self.reduce_conv_c3 = nn.Sequential(
|
25 |
+
nn.Conv2d(backbone_out_channels[1], conv_out, kernel_size=1, stride=1, padding=0),
|
26 |
+
nn.BatchNorm2d(conv_out),
|
27 |
+
nn.ReLU(inplace=inplace)
|
28 |
+
)
|
29 |
+
self.reduce_conv_c4 = nn.Sequential(
|
30 |
+
nn.Conv2d(backbone_out_channels[2], conv_out, kernel_size=1, stride=1, padding=0),
|
31 |
+
nn.BatchNorm2d(conv_out),
|
32 |
+
nn.ReLU(inplace=inplace)
|
33 |
+
)
|
34 |
+
|
35 |
+
self.reduce_conv_c5 = nn.Sequential(
|
36 |
+
nn.Conv2d(backbone_out_channels[3], conv_out, kernel_size=1, stride=1, padding=0),
|
37 |
+
nn.BatchNorm2d(conv_out),
|
38 |
+
nn.ReLU(inplace=inplace)
|
39 |
+
)
|
40 |
+
# Smooth layers
|
41 |
+
self.smooth_p4 = nn.Sequential(
|
42 |
+
nn.Conv2d(conv_out, conv_out, kernel_size=3, stride=1, padding=1),
|
43 |
+
nn.BatchNorm2d(conv_out),
|
44 |
+
nn.ReLU(inplace=inplace)
|
45 |
+
)
|
46 |
+
self.smooth_p3 = nn.Sequential(
|
47 |
+
nn.Conv2d(conv_out, conv_out, kernel_size=3, stride=1, padding=1),
|
48 |
+
nn.BatchNorm2d(conv_out),
|
49 |
+
nn.ReLU(inplace=inplace)
|
50 |
+
)
|
51 |
+
self.smooth_p2 = nn.Sequential(
|
52 |
+
nn.Conv2d(conv_out, conv_out, kernel_size=3, stride=1, padding=1),
|
53 |
+
nn.BatchNorm2d(conv_out),
|
54 |
+
nn.ReLU(inplace=inplace)
|
55 |
+
)
|
56 |
+
|
57 |
+
self.conv = nn.Sequential(
|
58 |
+
nn.Conv2d(conv_out * 4, conv_out, kernel_size=3, padding=1, stride=1),
|
59 |
+
nn.BatchNorm2d(conv_out),
|
60 |
+
nn.ReLU(inplace=inplace)
|
61 |
+
)
|
62 |
+
# self.out_conv = nn.Conv2d(conv_out, result_num, kernel_size=1, stride=1)
|
63 |
+
|
64 |
+
self.pred_conv = nn.Sequential(
|
65 |
+
nn.Conv2d(conv_out, 2, kernel_size=1, stride=1, padding=0),
|
66 |
+
nn.Sigmoid()
|
67 |
+
)
|
68 |
+
|
69 |
+
def forward(self, x):
|
70 |
+
c2, c3, c4, c5 = x
|
71 |
+
# Top-down
|
72 |
+
p5 = self.reduce_conv_c5(c5)
|
73 |
+
p4 = self._upsample_add(p5, self.reduce_conv_c4(c4))
|
74 |
+
p4 = self.smooth_p4(p4)
|
75 |
+
p3 = self._upsample_add(p4, self.reduce_conv_c3(c3))
|
76 |
+
p3 = self.smooth_p3(p3)
|
77 |
+
p2 = self._upsample_add(p3, self.reduce_conv_c2(c2))
|
78 |
+
p2 = self.smooth_p2(p2)
|
79 |
+
|
80 |
+
x = self._upsample_cat(p2, p3, p4, p5)
|
81 |
+
x = self.conv(x)
|
82 |
+
|
83 |
+
# x = self.out_conv(x)
|
84 |
+
|
85 |
+
x = self.pred_conv(x)
|
86 |
+
return x
|
87 |
+
|
88 |
+
def _upsample_add(self, x, y):
|
89 |
+
return F.interpolate(x, size=y.size()[2:], mode='bilinear', align_corners=True) + y
|
90 |
+
|
91 |
+
def _upsample_cat(self, p2, p3, p4, p5):
|
92 |
+
h, w = p2.size()[2:]
|
93 |
+
p3 = F.interpolate(p3, size=(h, w), mode='bilinear', align_corners=True)
|
94 |
+
p4 = F.interpolate(p4, size=(h, w), mode='bilinear', align_corners=True)
|
95 |
+
p5 = F.interpolate(p5, size=(h, w), mode='bilinear', align_corners=True)
|
96 |
+
return torch.cat([p2, p3, p4, p5], dim=1)
|
97 |
+
|
98 |
+
|
99 |
+
class FPEM_FFM(nn.Module):
|
100 |
+
def __init__(self, backbone_out_channels, **kwargs):
|
101 |
+
"""
|
102 |
+
PANnet
|
103 |
+
:param backbone_out_channels: 基础网络输出的维度
|
104 |
+
"""
|
105 |
+
super().__init__()
|
106 |
+
fpem_repeat = kwargs.get('fpem_repeat', 2)
|
107 |
+
conv_out = 128
|
108 |
+
# reduce layers
|
109 |
+
self.reduce_conv_c2 = nn.Sequential(
|
110 |
+
nn.Conv2d(in_channels=backbone_out_channels[0], out_channels=conv_out, kernel_size=1),
|
111 |
+
nn.BatchNorm2d(conv_out),
|
112 |
+
nn.ReLU()
|
113 |
+
)
|
114 |
+
self.reduce_conv_c3 = nn.Sequential(
|
115 |
+
nn.Conv2d(in_channels=backbone_out_channels[1], out_channels=conv_out, kernel_size=1),
|
116 |
+
nn.BatchNorm2d(conv_out),
|
117 |
+
nn.ReLU()
|
118 |
+
)
|
119 |
+
self.reduce_conv_c4 = nn.Sequential(
|
120 |
+
nn.Conv2d(in_channels=backbone_out_channels[2], out_channels=conv_out, kernel_size=1),
|
121 |
+
nn.BatchNorm2d(conv_out),
|
122 |
+
nn.ReLU()
|
123 |
+
)
|
124 |
+
self.reduce_conv_c5 = nn.Sequential(
|
125 |
+
nn.Conv2d(in_channels=backbone_out_channels[3], out_channels=conv_out, kernel_size=1),
|
126 |
+
nn.BatchNorm2d(conv_out),
|
127 |
+
nn.ReLU()
|
128 |
+
)
|
129 |
+
self.fpems = nn.ModuleList()
|
130 |
+
for i in range(fpem_repeat):
|
131 |
+
self.fpems.append(FPEM(conv_out))
|
132 |
+
self.out_conv = nn.Conv2d(in_channels=conv_out * 4, out_channels=6, kernel_size=1)
|
133 |
+
|
134 |
+
def forward(self, x):
|
135 |
+
c2, c3, c4, c5 = x
|
136 |
+
# reduce channel
|
137 |
+
c2 = self.reduce_conv_c2(c2)
|
138 |
+
c3 = self.reduce_conv_c3(c3)
|
139 |
+
c4 = self.reduce_conv_c4(c4)
|
140 |
+
c5 = self.reduce_conv_c5(c5)
|
141 |
+
|
142 |
+
# FPEM
|
143 |
+
for i, fpem in enumerate(self.fpems):
|
144 |
+
c2, c3, c4, c5 = fpem(c2, c3, c4, c5)
|
145 |
+
if i == 0:
|
146 |
+
c2_ffm = c2
|
147 |
+
c3_ffm = c3
|
148 |
+
c4_ffm = c4
|
149 |
+
c5_ffm = c5
|
150 |
+
else:
|
151 |
+
c2_ffm += c2
|
152 |
+
c3_ffm += c3
|
153 |
+
c4_ffm += c4
|
154 |
+
c5_ffm += c5
|
155 |
+
|
156 |
+
# FFM
|
157 |
+
c5 = F.interpolate(c5_ffm, c2_ffm.size()[-2:], mode='bilinear')
|
158 |
+
c4 = F.interpolate(c4_ffm, c2_ffm.size()[-2:], mode='bilinear')
|
159 |
+
c3 = F.interpolate(c3_ffm, c2_ffm.size()[-2:], mode='bilinear')
|
160 |
+
Fy = torch.cat([c2_ffm, c3, c4, c5], dim=1)
|
161 |
+
y = self.out_conv(Fy)
|
162 |
+
return y
|
163 |
+
|
164 |
+
|
165 |
+
class FPEM(nn.Module):
|
166 |
+
def __init__(self, in_channels=128):
|
167 |
+
super().__init__()
|
168 |
+
self.up_add1 = SeparableConv2d(in_channels, in_channels, 1)
|
169 |
+
self.up_add2 = SeparableConv2d(in_channels, in_channels, 1)
|
170 |
+
self.up_add3 = SeparableConv2d(in_channels, in_channels, 1)
|
171 |
+
self.down_add1 = SeparableConv2d(in_channels, in_channels, 2)
|
172 |
+
self.down_add2 = SeparableConv2d(in_channels, in_channels, 2)
|
173 |
+
self.down_add3 = SeparableConv2d(in_channels, in_channels, 2)
|
174 |
+
|
175 |
+
def forward(self, c2, c3, c4, c5):
|
176 |
+
# up阶段
|
177 |
+
c4 = self.up_add1(self._upsample_add(c5, c4))
|
178 |
+
c3 = self.up_add2(self._upsample_add(c4, c3))
|
179 |
+
c2 = self.up_add3(self._upsample_add(c3, c2))
|
180 |
+
|
181 |
+
# down 阶段
|
182 |
+
c3 = self.down_add1(self._upsample_add(c3, c2))
|
183 |
+
c4 = self.down_add2(self._upsample_add(c4, c3))
|
184 |
+
c5 = self.down_add3(self._upsample_add(c5, c4))
|
185 |
+
return c2, c3, c4, c5
|
186 |
+
|
187 |
+
def _upsample_add(self, x, y):
|
188 |
+
return F.interpolate(x, size=y.size()[2:], mode='bilinear') + y
|
189 |
+
|
190 |
+
|
191 |
+
class SeparableConv2d(nn.Module):
|
192 |
+
def __init__(self, in_channels, out_channels, stride=1):
|
193 |
+
super(SeparableConv2d, self).__init__()
|
194 |
+
|
195 |
+
self.depthwise_conv = nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=3, padding=1,
|
196 |
+
stride=stride, groups=in_channels)
|
197 |
+
self.pointwise_conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1)
|
198 |
+
self.bn = nn.BatchNorm2d(out_channels)
|
199 |
+
self.relu = nn.ReLU()
|
200 |
+
|
201 |
+
def forward(self, x):
|
202 |
+
x = self.depthwise_conv(x)
|
203 |
+
x = self.pointwise_conv(x)
|
204 |
+
x = self.bn(x)
|
205 |
+
x = self.relu(x)
|
206 |
+
return x
|
models/modules/shufflenetv2.py
ADDED
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from torch.hub import load_state_dict_from_url
|
6 |
+
#from torchvision.models.utils import load_state_dict_from_url
|
7 |
+
|
8 |
+
__all__ = [
|
9 |
+
'ShuffleNetV2', 'shufflenet_v2_x0_5', 'shufflenet_v2_x1_0',
|
10 |
+
'shufflenet_v2_x1_5', 'shufflenet_v2_x2_0'
|
11 |
+
]
|
12 |
+
|
13 |
+
model_urls = {
|
14 |
+
'shufflenetv2_x0.5': 'https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth',
|
15 |
+
'shufflenetv2_x1.0': 'https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth',
|
16 |
+
'shufflenetv2_x1.5': None,
|
17 |
+
'shufflenetv2_x2.0': None,
|
18 |
+
}
|
19 |
+
|
20 |
+
def channel_shuffle(x, groups):
|
21 |
+
batchsize, num_channels, height, width = x.data.size()
|
22 |
+
channels_per_group = num_channels // groups
|
23 |
+
|
24 |
+
# reshape
|
25 |
+
x = x.view(batchsize, groups,
|
26 |
+
channels_per_group, height, width)
|
27 |
+
|
28 |
+
x = torch.transpose(x, 1, 2).contiguous()
|
29 |
+
|
30 |
+
# flatten
|
31 |
+
x = x.view(batchsize, -1, height, width)
|
32 |
+
|
33 |
+
return x
|
34 |
+
|
35 |
+
|
36 |
+
class InvertedResidual(nn.Module):
|
37 |
+
def __init__(self, inp, oup, stride):
|
38 |
+
super(InvertedResidual, self).__init__()
|
39 |
+
|
40 |
+
if not (1 <= stride <= 3):
|
41 |
+
raise ValueError('illegal stride value')
|
42 |
+
self.stride = stride
|
43 |
+
|
44 |
+
branch_features = oup // 2
|
45 |
+
assert (self.stride != 1) or (inp == branch_features << 1)
|
46 |
+
|
47 |
+
if self.stride > 1:
|
48 |
+
self.branch1 = nn.Sequential(
|
49 |
+
self.depthwise_conv(inp, inp, kernel_size=3, stride=self.stride, padding=1),
|
50 |
+
nn.BatchNorm2d(inp),
|
51 |
+
nn.Conv2d(inp, branch_features, kernel_size=1, stride=1, padding=0, bias=False),
|
52 |
+
nn.BatchNorm2d(branch_features),
|
53 |
+
nn.ReLU(inplace=True),
|
54 |
+
)
|
55 |
+
|
56 |
+
self.branch2 = nn.Sequential(
|
57 |
+
nn.Conv2d(inp if (self.stride > 1) else branch_features,
|
58 |
+
branch_features, kernel_size=1, stride=1, padding=0, bias=False),
|
59 |
+
nn.BatchNorm2d(branch_features),
|
60 |
+
nn.ReLU(inplace=True),
|
61 |
+
self.depthwise_conv(branch_features, branch_features, kernel_size=3, stride=self.stride, padding=1),
|
62 |
+
nn.BatchNorm2d(branch_features),
|
63 |
+
nn.Conv2d(branch_features, branch_features, kernel_size=1, stride=1, padding=0, bias=False),
|
64 |
+
nn.BatchNorm2d(branch_features),
|
65 |
+
nn.ReLU(inplace=True),
|
66 |
+
)
|
67 |
+
|
68 |
+
@staticmethod
|
69 |
+
def depthwise_conv(i, o, kernel_size, stride=1, padding=0, bias=False):
|
70 |
+
return nn.Conv2d(i, o, kernel_size, stride, padding, bias=bias, groups=i)
|
71 |
+
|
72 |
+
def forward(self, x):
|
73 |
+
if self.stride == 1:
|
74 |
+
x1, x2 = x.chunk(2, dim=1)
|
75 |
+
out = torch.cat((x1, self.branch2(x2)), dim=1)
|
76 |
+
else:
|
77 |
+
out = torch.cat((self.branch1(x), self.branch2(x)), dim=1)
|
78 |
+
|
79 |
+
out = channel_shuffle(out, 2)
|
80 |
+
|
81 |
+
return out
|
82 |
+
|
83 |
+
|
84 |
+
class ShuffleNetV2(nn.Module):
|
85 |
+
def __init__(self, stages_repeats, stages_out_channels, num_classes=1000):
|
86 |
+
super(ShuffleNetV2, self).__init__()
|
87 |
+
|
88 |
+
if len(stages_repeats) != 3:
|
89 |
+
raise ValueError('expected stages_repeats as list of 3 positive ints')
|
90 |
+
if len(stages_out_channels) != 5:
|
91 |
+
raise ValueError('expected stages_out_channels as list of 5 positive ints')
|
92 |
+
self._stage_out_channels = stages_out_channels
|
93 |
+
|
94 |
+
input_channels = 3
|
95 |
+
output_channels = self._stage_out_channels[0]
|
96 |
+
self.conv1 = nn.Sequential(
|
97 |
+
nn.Conv2d(input_channels, output_channels, 3, 2, 1, bias=False),
|
98 |
+
nn.BatchNorm2d(output_channels),
|
99 |
+
nn.ReLU(inplace=True),
|
100 |
+
)
|
101 |
+
input_channels = output_channels
|
102 |
+
|
103 |
+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
104 |
+
|
105 |
+
stage_names = ['stage{}'.format(i) for i in [2, 3, 4]]
|
106 |
+
for name, repeats, output_channels in zip(
|
107 |
+
stage_names, stages_repeats, self._stage_out_channels[1:]):
|
108 |
+
seq = [InvertedResidual(input_channels, output_channels, 2)]
|
109 |
+
for i in range(repeats - 1):
|
110 |
+
seq.append(InvertedResidual(output_channels, output_channels, 1))
|
111 |
+
setattr(self, name, nn.Sequential(*seq))
|
112 |
+
input_channels = output_channels
|
113 |
+
|
114 |
+
output_channels = self._stage_out_channels[-1]
|
115 |
+
self.conv5 = nn.Sequential(
|
116 |
+
nn.Conv2d(input_channels, output_channels, 1, 1, 0, bias=False),
|
117 |
+
nn.BatchNorm2d(output_channels),
|
118 |
+
nn.ReLU(inplace=True),
|
119 |
+
)
|
120 |
+
|
121 |
+
def forward(self, x):
|
122 |
+
x = self.conv1(x)
|
123 |
+
c2 = self.maxpool(x)
|
124 |
+
c3 = self.stage2(c2)
|
125 |
+
c4 = self.stage3(c3)
|
126 |
+
c5 = self.stage4(c4)
|
127 |
+
# c5 = self.conv5(c5)
|
128 |
+
return c2, c3, c4, c5
|
129 |
+
|
130 |
+
|
131 |
+
def _shufflenetv2(arch, pretrained, progress, *args, **kwargs):
|
132 |
+
model = ShuffleNetV2(*args, **kwargs)
|
133 |
+
|
134 |
+
if pretrained:
|
135 |
+
model_url = model_urls[arch]
|
136 |
+
if model_url is None:
|
137 |
+
raise NotImplementedError('pretrained {} is not supported as of now'.format(arch))
|
138 |
+
else:
|
139 |
+
state_dict = load_state_dict_from_url(model_url, progress=progress)
|
140 |
+
model.load_state_dict(state_dict,strict=False)
|
141 |
+
|
142 |
+
return model
|
143 |
+
|
144 |
+
|
145 |
+
def shufflenet_v2_x0_5(pretrained=False, progress=True, **kwargs):
|
146 |
+
"""
|
147 |
+
Constructs a ShuffleNetV2 with 0.5x output channels, as described in
|
148 |
+
`"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
|
149 |
+
<https://arxiv.org/abs/1807.11164>`_.
|
150 |
+
|
151 |
+
Args:
|
152 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
153 |
+
progress (bool): If True, displays a progress bar of the download to stderr
|
154 |
+
"""
|
155 |
+
return _shufflenetv2('shufflenetv2_x0.5', pretrained, progress,
|
156 |
+
[4, 8, 4], [24, 48, 96, 192, 1024], **kwargs)
|
157 |
+
|
158 |
+
|
159 |
+
def shufflenet_v2_x1_0(pretrained=False, progress=True, **kwargs):
|
160 |
+
"""
|
161 |
+
Constructs a ShuffleNetV2 with 1.0x output channels, as described in
|
162 |
+
`"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
|
163 |
+
<https://arxiv.org/abs/1807.11164>`_.
|
164 |
+
|
165 |
+
Args:
|
166 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
167 |
+
progress (bool): If True, displays a progress bar of the download to stderr
|
168 |
+
"""
|
169 |
+
return _shufflenetv2('shufflenetv2_x1.0', pretrained, progress,
|
170 |
+
[4, 8, 4], [24, 116, 232, 464, 1024], **kwargs)
|
171 |
+
|
172 |
+
|
173 |
+
def shufflenet_v2_x1_5(pretrained=False, progress=True, **kwargs):
|
174 |
+
"""
|
175 |
+
Constructs a ShuffleNetV2 with 1.5x output channels, as described in
|
176 |
+
`"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
|
177 |
+
<https://arxiv.org/abs/1807.11164>`_.
|
178 |
+
|
179 |
+
Args:
|
180 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
181 |
+
progress (bool): If True, displays a progress bar of the download to stderr
|
182 |
+
"""
|
183 |
+
return _shufflenetv2('shufflenetv2_x1.5', pretrained, progress,
|
184 |
+
[4, 8, 4], [24, 176, 352, 704, 1024], **kwargs)
|
185 |
+
|
186 |
+
|
187 |
+
def shufflenet_v2_x2_0(pretrained=False, progress=True, **kwargs):
|
188 |
+
"""
|
189 |
+
Constructs a ShuffleNetV2 with 2.0x output channels, as described in
|
190 |
+
`"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
|
191 |
+
<https://arxiv.org/abs/1807.11164>`_.
|
192 |
+
|
193 |
+
Args:
|
194 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
195 |
+
progress (bool): If True, displays a progress bar of the download to stderr
|
196 |
+
"""
|
197 |
+
return _shufflenetv2('shufflenetv2_x2.0', pretrained, progress,
|
198 |
+
[4, 8, 4], [24, 244, 488, 976, 2048], **kwargs)
|
requirements.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
torchvision
|
3 |
+
opencv-python
|
4 |
+
numpy
|
5 |
+
pyclipper
|
6 |
+
gradio
|
7 |
+
matplotlib
|
utils/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
from .util import *
|
4 |
+
from .metrics import *
|
5 |
+
from .schedulers import *
|
6 |
+
from .cal_recall.script import cal_recall_precison_f1
|
utils/cal_recall/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
from .script import cal_recall_precison_f1
|
4 |
+
__all__ = ['cal_recall_precison_f1']
|
utils/cal_recall/rrc_evaluation_funcs.py
ADDED
@@ -0,0 +1,394 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python2
|
2 |
+
#encoding: UTF-8
|
3 |
+
import json
|
4 |
+
import sys;sys.path.append('./')
|
5 |
+
import zipfile
|
6 |
+
import re
|
7 |
+
import sys
|
8 |
+
import os
|
9 |
+
import codecs
|
10 |
+
import traceback
|
11 |
+
import numpy as np
|
12 |
+
from utils import order_points_clockwise
|
13 |
+
|
14 |
+
def print_help():
|
15 |
+
sys.stdout.write('Usage: python %s.py -g=<gtFile> -s=<submFile> [-o=<outputFolder> -p=<jsonParams>]' %sys.argv[0])
|
16 |
+
sys.exit(2)
|
17 |
+
|
18 |
+
|
19 |
+
def load_zip_file_keys(file,fileNameRegExp=''):
|
20 |
+
"""
|
21 |
+
Returns an array with the entries of the ZIP file that match with the regular expression.
|
22 |
+
The key's are the names or the file or the capturing group definied in the fileNameRegExp
|
23 |
+
"""
|
24 |
+
try:
|
25 |
+
archive=zipfile.ZipFile(file, mode='r', allowZip64=True)
|
26 |
+
except :
|
27 |
+
raise Exception('Error loading the ZIP archive.')
|
28 |
+
|
29 |
+
pairs = []
|
30 |
+
|
31 |
+
for name in archive.namelist():
|
32 |
+
addFile = True
|
33 |
+
keyName = name
|
34 |
+
if fileNameRegExp!="":
|
35 |
+
m = re.match(fileNameRegExp,name)
|
36 |
+
if m == None:
|
37 |
+
addFile = False
|
38 |
+
else:
|
39 |
+
if len(m.groups())>0:
|
40 |
+
keyName = m.group(1)
|
41 |
+
|
42 |
+
if addFile:
|
43 |
+
pairs.append( keyName )
|
44 |
+
|
45 |
+
return pairs
|
46 |
+
|
47 |
+
|
48 |
+
def load_zip_file(file,fileNameRegExp='',allEntries=False):
|
49 |
+
"""
|
50 |
+
Returns an array with the contents (filtered by fileNameRegExp) of a ZIP file.
|
51 |
+
The key's are the names or the file or the capturing group definied in the fileNameRegExp
|
52 |
+
allEntries validates that all entries in the ZIP file pass the fileNameRegExp
|
53 |
+
"""
|
54 |
+
try:
|
55 |
+
archive=zipfile.ZipFile(file, mode='r', allowZip64=True)
|
56 |
+
except :
|
57 |
+
raise Exception('Error loading the ZIP archive')
|
58 |
+
|
59 |
+
pairs = []
|
60 |
+
for name in archive.namelist():
|
61 |
+
addFile = True
|
62 |
+
keyName = name
|
63 |
+
if fileNameRegExp!="":
|
64 |
+
m = re.match(fileNameRegExp,name)
|
65 |
+
if m == None:
|
66 |
+
addFile = False
|
67 |
+
else:
|
68 |
+
if len(m.groups())>0:
|
69 |
+
keyName = m.group(1)
|
70 |
+
|
71 |
+
if addFile:
|
72 |
+
pairs.append( [ keyName , archive.read(name)] )
|
73 |
+
else:
|
74 |
+
if allEntries:
|
75 |
+
raise Exception('ZIP entry not valid: %s' %name)
|
76 |
+
|
77 |
+
return dict(pairs)
|
78 |
+
|
79 |
+
|
80 |
+
def load_folder_file(file, fileNameRegExp='', allEntries=False):
|
81 |
+
"""
|
82 |
+
Returns an array with the contents (filtered by fileNameRegExp) of a ZIP file.
|
83 |
+
The key's are the names or the file or the capturing group definied in the fileNameRegExp
|
84 |
+
allEntries validates that all entries in the ZIP file pass the fileNameRegExp
|
85 |
+
"""
|
86 |
+
pairs = []
|
87 |
+
for name in os.listdir(file):
|
88 |
+
addFile = True
|
89 |
+
keyName = name
|
90 |
+
print(name)
|
91 |
+
print(fileNameRegExp)
|
92 |
+
if fileNameRegExp != "":
|
93 |
+
m = re.match(fileNameRegExp, name)
|
94 |
+
print(m)
|
95 |
+
if m == None:
|
96 |
+
addFile = False
|
97 |
+
else:
|
98 |
+
if len(m.groups()) > 0:
|
99 |
+
keyName = m.group(1)
|
100 |
+
|
101 |
+
if addFile:
|
102 |
+
pairs.append([keyName, open(os.path.join(file,name),encoding="utf-8").read()])
|
103 |
+
else:
|
104 |
+
if allEntries:
|
105 |
+
raise Exception('ZIP entry not valid: %s' % name)
|
106 |
+
|
107 |
+
return dict(pairs)
|
108 |
+
|
109 |
+
|
110 |
+
def decode_utf8(raw):
|
111 |
+
"""
|
112 |
+
Returns a Unicode object on success, or None on failure
|
113 |
+
"""
|
114 |
+
try:
|
115 |
+
raw = codecs.decode(raw,'utf-8', 'replace')
|
116 |
+
#extracts BOM if exists
|
117 |
+
raw = raw.encode('utf8')
|
118 |
+
if raw.startswith(codecs.BOM_UTF8):
|
119 |
+
raw = raw.replace(codecs.BOM_UTF8, '', 1)
|
120 |
+
return raw.decode('utf-8')
|
121 |
+
except:
|
122 |
+
return None
|
123 |
+
|
124 |
+
def validate_lines_in_file(fileName,file_contents,CRLF=True,LTRB=True,withTranscription=False,withConfidence=False,imWidth=0,imHeight=0):
|
125 |
+
"""
|
126 |
+
This function validates that all lines of the file calling the Line validation function for each line
|
127 |
+
"""
|
128 |
+
utf8File = decode_utf8(file_contents)
|
129 |
+
if (utf8File is None) :
|
130 |
+
raise Exception("The file %s is not UTF-8" %fileName)
|
131 |
+
|
132 |
+
lines = utf8File.split( "\r\n" if CRLF else "\n" )
|
133 |
+
for line in lines:
|
134 |
+
line = line.replace("\r","").replace("\n","")
|
135 |
+
if(line != ""):
|
136 |
+
try:
|
137 |
+
validate_tl_line(line,LTRB,withTranscription,withConfidence,imWidth,imHeight)
|
138 |
+
except Exception as e:
|
139 |
+
raise Exception(("Line in sample not valid. Sample: %s Line: %s Error: %s" %(fileName,line,str(e))).encode('utf-8', 'replace'))
|
140 |
+
|
141 |
+
|
142 |
+
|
143 |
+
def validate_tl_line(line,LTRB=True,withTranscription=True,withConfidence=True,imWidth=0,imHeight=0):
|
144 |
+
"""
|
145 |
+
Validate the format of the line. If the line is not valid an exception will be raised.
|
146 |
+
If maxWidth and maxHeight are specified, all points must be inside the imgage bounds.
|
147 |
+
Posible values are:
|
148 |
+
LTRB=True: xmin,ymin,xmax,ymax[,confidence][,transcription]
|
149 |
+
LTRB=False: x1,y1,x2,y2,x3,y3,x4,y4[,confidence][,transcription]
|
150 |
+
"""
|
151 |
+
get_tl_line_values(line,LTRB,withTranscription,withConfidence,imWidth,imHeight)
|
152 |
+
|
153 |
+
|
154 |
+
def get_tl_line_values(line,LTRB=True,withTranscription=False,withConfidence=False,imWidth=0,imHeight=0):
|
155 |
+
"""
|
156 |
+
Validate the format of the line. If the line is not valid an exception will be raised.
|
157 |
+
If maxWidth and maxHeight are specified, all points must be inside the imgage bounds.
|
158 |
+
Posible values are:
|
159 |
+
LTRB=True: xmin,ymin,xmax,ymax[,confidence][,transcription]
|
160 |
+
LTRB=False: x1,y1,x2,y2,x3,y3,x4,y4[,confidence][,transcription]
|
161 |
+
Returns values from a textline. Points , [Confidences], [Transcriptions]
|
162 |
+
"""
|
163 |
+
confidence = 0.0
|
164 |
+
transcription = "";
|
165 |
+
points = []
|
166 |
+
|
167 |
+
numPoints = 4;
|
168 |
+
|
169 |
+
if LTRB:
|
170 |
+
|
171 |
+
numPoints = 4;
|
172 |
+
|
173 |
+
if withTranscription and withConfidence:
|
174 |
+
m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-1].?[0-9]*)\s*,(.*)$',line)
|
175 |
+
if m == None :
|
176 |
+
m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-1].?[0-9]*)\s*,(.*)$',line)
|
177 |
+
raise Exception("Format incorrect. Should be: xmin,ymin,xmax,ymax,confidence,transcription")
|
178 |
+
elif withConfidence:
|
179 |
+
m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-1].?[0-9]*)\s*$',line)
|
180 |
+
if m == None :
|
181 |
+
raise Exception("Format incorrect. Should be: xmin,ymin,xmax,ymax,confidence")
|
182 |
+
elif withTranscription:
|
183 |
+
m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,(.*)$',line)
|
184 |
+
if m == None :
|
185 |
+
raise Exception("Format incorrect. Should be: xmin,ymin,xmax,ymax,transcription")
|
186 |
+
else:
|
187 |
+
m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,?\s*$',line)
|
188 |
+
if m == None :
|
189 |
+
raise Exception("Format incorrect. Should be: xmin,ymin,xmax,ymax")
|
190 |
+
|
191 |
+
xmin = int(m.group(1))
|
192 |
+
ymin = int(m.group(2))
|
193 |
+
xmax = int(m.group(3))
|
194 |
+
ymax = int(m.group(4))
|
195 |
+
if(xmax<xmin):
|
196 |
+
raise Exception("Xmax value (%s) not valid (Xmax < Xmin)." %(xmax))
|
197 |
+
if(ymax<ymin):
|
198 |
+
raise Exception("Ymax value (%s) not valid (Ymax < Ymin)." %(ymax))
|
199 |
+
|
200 |
+
points = [ float(m.group(i)) for i in range(1, (numPoints+1) ) ]
|
201 |
+
|
202 |
+
if (imWidth>0 and imHeight>0):
|
203 |
+
validate_point_inside_bounds(xmin,ymin,imWidth,imHeight);
|
204 |
+
validate_point_inside_bounds(xmax,ymax,imWidth,imHeight);
|
205 |
+
|
206 |
+
else:
|
207 |
+
|
208 |
+
numPoints = 8;
|
209 |
+
|
210 |
+
if withTranscription and withConfidence:
|
211 |
+
m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-1].?[0-9]*)\s*,(.*)$',line)
|
212 |
+
if m == None :
|
213 |
+
raise Exception("Format incorrect. Should be: x1,y1,x2,y2,x3,y3,x4,y4,confidence,transcription")
|
214 |
+
elif withConfidence:
|
215 |
+
m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-1].?[0-9]*)\s*$',line)
|
216 |
+
if m == None :
|
217 |
+
raise Exception("Format incorrect. Should be: x1,y1,x2,y2,x3,y3,x4,y4,confidence")
|
218 |
+
elif withTranscription:
|
219 |
+
m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,(.*)$',line)
|
220 |
+
if m == None :
|
221 |
+
raise Exception("Format incorrect. Should be: x1,y1,x2,y2,x3,y3,x4,y4,transcription")
|
222 |
+
else:
|
223 |
+
m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*$',line)
|
224 |
+
if m == None :
|
225 |
+
raise Exception("Format incorrect. Should be: x1,y1,x2,y2,x3,y3,x4,y4")
|
226 |
+
|
227 |
+
points = [ float(m.group(i)) for i in range(1, (numPoints+1) ) ]
|
228 |
+
|
229 |
+
points = order_points_clockwise(np.array(points).reshape(-1, 2)).reshape(-1)
|
230 |
+
validate_clockwise_points(points)
|
231 |
+
|
232 |
+
if (imWidth>0 and imHeight>0):
|
233 |
+
validate_point_inside_bounds(points[0],points[1],imWidth,imHeight);
|
234 |
+
validate_point_inside_bounds(points[2],points[3],imWidth,imHeight);
|
235 |
+
validate_point_inside_bounds(points[4],points[5],imWidth,imHeight);
|
236 |
+
validate_point_inside_bounds(points[6],points[7],imWidth,imHeight);
|
237 |
+
|
238 |
+
|
239 |
+
if withConfidence:
|
240 |
+
try:
|
241 |
+
confidence = float(m.group(numPoints+1))
|
242 |
+
except ValueError:
|
243 |
+
raise Exception("Confidence value must be a float")
|
244 |
+
|
245 |
+
if withTranscription:
|
246 |
+
posTranscription = numPoints + (2 if withConfidence else 1)
|
247 |
+
transcription = m.group(posTranscription)
|
248 |
+
m2 = re.match(r'^\s*\"(.*)\"\s*$',transcription)
|
249 |
+
if m2 != None : #Transcription with double quotes, we extract the value and replace escaped characters
|
250 |
+
transcription = m2.group(1).replace("\\\\", "\\").replace("\\\"", "\"")
|
251 |
+
|
252 |
+
return points,confidence,transcription
|
253 |
+
|
254 |
+
|
255 |
+
def validate_point_inside_bounds(x,y,imWidth,imHeight):
|
256 |
+
if(x<0 or x>imWidth):
|
257 |
+
raise Exception("X value (%s) not valid. Image dimensions: (%s,%s)" %(xmin,imWidth,imHeight))
|
258 |
+
if(y<0 or y>imHeight):
|
259 |
+
raise Exception("Y value (%s) not valid. Image dimensions: (%s,%s) Sample: %s Line:%s" %(ymin,imWidth,imHeight))
|
260 |
+
|
261 |
+
def validate_clockwise_points(points):
|
262 |
+
"""
|
263 |
+
Validates that the points that the 4 points that dlimite a polygon are in clockwise order.
|
264 |
+
"""
|
265 |
+
|
266 |
+
if len(points) != 8:
|
267 |
+
raise Exception("Points list not valid." + str(len(points)))
|
268 |
+
|
269 |
+
point = [
|
270 |
+
[int(points[0]) , int(points[1])],
|
271 |
+
[int(points[2]) , int(points[3])],
|
272 |
+
[int(points[4]) , int(points[5])],
|
273 |
+
[int(points[6]) , int(points[7])]
|
274 |
+
]
|
275 |
+
edge = [
|
276 |
+
( point[1][0] - point[0][0])*( point[1][1] + point[0][1]),
|
277 |
+
( point[2][0] - point[1][0])*( point[2][1] + point[1][1]),
|
278 |
+
( point[3][0] - point[2][0])*( point[3][1] + point[2][1]),
|
279 |
+
( point[0][0] - point[3][0])*( point[0][1] + point[3][1])
|
280 |
+
]
|
281 |
+
|
282 |
+
summatory = edge[0] + edge[1] + edge[2] + edge[3];
|
283 |
+
if summatory>0:
|
284 |
+
raise Exception("Points are not clockwise. The coordinates of bounding quadrilaterals have to be given in clockwise order. Regarding the correct interpretation of 'clockwise' remember that the image coordinate system used is the standard one, with the image origin at the upper left, the X axis extending to the right and Y axis extending downwards.")
|
285 |
+
|
286 |
+
def get_tl_line_values_from_file_contents(content,CRLF=True,LTRB=True,withTranscription=False,withConfidence=False,imWidth=0,imHeight=0,sort_by_confidences=True):
|
287 |
+
"""
|
288 |
+
Returns all points, confindences and transcriptions of a file in lists. Valid line formats:
|
289 |
+
xmin,ymin,xmax,ymax,[confidence],[transcription]
|
290 |
+
x1,y1,x2,y2,x3,y3,x4,y4,[confidence],[transcription]
|
291 |
+
"""
|
292 |
+
pointsList = []
|
293 |
+
transcriptionsList = []
|
294 |
+
confidencesList = []
|
295 |
+
|
296 |
+
lines = content.split( "\r\n" if CRLF else "\n" )
|
297 |
+
for line in lines:
|
298 |
+
line = line.replace("\r","").replace("\n","")
|
299 |
+
if(line != "") :
|
300 |
+
points, confidence, transcription = get_tl_line_values(line,LTRB,withTranscription,withConfidence,imWidth,imHeight);
|
301 |
+
pointsList.append(points)
|
302 |
+
transcriptionsList.append(transcription)
|
303 |
+
confidencesList.append(confidence)
|
304 |
+
|
305 |
+
if withConfidence and len(confidencesList)>0 and sort_by_confidences:
|
306 |
+
import numpy as np
|
307 |
+
sorted_ind = np.argsort(-np.array(confidencesList))
|
308 |
+
confidencesList = [confidencesList[i] for i in sorted_ind]
|
309 |
+
pointsList = [pointsList[i] for i in sorted_ind]
|
310 |
+
transcriptionsList = [transcriptionsList[i] for i in sorted_ind]
|
311 |
+
|
312 |
+
return pointsList,confidencesList,transcriptionsList
|
313 |
+
|
314 |
+
def main_evaluation(p,default_evaluation_params_fn,validate_data_fn,evaluate_method_fn,show_result=True,per_sample=True):
|
315 |
+
"""
|
316 |
+
This process validates a method, evaluates it and if it succed generates a ZIP file with a JSON entry for each sample.
|
317 |
+
Params:
|
318 |
+
p: Dictionary of parmeters with the GT/submission locations. If None is passed, the parameters send by the system are used.
|
319 |
+
default_evaluation_params_fn: points to a function that returns a dictionary with the default parameters used for the evaluation
|
320 |
+
validate_data_fn: points to a method that validates the corrct format of the submission
|
321 |
+
evaluate_method_fn: points to a function that evaluated the submission and return a Dictionary with the results
|
322 |
+
"""
|
323 |
+
evalParams = default_evaluation_params_fn()
|
324 |
+
if 'p' in p.keys():
|
325 |
+
evalParams.update( p['p'] if isinstance(p['p'], dict) else json.loads(p['p'][1:-1]) )
|
326 |
+
|
327 |
+
resDict={'calculated':True,'Message':'','method':'{}','per_sample':'{}'}
|
328 |
+
try:
|
329 |
+
# validate_data_fn(p['g'], p['s'], evalParams)
|
330 |
+
evalData = evaluate_method_fn(p['g'], p['s'], evalParams)
|
331 |
+
resDict.update(evalData)
|
332 |
+
|
333 |
+
except Exception as e:
|
334 |
+
traceback.print_exc()
|
335 |
+
resDict['Message']= str(e)
|
336 |
+
resDict['calculated']=False
|
337 |
+
|
338 |
+
if 'o' in p:
|
339 |
+
if not os.path.exists(p['o']):
|
340 |
+
os.makedirs(p['o'])
|
341 |
+
|
342 |
+
resultsOutputname = p['o'] + '/results.zip'
|
343 |
+
outZip = zipfile.ZipFile(resultsOutputname, mode='w', allowZip64=True)
|
344 |
+
|
345 |
+
del resDict['per_sample']
|
346 |
+
if 'output_items' in resDict.keys():
|
347 |
+
del resDict['output_items']
|
348 |
+
|
349 |
+
outZip.writestr('method.json',json.dumps(resDict))
|
350 |
+
|
351 |
+
if not resDict['calculated']:
|
352 |
+
if show_result:
|
353 |
+
sys.stderr.write('Error!\n'+ resDict['Message']+'\n\n')
|
354 |
+
if 'o' in p:
|
355 |
+
outZip.close()
|
356 |
+
return resDict
|
357 |
+
|
358 |
+
if 'o' in p:
|
359 |
+
if per_sample == True:
|
360 |
+
for k,v in evalData['per_sample'].iteritems():
|
361 |
+
outZip.writestr( k + '.json',json.dumps(v))
|
362 |
+
|
363 |
+
if 'output_items' in evalData.keys():
|
364 |
+
for k, v in evalData['output_items'].iteritems():
|
365 |
+
outZip.writestr( k,v)
|
366 |
+
|
367 |
+
outZip.close()
|
368 |
+
|
369 |
+
if show_result:
|
370 |
+
sys.stdout.write("Calculated!")
|
371 |
+
sys.stdout.write(json.dumps(resDict['method']))
|
372 |
+
|
373 |
+
return resDict
|
374 |
+
|
375 |
+
|
376 |
+
def main_validation(default_evaluation_params_fn,validate_data_fn):
|
377 |
+
"""
|
378 |
+
This process validates a method
|
379 |
+
Params:
|
380 |
+
default_evaluation_params_fn: points to a function that returns a dictionary with the default parameters used for the evaluation
|
381 |
+
validate_data_fn: points to a method that validates the corrct format of the submission
|
382 |
+
"""
|
383 |
+
try:
|
384 |
+
p = dict([s[1:].split('=') for s in sys.argv[1:]])
|
385 |
+
evalParams = default_evaluation_params_fn()
|
386 |
+
if 'p' in p.keys():
|
387 |
+
evalParams.update( p['p'] if isinstance(p['p'], dict) else json.loads(p['p'][1:-1]) )
|
388 |
+
|
389 |
+
validate_data_fn(p['g'], p['s'], evalParams)
|
390 |
+
print('SUCCESS')
|
391 |
+
sys.exit(0)
|
392 |
+
except Exception as e:
|
393 |
+
print(str(e))
|
394 |
+
sys.exit(101)
|
utils/cal_recall/script.py
ADDED
@@ -0,0 +1,325 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
from collections import namedtuple
|
4 |
+
from . import rrc_evaluation_funcs
|
5 |
+
#import Polygon as plg
|
6 |
+
import shapely.geometry as plg
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
|
10 |
+
def default_evaluation_params():
|
11 |
+
"""
|
12 |
+
default_evaluation_params: Default parameters to use for the validation and evaluation.
|
13 |
+
"""
|
14 |
+
return {
|
15 |
+
'IOU_CONSTRAINT': 0.5,
|
16 |
+
'AREA_PRECISION_CONSTRAINT': 0.5,
|
17 |
+
'GT_SAMPLE_NAME_2_ID': '.+_([0-9]+).txt',
|
18 |
+
'DET_SAMPLE_NAME_2_ID': 'res_.+_([0-9]+).txt',
|
19 |
+
'LTRB': False, # LTRB:2points(left,top,right,bottom) or 4 points(x1,y1,x2,y2,x3,y3,x4,y4)
|
20 |
+
'CRLF': False, # Lines are delimited by Windows CRLF format
|
21 |
+
'CONFIDENCES': False, # Detections must include confidence value. AP will be calculated
|
22 |
+
'PER_SAMPLE_RESULTS': True # Generate per sample results and produce data for visualization
|
23 |
+
}
|
24 |
+
|
25 |
+
|
26 |
+
def validate_data(gtFilePath, submFilePath, evaluationParams):
|
27 |
+
"""
|
28 |
+
Method validate_data: validates that all files in the results folder are correct (have the correct name contents).
|
29 |
+
Validates also that there are no missing files in the folder.
|
30 |
+
If some error detected, the method raises the error
|
31 |
+
"""
|
32 |
+
gt = rrc_evaluation_funcs.load_folder_file(gtFilePath, evaluationParams['GT_SAMPLE_NAME_2_ID'])
|
33 |
+
|
34 |
+
subm = rrc_evaluation_funcs.load_folder_file(submFilePath, evaluationParams['DET_SAMPLE_NAME_2_ID'], True)
|
35 |
+
|
36 |
+
# Validate format of GroundTruth
|
37 |
+
for k in gt:
|
38 |
+
rrc_evaluation_funcs.validate_lines_in_file(k, gt[k], evaluationParams['CRLF'], evaluationParams['LTRB'], True)
|
39 |
+
|
40 |
+
# Validate format of results
|
41 |
+
for k in subm:
|
42 |
+
if (k in gt) == False:
|
43 |
+
raise Exception("The sample %s not present in GT" % k)
|
44 |
+
|
45 |
+
rrc_evaluation_funcs.validate_lines_in_file(k, subm[k], evaluationParams['CRLF'], evaluationParams['LTRB'],
|
46 |
+
False, evaluationParams['CONFIDENCES'])
|
47 |
+
|
48 |
+
|
49 |
+
def evaluate_method(gtFilePath, submFilePath, evaluationParams):
|
50 |
+
"""
|
51 |
+
Method evaluate_method: evaluate method and returns the results
|
52 |
+
Results. Dictionary with the following values:
|
53 |
+
- method (required) Global method metrics. Ex: { 'Precision':0.8,'Recall':0.9 }
|
54 |
+
- samples (optional) Per sample metrics. Ex: {'sample1' : { 'Precision':0.8,'Recall':0.9 } , 'sample2' : { 'Precision':0.8,'Recall':0.9 }
|
55 |
+
"""
|
56 |
+
|
57 |
+
def polygon_from_points(points):
|
58 |
+
"""
|
59 |
+
Returns a Polygon object to use with the Polygon2 class from a list of 8 points: x1,y1,x2,y2,x3,y3,x4,y4
|
60 |
+
"""
|
61 |
+
resBoxes = np.empty([1, 8], dtype='int32')
|
62 |
+
resBoxes[0, 0] = int(points[0])
|
63 |
+
resBoxes[0, 4] = int(points[1])
|
64 |
+
resBoxes[0, 1] = int(points[2])
|
65 |
+
resBoxes[0, 5] = int(points[3])
|
66 |
+
resBoxes[0, 2] = int(points[4])
|
67 |
+
resBoxes[0, 6] = int(points[5])
|
68 |
+
resBoxes[0, 3] = int(points[6])
|
69 |
+
resBoxes[0, 7] = int(points[7])
|
70 |
+
pointMat = resBoxes[0].reshape([2, 4]).T
|
71 |
+
return plg.Polygon(pointMat)
|
72 |
+
|
73 |
+
def rectangle_to_polygon(rect):
|
74 |
+
resBoxes = np.empty([1, 8], dtype='int32')
|
75 |
+
resBoxes[0, 0] = int(rect.xmin)
|
76 |
+
resBoxes[0, 4] = int(rect.ymax)
|
77 |
+
resBoxes[0, 1] = int(rect.xmin)
|
78 |
+
resBoxes[0, 5] = int(rect.ymin)
|
79 |
+
resBoxes[0, 2] = int(rect.xmax)
|
80 |
+
resBoxes[0, 6] = int(rect.ymin)
|
81 |
+
resBoxes[0, 3] = int(rect.xmax)
|
82 |
+
resBoxes[0, 7] = int(rect.ymax)
|
83 |
+
|
84 |
+
pointMat = resBoxes[0].reshape([2, 4]).T
|
85 |
+
|
86 |
+
return plg.Polygon(pointMat)
|
87 |
+
|
88 |
+
def rectangle_to_points(rect):
|
89 |
+
points = [int(rect.xmin), int(rect.ymax), int(rect.xmax), int(rect.ymax), int(rect.xmax), int(rect.ymin),
|
90 |
+
int(rect.xmin), int(rect.ymin)]
|
91 |
+
return points
|
92 |
+
|
93 |
+
def get_union(pD, pG):
|
94 |
+
areaA = pD.area;
|
95 |
+
areaB = pG.area;
|
96 |
+
return areaA + areaB - get_intersection(pD, pG);
|
97 |
+
|
98 |
+
def get_intersection_over_union(pD, pG):
|
99 |
+
try:
|
100 |
+
return get_intersection(pD, pG) / get_union(pD, pG);
|
101 |
+
except:
|
102 |
+
return 0
|
103 |
+
|
104 |
+
def get_intersection(pD, pG):
|
105 |
+
pInt = pD & pG
|
106 |
+
if pInt.is_empty:
|
107 |
+
return 0
|
108 |
+
return pInt.area
|
109 |
+
|
110 |
+
def compute_ap(confList, matchList, numGtCare):
|
111 |
+
correct = 0
|
112 |
+
AP = 0
|
113 |
+
if len(confList) > 0:
|
114 |
+
confList = np.array(confList)
|
115 |
+
matchList = np.array(matchList)
|
116 |
+
sorted_ind = np.argsort(-confList)
|
117 |
+
confList = confList[sorted_ind]
|
118 |
+
matchList = matchList[sorted_ind]
|
119 |
+
for n in range(len(confList)):
|
120 |
+
match = matchList[n]
|
121 |
+
if match:
|
122 |
+
correct += 1
|
123 |
+
AP += float(correct) / (n + 1)
|
124 |
+
|
125 |
+
if numGtCare > 0:
|
126 |
+
AP /= numGtCare
|
127 |
+
|
128 |
+
return AP
|
129 |
+
|
130 |
+
perSampleMetrics = {}
|
131 |
+
|
132 |
+
matchedSum = 0
|
133 |
+
|
134 |
+
Rectangle = namedtuple('Rectangle', 'xmin ymin xmax ymax')
|
135 |
+
|
136 |
+
gt = rrc_evaluation_funcs.load_folder_file(gtFilePath, evaluationParams['GT_SAMPLE_NAME_2_ID'])
|
137 |
+
subm = rrc_evaluation_funcs.load_folder_file(submFilePath, evaluationParams['DET_SAMPLE_NAME_2_ID'], True)
|
138 |
+
|
139 |
+
numGlobalCareGt = 0;
|
140 |
+
numGlobalCareDet = 0;
|
141 |
+
|
142 |
+
arrGlobalConfidences = [];
|
143 |
+
arrGlobalMatches = [];
|
144 |
+
|
145 |
+
for resFile in gt:
|
146 |
+
|
147 |
+
gtFile = gt[resFile] # rrc_evaluation_funcs.decode_utf8(gt[resFile])
|
148 |
+
recall = 0
|
149 |
+
precision = 0
|
150 |
+
hmean = 0
|
151 |
+
|
152 |
+
detMatched = 0
|
153 |
+
|
154 |
+
iouMat = np.empty([1, 1])
|
155 |
+
|
156 |
+
gtPols = []
|
157 |
+
detPols = []
|
158 |
+
|
159 |
+
gtPolPoints = []
|
160 |
+
detPolPoints = []
|
161 |
+
|
162 |
+
# Array of Ground Truth Polygons' keys marked as don't Care
|
163 |
+
gtDontCarePolsNum = []
|
164 |
+
# Array of Detected Polygons' matched with a don't Care GT
|
165 |
+
detDontCarePolsNum = []
|
166 |
+
|
167 |
+
pairs = []
|
168 |
+
detMatchedNums = []
|
169 |
+
|
170 |
+
arrSampleConfidences = [];
|
171 |
+
arrSampleMatch = [];
|
172 |
+
sampleAP = 0;
|
173 |
+
|
174 |
+
evaluationLog = ""
|
175 |
+
|
176 |
+
pointsList, _, transcriptionsList = rrc_evaluation_funcs.get_tl_line_values_from_file_contents(gtFile,
|
177 |
+
evaluationParams[
|
178 |
+
'CRLF'],
|
179 |
+
evaluationParams[
|
180 |
+
'LTRB'],
|
181 |
+
True, False)
|
182 |
+
for n in range(len(pointsList)):
|
183 |
+
points = pointsList[n]
|
184 |
+
transcription = transcriptionsList[n]
|
185 |
+
dontCare = transcription == "###"
|
186 |
+
if evaluationParams['LTRB']:
|
187 |
+
gtRect = Rectangle(*points)
|
188 |
+
gtPol = rectangle_to_polygon(gtRect)
|
189 |
+
else:
|
190 |
+
gtPol = polygon_from_points(points)
|
191 |
+
gtPols.append(gtPol)
|
192 |
+
gtPolPoints.append(points)
|
193 |
+
if dontCare:
|
194 |
+
gtDontCarePolsNum.append(len(gtPols) - 1)
|
195 |
+
|
196 |
+
evaluationLog += "GT polygons: " + str(len(gtPols)) + (
|
197 |
+
" (" + str(len(gtDontCarePolsNum)) + " don't care)\n" if len(gtDontCarePolsNum) > 0 else "\n")
|
198 |
+
|
199 |
+
if resFile in subm:
|
200 |
+
|
201 |
+
detFile = subm[resFile] # rrc_evaluation_funcs.decode_utf8(subm[resFile])
|
202 |
+
|
203 |
+
pointsList, confidencesList, _ = rrc_evaluation_funcs.get_tl_line_values_from_file_contents(detFile,
|
204 |
+
evaluationParams[
|
205 |
+
'CRLF'],
|
206 |
+
evaluationParams[
|
207 |
+
'LTRB'],
|
208 |
+
False,
|
209 |
+
evaluationParams[
|
210 |
+
'CONFIDENCES'])
|
211 |
+
for n in range(len(pointsList)):
|
212 |
+
points = pointsList[n]
|
213 |
+
|
214 |
+
if evaluationParams['LTRB']:
|
215 |
+
detRect = Rectangle(*points)
|
216 |
+
detPol = rectangle_to_polygon(detRect)
|
217 |
+
else:
|
218 |
+
detPol = polygon_from_points(points)
|
219 |
+
detPols.append(detPol)
|
220 |
+
detPolPoints.append(points)
|
221 |
+
if len(gtDontCarePolsNum) > 0:
|
222 |
+
for dontCarePol in gtDontCarePolsNum:
|
223 |
+
dontCarePol = gtPols[dontCarePol]
|
224 |
+
intersected_area = get_intersection(dontCarePol, detPol)
|
225 |
+
pdDimensions = detPol.area
|
226 |
+
precision = 0 if pdDimensions == 0 else intersected_area / pdDimensions
|
227 |
+
if (precision > evaluationParams['AREA_PRECISION_CONSTRAINT']):
|
228 |
+
detDontCarePolsNum.append(len(detPols) - 1)
|
229 |
+
break
|
230 |
+
|
231 |
+
evaluationLog += "DET polygons: " + str(len(detPols)) + (
|
232 |
+
" (" + str(len(detDontCarePolsNum)) + " don't care)\n" if len(detDontCarePolsNum) > 0 else "\n")
|
233 |
+
|
234 |
+
if len(gtPols) > 0 and len(detPols) > 0:
|
235 |
+
# Calculate IoU and precision matrixs
|
236 |
+
outputShape = [len(gtPols), len(detPols)]
|
237 |
+
iouMat = np.empty(outputShape)
|
238 |
+
gtRectMat = np.zeros(len(gtPols), np.int8)
|
239 |
+
detRectMat = np.zeros(len(detPols), np.int8)
|
240 |
+
for gtNum in range(len(gtPols)):
|
241 |
+
for detNum in range(len(detPols)):
|
242 |
+
pG = gtPols[gtNum]
|
243 |
+
pD = detPols[detNum]
|
244 |
+
iouMat[gtNum, detNum] = get_intersection_over_union(pD, pG)
|
245 |
+
|
246 |
+
for gtNum in range(len(gtPols)):
|
247 |
+
for detNum in range(len(detPols)):
|
248 |
+
if gtRectMat[gtNum] == 0 and detRectMat[
|
249 |
+
detNum] == 0 and gtNum not in gtDontCarePolsNum and detNum not in detDontCarePolsNum:
|
250 |
+
if iouMat[gtNum, detNum] > evaluationParams['IOU_CONSTRAINT']:
|
251 |
+
gtRectMat[gtNum] = 1
|
252 |
+
detRectMat[detNum] = 1
|
253 |
+
detMatched += 1
|
254 |
+
pairs.append({'gt': gtNum, 'det': detNum})
|
255 |
+
detMatchedNums.append(detNum)
|
256 |
+
evaluationLog += "Match GT #" + str(gtNum) + " with Det #" + str(detNum) + "\n"
|
257 |
+
|
258 |
+
if evaluationParams['CONFIDENCES']:
|
259 |
+
for detNum in range(len(detPols)):
|
260 |
+
if detNum not in detDontCarePolsNum:
|
261 |
+
# we exclude the don't care detections
|
262 |
+
match = detNum in detMatchedNums
|
263 |
+
|
264 |
+
arrSampleConfidences.append(confidencesList[detNum])
|
265 |
+
arrSampleMatch.append(match)
|
266 |
+
|
267 |
+
arrGlobalConfidences.append(confidencesList[detNum]);
|
268 |
+
arrGlobalMatches.append(match);
|
269 |
+
|
270 |
+
numGtCare = (len(gtPols) - len(gtDontCarePolsNum))
|
271 |
+
numDetCare = (len(detPols) - len(detDontCarePolsNum))
|
272 |
+
if numGtCare == 0:
|
273 |
+
recall = float(1)
|
274 |
+
precision = float(0) if numDetCare > 0 else float(1)
|
275 |
+
sampleAP = precision
|
276 |
+
else:
|
277 |
+
recall = float(detMatched) / numGtCare
|
278 |
+
precision = 0 if numDetCare == 0 else float(detMatched) / numDetCare
|
279 |
+
if evaluationParams['CONFIDENCES'] and evaluationParams['PER_SAMPLE_RESULTS']:
|
280 |
+
sampleAP = compute_ap(arrSampleConfidences, arrSampleMatch, numGtCare)
|
281 |
+
|
282 |
+
hmean = 0 if (precision + recall) == 0 else 2.0 * precision * recall / (precision + recall)
|
283 |
+
|
284 |
+
matchedSum += detMatched
|
285 |
+
numGlobalCareGt += numGtCare
|
286 |
+
numGlobalCareDet += numDetCare
|
287 |
+
|
288 |
+
if evaluationParams['PER_SAMPLE_RESULTS']:
|
289 |
+
perSampleMetrics[resFile] = {
|
290 |
+
'precision': precision,
|
291 |
+
'recall': recall,
|
292 |
+
'hmean': hmean,
|
293 |
+
'pairs': pairs,
|
294 |
+
'AP': sampleAP,
|
295 |
+
'iouMat': [] if len(detPols) > 100 else iouMat.tolist(),
|
296 |
+
'gtPolPoints': gtPolPoints,
|
297 |
+
'detPolPoints': detPolPoints,
|
298 |
+
'gtDontCare': gtDontCarePolsNum,
|
299 |
+
'detDontCare': detDontCarePolsNum,
|
300 |
+
'evaluationParams': evaluationParams,
|
301 |
+
'evaluationLog': evaluationLog
|
302 |
+
}
|
303 |
+
|
304 |
+
# Compute MAP and MAR
|
305 |
+
AP = 0
|
306 |
+
if evaluationParams['CONFIDENCES']:
|
307 |
+
AP = compute_ap(arrGlobalConfidences, arrGlobalMatches, numGlobalCareGt)
|
308 |
+
|
309 |
+
methodRecall = 0 if numGlobalCareGt == 0 else float(matchedSum) / numGlobalCareGt
|
310 |
+
methodPrecision = 0 if numGlobalCareDet == 0 else float(matchedSum) / numGlobalCareDet
|
311 |
+
methodHmean = 0 if methodRecall + methodPrecision == 0 else 2 * methodRecall * methodPrecision / (
|
312 |
+
methodRecall + methodPrecision)
|
313 |
+
|
314 |
+
methodMetrics = {'precision': methodPrecision, 'recall': methodRecall, 'hmean': methodHmean, 'AP': AP}
|
315 |
+
|
316 |
+
resDict = {'calculated': True, 'Message': '', 'method': methodMetrics, 'per_sample': perSampleMetrics}
|
317 |
+
|
318 |
+
return resDict;
|
319 |
+
|
320 |
+
|
321 |
+
def cal_recall_precison_f1(gt_path, result_path, show_result=False):
|
322 |
+
p = {'g': gt_path, 's': result_path}
|
323 |
+
result = rrc_evaluation_funcs.main_evaluation(p, default_evaluation_params, validate_data, evaluate_method,
|
324 |
+
show_result)
|
325 |
+
return result['method']
|
utils/make_trainfile.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
import os
|
4 |
+
import glob
|
5 |
+
import pathlib
|
6 |
+
|
7 |
+
data_path = r'E:\zj\dataset\icdar2015\test'
|
8 |
+
# data_path/img 存放图片
|
9 |
+
# data_path/gt 存放标签文件
|
10 |
+
|
11 |
+
f_w = open(os.path.join(data_path, 'test.txt'), 'w', encoding='utf8')
|
12 |
+
for img_path in glob.glob(data_path + '/img/*.jpg', recursive=True):
|
13 |
+
d = pathlib.Path(img_path)
|
14 |
+
label_path = os.path.join(data_path, 'gt', ('gt_' + str(d.stem) + '.txt'))
|
15 |
+
if os.path.exists(img_path) and os.path.exists(label_path):
|
16 |
+
print(img_path, label_path)
|
17 |
+
else:
|
18 |
+
print('不存在', img_path, label_path)
|
19 |
+
f_w.write('{}\t{}\n'.format(img_path, label_path))
|
20 |
+
f_w.close()
|
utils/metrics.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from score written by wkentaro
|
2 |
+
# https://github.com/wkentaro/pytorch-fcn/blob/master/torchfcn/utils.py
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
|
7 |
+
class runningScore(object):
|
8 |
+
|
9 |
+
def __init__(self, n_classes):
|
10 |
+
self.n_classes = n_classes
|
11 |
+
self.confusion_matrix = np.zeros((n_classes, n_classes))
|
12 |
+
|
13 |
+
def _fast_hist(self, label_true, label_pred, n_class):
|
14 |
+
mask = (label_true >= 0) & (label_true < n_class)
|
15 |
+
|
16 |
+
if np.sum((label_pred[mask] < 0)) > 0:
|
17 |
+
print(label_pred[label_pred < 0])
|
18 |
+
hist = np.bincount(n_class * label_true[mask].astype(int) +
|
19 |
+
label_pred[mask], minlength=n_class ** 2).reshape(n_class, n_class)
|
20 |
+
return hist
|
21 |
+
|
22 |
+
def update(self, label_trues, label_preds):
|
23 |
+
# print label_trues.dtype, label_preds.dtype
|
24 |
+
for lt, lp in zip(label_trues, label_preds):
|
25 |
+
try:
|
26 |
+
self.confusion_matrix += self._fast_hist(lt.flatten(), lp.flatten(), self.n_classes)
|
27 |
+
except:
|
28 |
+
pass
|
29 |
+
|
30 |
+
def get_scores(self):
|
31 |
+
"""Returns accuracy score evaluation result.
|
32 |
+
- overall accuracy
|
33 |
+
- mean accuracy
|
34 |
+
- mean IU
|
35 |
+
- fwavacc
|
36 |
+
"""
|
37 |
+
hist = self.confusion_matrix
|
38 |
+
acc = np.diag(hist).sum() / (hist.sum() + 0.0001)
|
39 |
+
acc_cls = np.diag(hist) / (hist.sum(axis=1) + 0.0001)
|
40 |
+
acc_cls = np.nanmean(acc_cls)
|
41 |
+
iu = np.diag(hist) / (hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist) + 0.0001)
|
42 |
+
mean_iu = np.nanmean(iu)
|
43 |
+
freq = hist.sum(axis=1) / (hist.sum() + 0.0001)
|
44 |
+
fwavacc = (freq[freq > 0] * iu[freq > 0]).sum()
|
45 |
+
cls_iu = dict(zip(range(self.n_classes), iu))
|
46 |
+
|
47 |
+
return {'Overall Acc': acc,
|
48 |
+
'Mean Acc': acc_cls,
|
49 |
+
'FreqW Acc': fwavacc,
|
50 |
+
'Mean IoU': mean_iu, }, cls_iu
|
51 |
+
|
52 |
+
def reset(self):
|
53 |
+
self.confusion_matrix = np.zeros((self.n_classes, self.n_classes))
|
utils/schedulers.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.optim.lr_scheduler import _LRScheduler
|
2 |
+
|
3 |
+
|
4 |
+
class ConstantLR(_LRScheduler):
|
5 |
+
def __init__(self, optimizer, last_epoch=-1):
|
6 |
+
super(ConstantLR, self).__init__(optimizer, last_epoch)
|
7 |
+
|
8 |
+
def get_lr(self):
|
9 |
+
return [base_lr for base_lr in self.base_lrs]
|
10 |
+
|
11 |
+
|
12 |
+
class PolynomialLR(_LRScheduler):
|
13 |
+
def __init__(self, optimizer, max_iter, power=0.9, last_epoch=-1):
|
14 |
+
self.max_iter = max_iter
|
15 |
+
self.power = power
|
16 |
+
super(PolynomialLR, self).__init__(optimizer, last_epoch)
|
17 |
+
|
18 |
+
def get_lr(self):
|
19 |
+
factor = (1 - self.last_epoch / float(self.max_iter)) ** self.power
|
20 |
+
return [base_lr * factor for base_lr in self.base_lrs]
|
21 |
+
|
22 |
+
|
23 |
+
class WarmUpLR(_LRScheduler):
|
24 |
+
def __init__(
|
25 |
+
self, optimizer, scheduler, mode="linear", warmup_iters=100, gamma=0.2, last_epoch=-1
|
26 |
+
):
|
27 |
+
self.mode = mode
|
28 |
+
self.scheduler = scheduler
|
29 |
+
self.warmup_iters = warmup_iters
|
30 |
+
self.gamma = gamma
|
31 |
+
super(WarmUpLR, self).__init__(optimizer, last_epoch)
|
32 |
+
|
33 |
+
def get_lr(self):
|
34 |
+
cold_lrs = self.scheduler.get_lr()
|
35 |
+
|
36 |
+
if self.last_epoch < self.warmup_iters:
|
37 |
+
if self.mode == "linear":
|
38 |
+
alpha = self.last_epoch / float(self.warmup_iters)
|
39 |
+
factor = self.gamma * (1 - alpha) + alpha
|
40 |
+
|
41 |
+
elif self.mode == "constant":
|
42 |
+
factor = self.gamma
|
43 |
+
else:
|
44 |
+
raise KeyError("WarmUp type {} not implemented".format(self.mode))
|
45 |
+
|
46 |
+
return [factor * base_lr for base_lr in cold_lrs]
|
47 |
+
|
48 |
+
return cold_lrs
|
49 |
+
|
50 |
+
if __name__ == '__main__':
|
51 |
+
import torch
|
52 |
+
from torchvision.models import resnet18
|
53 |
+
max_iter = 600 * 125
|
54 |
+
model = resnet18()
|
55 |
+
op = torch.optim.SGD(model.parameters(),0.001)
|
56 |
+
sc = PolynomialLR(op,max_iter)
|
57 |
+
lr = []
|
58 |
+
for i in range(max_iter):
|
59 |
+
sc.step()
|
60 |
+
print(i,sc.last_epoch,sc.get_lr()[0])
|
61 |
+
lr.append(sc.get_lr()[0])
|
62 |
+
from matplotlib import pyplot as plt
|
63 |
+
plt.plot(list(range(max_iter)),lr)
|
64 |
+
plt.show()
|
utils/util.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
import time
|
4 |
+
import json
|
5 |
+
import cv2
|
6 |
+
import torch
|
7 |
+
import numpy as np
|
8 |
+
import matplotlib.pyplot as plt
|
9 |
+
|
10 |
+
|
11 |
+
def setup_logger(log_file_path: str = None):
|
12 |
+
import logging
|
13 |
+
from colorlog import ColoredFormatter
|
14 |
+
logging.basicConfig(filename=log_file_path,
|
15 |
+
format='%(asctime)s %(levelname)-8s %(filename)s[line:%(lineno)d]: %(message)s',
|
16 |
+
# 定义输出log的格式
|
17 |
+
datefmt='%Y-%m-%d %H:%M:%S', )
|
18 |
+
"""Return a logger with a default ColoredFormatter."""
|
19 |
+
formatter = ColoredFormatter(
|
20 |
+
"%(asctime)s %(log_color)s%(levelname)-8s %(reset)s %(filename)s[line:%(lineno)d]: %(message)s",
|
21 |
+
datefmt='%Y-%m-%d %H:%M:%S',
|
22 |
+
reset=True,
|
23 |
+
log_colors={
|
24 |
+
'DEBUG': 'blue',
|
25 |
+
'INFO': 'green',
|
26 |
+
'WARNING': 'yellow',
|
27 |
+
'ERROR': 'red',
|
28 |
+
'CRITICAL': 'red',
|
29 |
+
})
|
30 |
+
|
31 |
+
logger = logging.getLogger('PAN')
|
32 |
+
handler = logging.StreamHandler()
|
33 |
+
handler.setFormatter(formatter)
|
34 |
+
logger.addHandler(handler)
|
35 |
+
logger.setLevel(logging.DEBUG)
|
36 |
+
logger.info('logger init finished')
|
37 |
+
return logger
|
38 |
+
|
39 |
+
|
40 |
+
# --exeTime
|
41 |
+
def exe_time(func):
|
42 |
+
def newFunc(*args, **args2):
|
43 |
+
t0 = time.time()
|
44 |
+
back = func(*args, **args2)
|
45 |
+
print("{} cost {:.3f}s".format(func.__name__, time.time() - t0))
|
46 |
+
return back
|
47 |
+
|
48 |
+
return newFunc
|
49 |
+
|
50 |
+
|
51 |
+
def save_json(data, json_path):
|
52 |
+
with open(json_path, mode='w', encoding='utf8') as f:
|
53 |
+
json.dump(data, f, indent=4)
|
54 |
+
|
55 |
+
|
56 |
+
def load_json(json_path):
|
57 |
+
with open(json_path, mode='r', encoding='utf8') as f:
|
58 |
+
data = json.load(f)
|
59 |
+
return data
|
60 |
+
|
61 |
+
|
62 |
+
def show_img(imgs: np.ndarray, color=False):
|
63 |
+
if (len(imgs.shape) == 3 and color) or (len(imgs.shape) == 2 and not color):
|
64 |
+
imgs = np.expand_dims(imgs, axis=0)
|
65 |
+
for img in imgs:
|
66 |
+
plt.figure()
|
67 |
+
plt.imshow(img, cmap=None if color else 'gray')
|
68 |
+
|
69 |
+
|
70 |
+
def draw_bbox(img_path, result, color=(128, 240, 128), thickness=3):
|
71 |
+
if isinstance(img_path, str):
|
72 |
+
img_path = cv2.imread(img_path)
|
73 |
+
# img_path = cv2.cvtColor(img_path, cv2.COLOR_BGR2RGB)
|
74 |
+
img_path = img_path.copy()
|
75 |
+
for point in result:
|
76 |
+
point = point.astype(int)
|
77 |
+
cv2.line(img_path, tuple(point[0]), tuple(point[1]), color, thickness)
|
78 |
+
cv2.line(img_path, tuple(point[1]), tuple(point[2]), color, thickness)
|
79 |
+
cv2.line(img_path, tuple(point[2]), tuple(point[3]), color, thickness)
|
80 |
+
cv2.line(img_path, tuple(point[3]), tuple(point[0]), color, thickness)
|
81 |
+
return img_path
|
82 |
+
|
83 |
+
|
84 |
+
def cal_text_score(texts, gt_texts, training_masks, running_metric_text):
|
85 |
+
training_masks = training_masks.data.cpu().numpy()
|
86 |
+
pred_text = torch.sigmoid(texts).data.cpu().numpy() * training_masks
|
87 |
+
pred_text[pred_text <= 0.5] = 0
|
88 |
+
pred_text[pred_text > 0.5] = 1
|
89 |
+
pred_text = pred_text.astype(np.int32)
|
90 |
+
gt_text = gt_texts.data.cpu().numpy() * training_masks
|
91 |
+
gt_text = gt_text.astype(np.int32)
|
92 |
+
running_metric_text.update(gt_text, pred_text)
|
93 |
+
score_text, _ = running_metric_text.get_scores()
|
94 |
+
return score_text
|
95 |
+
|
96 |
+
|
97 |
+
def cal_kernel_score(kernel, gt_kernel, gt_texts, training_masks, running_metric_kernel):
|
98 |
+
mask = (gt_texts * training_masks.float()).data.cpu().numpy()
|
99 |
+
pred_kernel = torch.sigmoid(kernel).data.cpu().numpy()
|
100 |
+
pred_kernel[pred_kernel <= 0.5] = 0
|
101 |
+
pred_kernel[pred_kernel > 0.5] = 1
|
102 |
+
pred_kernel = (pred_kernel * mask).astype(np.int32)
|
103 |
+
gt_kernel = gt_kernel.data.cpu().numpy()
|
104 |
+
gt_kernel = (gt_kernel * mask).astype(np.int32)
|
105 |
+
running_metric_kernel.update(gt_kernel, pred_kernel)
|
106 |
+
score_kernel, _ = running_metric_kernel.get_scores()
|
107 |
+
return score_kernel
|
108 |
+
|
109 |
+
|
110 |
+
def order_points_clockwise(pts):
|
111 |
+
rect = np.zeros((4, 2), dtype="float32")
|
112 |
+
s = pts.sum(axis=1)
|
113 |
+
rect[0] = pts[np.argmin(s)]
|
114 |
+
rect[2] = pts[np.argmax(s)]
|
115 |
+
diff = np.diff(pts, axis=1)
|
116 |
+
rect[1] = pts[np.argmin(diff)]
|
117 |
+
rect[3] = pts[np.argmax(diff)]
|
118 |
+
return rect
|
119 |
+
|
120 |
+
|
121 |
+
def order_points_clockwise_list(pts):
|
122 |
+
pts = pts.tolist()
|
123 |
+
pts.sort(key=lambda x: (x[1], x[0]))
|
124 |
+
pts[:2] = sorted(pts[:2], key=lambda x: x[0])
|
125 |
+
pts[2:] = sorted(pts[2:], key=lambda x: -x[0])
|
126 |
+
pts = np.array(pts)
|
127 |
+
return pts
|
128 |
+
|
129 |
+
|
130 |
+
if __name__ == '__main__':
|
131 |
+
#box = np.array([382, 1080, 443, 999, 423, 1014, 362, 1095]).reshape(-1, 2)
|
132 |
+
box = np.array([0, 4, 2, 2, 0, 8, 4, 4]).reshape(-1, 2)
|
133 |
+
# box = np.array([0, 0, 2, 2, 0, 4, 4, 4]).reshape(-1, 2)
|
134 |
+
from scipy.spatial import ConvexHull
|
135 |
+
print(box)
|
136 |
+
print(order_points_clockwise(box))
|
137 |
+
print(order_points_clockwise_list(box))
|