File size: 4,853 Bytes
0742dfe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
# -*- coding: utf-8 -*-

import torch
from torchvision import transforms
import os
import cv2
import time
import numpy as np

import pyclipper
import gradio as gr

from models import get_model
from utils.util import show_img, draw_bbox

class Pytorch_model:
    def __init__(self, model_path, gpu_id=None):
        '''
        初始化pytorch模型
        :param model_path: 模型地址(可以是模型的参数或者参数和计算图一起保存的文件)
        :param gpu_id: 在哪一块gpu上运行
        '''
        self.gpu_id = gpu_id
        
        if self.gpu_id is not None and isinstance(self.gpu_id, int) and torch.cuda.is_available():
            self.device = torch.device("cuda:%s" % self.gpu_id)
            checkpoint = torch.load(model_path)
        else:
            self.device = torch.device("cpu")
            checkpoint = torch.load(model_path, map_location='cpu')
        print('device:', self.device)

        config = checkpoint['config']
        config['arch']['args']['pretrained'] = False
        self.net = get_model(config)

        self.img_channel = config['data_loader']['args']['dataset']['img_channel']
        self.net.load_state_dict(checkpoint['state_dict']) ## load weights
        self.net.to(self.device)
        self.net.eval()

    def predict(self, img, short_size: int = 736, min_area: int = 100):
        '''
        对传入的图像进行预测,支持图像地址, opencv读取图片,偏慢
        :param img: the image as an np array 
        :param short_size: 
        :param min_area: 小于该尺度的bbox忽略
        :return:
        '''
        img = cv2.imread(img)
        if self.img_channel == 3:
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        h, w = img.shape[:2]
        scale = short_size / min(h, w)
        img = cv2.resize(img, None, fx=scale, fy=scale)

        tensor = transforms.ToTensor()(img)
        tensor = tensor.unsqueeze_(0)

        tensor = tensor.to(self.device)
        with torch.no_grad():
            torch.cuda.synchronize(self.device)
            start = time.time()
            preds = self.net(tensor)[0]
            torch.cuda.synchronize(self.device)
            scale = (preds.shape[2] / w, preds.shape[1] / h)
            t = time.time() - start
            
        '''inference'''
        start = time.time()
        prob_map, thres_map = preds[0], preds[1]
        
        ## Step 1: Use threshold to get the binary map 
        thr = 0.2
        out = (prob_map > thr).float() * 255
        out = out.data.cpu().numpy().astype(np.uint8)
        # cv2.imwrite('c_bin_map.png', out)   
        
        ## Step 2: Connected components findContours
        contours, hierarchy = cv2.findContours(out, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        contours = [(i / scale).astype(int) for i in contours if len(i)>=4] 
                
        # Step 3: Dilate the shrunk region (not necessary)    
        ratio_prime = 1.5
        dilated_polys = []
        for poly in contours:
            poly = poly[:,0,:]
            D_prime = cv2.contourArea(poly) * ratio_prime / cv2.arcLength(poly, True) # formula(10) in the thesis
            pco = pyclipper.PyclipperOffset()
            pco.AddPath(poly, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
            dilated_poly = np.array(pco.Execute(D_prime))
            if dilated_poly.size == 0 or dilated_poly.dtype != int or len(dilated_poly) != 1:
                continue
            dilated_polys.append(dilated_poly)
            
        boxes_list = []
        for cnt in dilated_polys:
            # print('=============')
            # print(cnt)
            # print(len(cnt))
            if cv2.contourArea(cnt) < min_area:
                continue
            rect = cv2.minAreaRect(cnt)
            box = (cv2.boxPoints(rect)).astype(int)
            boxes_list.append(box)
        
        t = time.time() - start + t
            
        boxes_list = np.array(boxes_list)
        return dilated_polys, boxes_list, t


def run(img):
    # os.environ["CUDA_VISIBLE_DEVICES"] = "5"
    model_path = 'pre_trained/MobileNetv2_best_loss.pth'    
    
    # init model 
    model = Pytorch_model(model_path, gpu_id = None)  ## set GPU id or None if you only have cpu
    contours, boxes_list, t = model.predict(img)
    print('Time: %.4f' %t)
 
    img = img[:, :, ::-1]
    imgc = img.copy()
    cv2.drawContours(imgc, contours, -1, (22,222,22), 2, cv2.LINE_AA)
    return imgc
    # cv2.imwrite('contour.png', imgc)
    # img = draw_bbox(img, boxes_list)
    # cv2.imwrite('predict.jpg', img)

if __name__ == '__main__':
    iface = gr.Interface(fn=run,
        title="CJK Font Detection Using DBNet",
        description="Gives an image containing CJK fonts, we will mark the texts out!",
        inputs="image", 
        outputs="image")
    iface.launch()