Spaces:
Sleeping
Sleeping
# -*- coding: utf-8 -*- | |
import torch | |
from torchvision import transforms | |
import os | |
import cv2 | |
import time | |
import numpy as np | |
import pyclipper | |
import gradio as gr | |
from models import get_model | |
from utils.util import show_img, draw_bbox | |
class Pytorch_model: | |
def __init__(self, model_path, gpu_id=None): | |
''' | |
初始化pytorch模型 | |
:param model_path: 模型地址(可以是模型的参数或者参数和计算图一起保存的文件) | |
:param gpu_id: 在哪一块gpu上运行 | |
''' | |
self.gpu_id = gpu_id | |
if self.gpu_id is not None and isinstance(self.gpu_id, int) and torch.cuda.is_available(): | |
self.device = torch.device("cuda:%s" % self.gpu_id) | |
checkpoint = torch.load(model_path) | |
else: | |
self.device = torch.device("cpu") | |
checkpoint = torch.load(model_path, map_location='cpu') | |
print('device:', self.device) | |
config = checkpoint['config'] | |
config['arch']['args']['pretrained'] = False | |
self.net = get_model(config) | |
self.img_channel = config['data_loader']['args']['dataset']['img_channel'] | |
self.net.load_state_dict(checkpoint['state_dict']) ## load weights | |
self.net.to(self.device) | |
self.net.eval() | |
def predict(self, img, short_size: int = 736, min_area: int = 100): | |
''' | |
对传入的图像进行预测,支持图像地址, opencv读取图片,偏慢 | |
:param img: the image as an np array | |
:param short_size: | |
:param min_area: 小于该尺度的bbox忽略 | |
:return: | |
''' | |
img = cv2.imread(img) | |
if self.img_channel == 3: | |
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
h, w = img.shape[:2] | |
scale = short_size / min(h, w) | |
img = cv2.resize(img, None, fx=scale, fy=scale) | |
tensor = transforms.ToTensor()(img) | |
tensor = tensor.unsqueeze_(0) | |
tensor = tensor.to(self.device) | |
with torch.no_grad(): | |
torch.cuda.synchronize(self.device) | |
start = time.time() | |
preds = self.net(tensor)[0] | |
torch.cuda.synchronize(self.device) | |
scale = (preds.shape[2] / w, preds.shape[1] / h) | |
t = time.time() - start | |
'''inference''' | |
start = time.time() | |
prob_map, thres_map = preds[0], preds[1] | |
## Step 1: Use threshold to get the binary map | |
thr = 0.2 | |
out = (prob_map > thr).float() * 255 | |
out = out.data.cpu().numpy().astype(np.uint8) | |
# cv2.imwrite('c_bin_map.png', out) | |
## Step 2: Connected components findContours | |
contours, hierarchy = cv2.findContours(out, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
contours = [(i / scale).astype(int) for i in contours if len(i)>=4] | |
# Step 3: Dilate the shrunk region (not necessary) | |
ratio_prime = 1.5 | |
dilated_polys = [] | |
for poly in contours: | |
poly = poly[:,0,:] | |
D_prime = cv2.contourArea(poly) * ratio_prime / cv2.arcLength(poly, True) # formula(10) in the thesis | |
pco = pyclipper.PyclipperOffset() | |
pco.AddPath(poly, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON) | |
dilated_poly = np.array(pco.Execute(D_prime)) | |
if dilated_poly.size == 0 or dilated_poly.dtype != int or len(dilated_poly) != 1: | |
continue | |
dilated_polys.append(dilated_poly) | |
boxes_list = [] | |
for cnt in dilated_polys: | |
# print('=============') | |
# print(cnt) | |
# print(len(cnt)) | |
if cv2.contourArea(cnt) < min_area: | |
continue | |
rect = cv2.minAreaRect(cnt) | |
box = (cv2.boxPoints(rect)).astype(int) | |
boxes_list.append(box) | |
t = time.time() - start + t | |
boxes_list = np.array(boxes_list) | |
return dilated_polys, boxes_list, t | |
def run(img): | |
# os.environ["CUDA_VISIBLE_DEVICES"] = "5" | |
model_path = 'pre_trained/MobileNetv2_best_loss.pth' | |
# init model | |
model = Pytorch_model(model_path, gpu_id = None) ## set GPU id or None if you only have cpu | |
contours, boxes_list, t = model.predict(img) | |
print('Time: %.4f' %t) | |
img = img[:, :, ::-1] | |
imgc = img.copy() | |
cv2.drawContours(imgc, contours, -1, (22,222,22), 2, cv2.LINE_AA) | |
return imgc | |
# cv2.imwrite('contour.png', imgc) | |
# img = draw_bbox(img, boxes_list) | |
# cv2.imwrite('predict.jpg', img) | |
if __name__ == '__main__': | |
iface = gr.Interface(fn=run, | |
title="CJK Font Detection Using DBNet", | |
description="Gives an image containing CJK fonts, we will mark the texts out!", | |
inputs="image", | |
outputs="image") | |
iface.launch() | |