File size: 8,079 Bytes
9dce458 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 |
import os
import shutil
import numpy as np
import einops
from typing import Union, Tuple
import cv2
import torch
from .ctd_utils.basemodel import TextDetBase, TextDetBaseDNN
from .ctd_utils.utils.yolov5_utils import non_max_suppression
from .ctd_utils.utils.db_utils import SegDetectorRepresenter
from .ctd_utils.utils.imgproc_utils import letterbox
from .ctd_utils.textmask import REFINEMASK_INPAINT, refine_mask
from .common import OfflineDetector
from ..utils import Quadrilateral, det_rearrange_forward
def preprocess_img(img, input_size=(1024, 1024), device='cpu', bgr2rgb=True, half=False, to_tensor=True):
if bgr2rgb:
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img_in, ratio, (dw, dh) = letterbox(img, new_shape=input_size, auto=False, stride=64)
if to_tensor:
img_in = img_in.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
img_in = np.array([np.ascontiguousarray(img_in)]).astype(np.float32) / 255
if to_tensor:
img_in = torch.from_numpy(img_in).to(device)
if half:
img_in = img_in.half()
return img_in, ratio, int(dw), int(dh)
def postprocess_mask(img: Union[torch.Tensor, np.ndarray], thresh=None):
# img = img.permute(1, 2, 0)
if isinstance(img, torch.Tensor):
img = img.squeeze_()
if img.device != 'cpu':
img = img.detach().cpu()
img = img.numpy()
else:
img = img.squeeze()
if thresh is not None:
img = img > thresh
img = img * 255
# if isinstance(img, torch.Tensor):
return img.astype(np.uint8)
def postprocess_yolo(det, conf_thresh, nms_thresh, resize_ratio, sort_func=None):
det = non_max_suppression(det, conf_thresh, nms_thresh)[0]
# bbox = det[..., 0:4]
if det.device != 'cpu':
det = det.detach_().cpu().numpy()
det[..., [0, 2]] = det[..., [0, 2]] * resize_ratio[0]
det[..., [1, 3]] = det[..., [1, 3]] * resize_ratio[1]
if sort_func is not None:
det = sort_func(det)
blines = det[..., 0:4].astype(np.int32)
confs = np.round(det[..., 4], 3)
cls = det[..., 5].astype(np.int32)
return blines, cls, confs
class ComicTextDetector(OfflineDetector):
_MODEL_MAPPING = {
'model-cuda': {
'url': 'https://github.com/zyddnys/manga-image-translator/releases/download/beta-0.3/comictextdetector.pt',
'hash': '1f90fa60aeeb1eb82e2ac1167a66bf139a8a61b8780acd351ead55268540cccb',
'file': '.',
},
'model-cpu': {
'url': 'https://github.com/zyddnys/manga-image-translator/releases/download/beta-0.3/comictextdetector.pt.onnx',
'hash': '1a86ace74961413cbd650002e7bb4dcec4980ffa21b2f19b86933372071d718f',
'file': '.',
},
}
def __init__(self, *args, **kwargs):
os.makedirs(self.model_dir, exist_ok=True)
if os.path.exists('comictextdetector.pt'):
shutil.move('comictextdetector.pt', self._get_file_path('comictextdetector.pt'))
if os.path.exists('comictextdetector.pt.onnx'):
shutil.move('comictextdetector.pt.onnx', self._get_file_path('comictextdetector.pt.onnx'))
super().__init__(*args, **kwargs)
async def _load(self, device: str, input_size=1024, half=False, nms_thresh=0.35, conf_thresh=0.4):
self.device = device
if self.device == 'cuda' or self.device == 'mps':
self.model = TextDetBase(self._get_file_path('comictextdetector.pt'), device=self.device, act='leaky')
self.model.to(self.device)
self.backend = 'torch'
else:
model_path = self._get_file_path('comictextdetector.pt.onnx')
self.model = cv2.dnn.readNetFromONNX(model_path)
self.model = TextDetBaseDNN(input_size, model_path)
self.backend = 'opencv'
if isinstance(input_size, int):
input_size = (input_size, input_size)
self.input_size = input_size
self.half = half
self.conf_thresh = conf_thresh
self.nms_thresh = nms_thresh
self.seg_rep = SegDetectorRepresenter(thresh=0.3)
async def _unload(self):
del self.model
def det_batch_forward_ctd(self, batch: np.ndarray, device: str) -> Tuple[np.ndarray, np.ndarray]:
if isinstance(self.model, TextDetBase):
batch = einops.rearrange(batch.astype(np.float32) / 255., 'n h w c -> n c h w')
batch = torch.from_numpy(batch).to(device)
_, mask, lines = self.model(batch)
mask = mask.detach().cpu().numpy()
lines = lines.detach().cpu().numpy()
elif isinstance(self.model, TextDetBaseDNN):
mask_lst, line_lst = [], []
for b in batch:
_, mask, lines = self.model(b)
if mask.shape[1] == 2: # some version of opencv spit out reversed result
tmp = mask
mask = lines
lines = tmp
mask_lst.append(mask)
line_lst.append(lines)
lines, mask = np.concatenate(line_lst, 0), np.concatenate(mask_lst, 0)
else:
raise NotImplementedError
return lines, mask
@torch.no_grad()
async def _infer(self, image: np.ndarray, detect_size: int, text_threshold: float, box_threshold: float,
unclip_ratio: float, verbose: bool = False):
# keep_undetected_mask = False
# refine_mode = REFINEMASK_INPAINT
im_h, im_w = image.shape[:2]
lines_map, mask = det_rearrange_forward(image, self.det_batch_forward_ctd, self.input_size[0], 4, self.device, verbose)
# blks = []
# resize_ratio = [1, 1]
if lines_map is None:
img_in, ratio, dw, dh = preprocess_img(image, input_size=self.input_size, device=self.device, half=self.half, to_tensor=self.backend=='torch')
blks, mask, lines_map = self.model(img_in)
if self.backend == 'opencv':
if mask.shape[1] == 2: # some version of opencv spit out reversed result
tmp = mask
mask = lines_map
lines_map = tmp
mask = mask.squeeze()
# resize_ratio = (im_w / (self.input_size[0] - dw), im_h / (self.input_size[1] - dh))
# blks = postprocess_yolo(blks, self.conf_thresh, self.nms_thresh, resize_ratio)
mask = mask[..., :mask.shape[0]-dh, :mask.shape[1]-dw]
lines_map = lines_map[..., :lines_map.shape[2]-dh, :lines_map.shape[3]-dw]
mask = postprocess_mask(mask)
lines, scores = self.seg_rep(None, lines_map, height=im_h, width=im_w)
box_thresh = 0.6
idx = np.where(scores[0] > box_thresh)
lines, scores = lines[0][idx], scores[0][idx]
# map output to input img
mask = cv2.resize(mask, (im_w, im_h), interpolation=cv2.INTER_LINEAR)
# if lines.size == 0:
# lines = []
# else:
# lines = lines.astype(np.int32)
# YOLO was used for finding bboxes which to order the lines into. This is now solved
# through the textline merger, which seems to work more reliably.
# The YOLO language detection seems unnecessary as it could never be as good as
# using the OCR extracted string directly.
# Doing it for increasing the textline merge accuracy doesn't really work either,
# as the merge could be postponed until after the OCR finishes.
textlines = [Quadrilateral(pts.astype(int), '', score) for pts, score in zip(lines, scores)]
mask_refined = refine_mask(image, mask, textlines, refine_mode=None)
return textlines, mask_refined, None
# blk_list = group_output(blks, lines, im_w, im_h, mask)
# mask_refined = refine_mask(image, mask, blk_list, refine_mode=refine_mode)
# if keep_undetected_mask:
# mask_refined = refine_undetected_mask(image, mask, mask_refined, blk_list, refine_mode=refine_mode)
# return blk_list, mask, mask_refined
|