Spaces:
Sleeping
Sleeping
File size: 4,853 Bytes
0742dfe |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
# -*- coding: utf-8 -*-
import torch
from torchvision import transforms
import os
import cv2
import time
import numpy as np
import pyclipper
import gradio as gr
from models import get_model
from utils.util import show_img, draw_bbox
class Pytorch_model:
def __init__(self, model_path, gpu_id=None):
'''
初始化pytorch模型
:param model_path: 模型地址(可以是模型的参数或者参数和计算图一起保存的文件)
:param gpu_id: 在哪一块gpu上运行
'''
self.gpu_id = gpu_id
if self.gpu_id is not None and isinstance(self.gpu_id, int) and torch.cuda.is_available():
self.device = torch.device("cuda:%s" % self.gpu_id)
checkpoint = torch.load(model_path)
else:
self.device = torch.device("cpu")
checkpoint = torch.load(model_path, map_location='cpu')
print('device:', self.device)
config = checkpoint['config']
config['arch']['args']['pretrained'] = False
self.net = get_model(config)
self.img_channel = config['data_loader']['args']['dataset']['img_channel']
self.net.load_state_dict(checkpoint['state_dict']) ## load weights
self.net.to(self.device)
self.net.eval()
def predict(self, img, short_size: int = 736, min_area: int = 100):
'''
对传入的图像进行预测,支持图像地址, opencv读取图片,偏慢
:param img: the image as an np array
:param short_size:
:param min_area: 小于该尺度的bbox忽略
:return:
'''
img = cv2.imread(img)
if self.img_channel == 3:
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
h, w = img.shape[:2]
scale = short_size / min(h, w)
img = cv2.resize(img, None, fx=scale, fy=scale)
tensor = transforms.ToTensor()(img)
tensor = tensor.unsqueeze_(0)
tensor = tensor.to(self.device)
with torch.no_grad():
torch.cuda.synchronize(self.device)
start = time.time()
preds = self.net(tensor)[0]
torch.cuda.synchronize(self.device)
scale = (preds.shape[2] / w, preds.shape[1] / h)
t = time.time() - start
'''inference'''
start = time.time()
prob_map, thres_map = preds[0], preds[1]
## Step 1: Use threshold to get the binary map
thr = 0.2
out = (prob_map > thr).float() * 255
out = out.data.cpu().numpy().astype(np.uint8)
# cv2.imwrite('c_bin_map.png', out)
## Step 2: Connected components findContours
contours, hierarchy = cv2.findContours(out, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
contours = [(i / scale).astype(int) for i in contours if len(i)>=4]
# Step 3: Dilate the shrunk region (not necessary)
ratio_prime = 1.5
dilated_polys = []
for poly in contours:
poly = poly[:,0,:]
D_prime = cv2.contourArea(poly) * ratio_prime / cv2.arcLength(poly, True) # formula(10) in the thesis
pco = pyclipper.PyclipperOffset()
pco.AddPath(poly, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
dilated_poly = np.array(pco.Execute(D_prime))
if dilated_poly.size == 0 or dilated_poly.dtype != int or len(dilated_poly) != 1:
continue
dilated_polys.append(dilated_poly)
boxes_list = []
for cnt in dilated_polys:
# print('=============')
# print(cnt)
# print(len(cnt))
if cv2.contourArea(cnt) < min_area:
continue
rect = cv2.minAreaRect(cnt)
box = (cv2.boxPoints(rect)).astype(int)
boxes_list.append(box)
t = time.time() - start + t
boxes_list = np.array(boxes_list)
return dilated_polys, boxes_list, t
def run(img):
# os.environ["CUDA_VISIBLE_DEVICES"] = "5"
model_path = 'pre_trained/MobileNetv2_best_loss.pth'
# init model
model = Pytorch_model(model_path, gpu_id = None) ## set GPU id or None if you only have cpu
contours, boxes_list, t = model.predict(img)
print('Time: %.4f' %t)
img = img[:, :, ::-1]
imgc = img.copy()
cv2.drawContours(imgc, contours, -1, (22,222,22), 2, cv2.LINE_AA)
return imgc
# cv2.imwrite('contour.png', imgc)
# img = draw_bbox(img, boxes_list)
# cv2.imwrite('predict.jpg', img)
if __name__ == '__main__':
iface = gr.Interface(fn=run,
title="CJK Font Detection Using DBNet",
description="Gives an image containing CJK fonts, we will mark the texts out!",
inputs="image",
outputs="image")
iface.launch()
|