|
import os |
|
import shutil |
|
import numpy as np |
|
import einops |
|
from typing import Union, Tuple |
|
import cv2 |
|
import torch |
|
|
|
from .ctd_utils.basemodel import TextDetBase, TextDetBaseDNN |
|
from .ctd_utils.utils.yolov5_utils import non_max_suppression |
|
from .ctd_utils.utils.db_utils import SegDetectorRepresenter |
|
from .ctd_utils.utils.imgproc_utils import letterbox |
|
from .ctd_utils.textmask import REFINEMASK_INPAINT, refine_mask |
|
from .common import OfflineDetector |
|
from ..utils import Quadrilateral, det_rearrange_forward |
|
|
|
def preprocess_img(img, input_size=(1024, 1024), device='cpu', bgr2rgb=True, half=False, to_tensor=True): |
|
if bgr2rgb: |
|
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) |
|
img_in, ratio, (dw, dh) = letterbox(img, new_shape=input_size, auto=False, stride=64) |
|
if to_tensor: |
|
img_in = img_in.transpose((2, 0, 1))[::-1] |
|
img_in = np.array([np.ascontiguousarray(img_in)]).astype(np.float32) / 255 |
|
if to_tensor: |
|
img_in = torch.from_numpy(img_in).to(device) |
|
if half: |
|
img_in = img_in.half() |
|
return img_in, ratio, int(dw), int(dh) |
|
|
|
def postprocess_mask(img: Union[torch.Tensor, np.ndarray], thresh=None): |
|
|
|
if isinstance(img, torch.Tensor): |
|
img = img.squeeze_() |
|
if img.device != 'cpu': |
|
img = img.detach().cpu() |
|
img = img.numpy() |
|
else: |
|
img = img.squeeze() |
|
if thresh is not None: |
|
img = img > thresh |
|
img = img * 255 |
|
|
|
|
|
return img.astype(np.uint8) |
|
|
|
def postprocess_yolo(det, conf_thresh, nms_thresh, resize_ratio, sort_func=None): |
|
det = non_max_suppression(det, conf_thresh, nms_thresh)[0] |
|
|
|
if det.device != 'cpu': |
|
det = det.detach_().cpu().numpy() |
|
det[..., [0, 2]] = det[..., [0, 2]] * resize_ratio[0] |
|
det[..., [1, 3]] = det[..., [1, 3]] * resize_ratio[1] |
|
if sort_func is not None: |
|
det = sort_func(det) |
|
|
|
blines = det[..., 0:4].astype(np.int32) |
|
confs = np.round(det[..., 4], 3) |
|
cls = det[..., 5].astype(np.int32) |
|
return blines, cls, confs |
|
|
|
|
|
class ComicTextDetector(OfflineDetector): |
|
_MODEL_MAPPING = { |
|
'model-cuda': { |
|
'url': 'https://github.com/zyddnys/manga-image-translator/releases/download/beta-0.3/comictextdetector.pt', |
|
'hash': '1f90fa60aeeb1eb82e2ac1167a66bf139a8a61b8780acd351ead55268540cccb', |
|
'file': '.', |
|
}, |
|
'model-cpu': { |
|
'url': 'https://github.com/zyddnys/manga-image-translator/releases/download/beta-0.3/comictextdetector.pt.onnx', |
|
'hash': '1a86ace74961413cbd650002e7bb4dcec4980ffa21b2f19b86933372071d718f', |
|
'file': '.', |
|
}, |
|
} |
|
|
|
def __init__(self, *args, **kwargs): |
|
os.makedirs(self.model_dir, exist_ok=True) |
|
if os.path.exists('comictextdetector.pt'): |
|
shutil.move('comictextdetector.pt', self._get_file_path('comictextdetector.pt')) |
|
if os.path.exists('comictextdetector.pt.onnx'): |
|
shutil.move('comictextdetector.pt.onnx', self._get_file_path('comictextdetector.pt.onnx')) |
|
super().__init__(*args, **kwargs) |
|
|
|
async def _load(self, device: str, input_size=1024, half=False, nms_thresh=0.35, conf_thresh=0.4): |
|
self.device = device |
|
if self.device == 'cuda' or self.device == 'mps': |
|
self.model = TextDetBase(self._get_file_path('comictextdetector.pt'), device=self.device, act='leaky') |
|
self.model.to(self.device) |
|
self.backend = 'torch' |
|
else: |
|
model_path = self._get_file_path('comictextdetector.pt.onnx') |
|
self.model = cv2.dnn.readNetFromONNX(model_path) |
|
self.model = TextDetBaseDNN(input_size, model_path) |
|
self.backend = 'opencv' |
|
|
|
if isinstance(input_size, int): |
|
input_size = (input_size, input_size) |
|
self.input_size = input_size |
|
self.half = half |
|
self.conf_thresh = conf_thresh |
|
self.nms_thresh = nms_thresh |
|
self.seg_rep = SegDetectorRepresenter(thresh=0.3) |
|
|
|
async def _unload(self): |
|
del self.model |
|
|
|
def det_batch_forward_ctd(self, batch: np.ndarray, device: str) -> Tuple[np.ndarray, np.ndarray]: |
|
if isinstance(self.model, TextDetBase): |
|
batch = einops.rearrange(batch.astype(np.float32) / 255., 'n h w c -> n c h w') |
|
batch = torch.from_numpy(batch).to(device) |
|
_, mask, lines = self.model(batch) |
|
mask = mask.detach().cpu().numpy() |
|
lines = lines.detach().cpu().numpy() |
|
elif isinstance(self.model, TextDetBaseDNN): |
|
mask_lst, line_lst = [], [] |
|
for b in batch: |
|
_, mask, lines = self.model(b) |
|
if mask.shape[1] == 2: |
|
tmp = mask |
|
mask = lines |
|
lines = tmp |
|
mask_lst.append(mask) |
|
line_lst.append(lines) |
|
lines, mask = np.concatenate(line_lst, 0), np.concatenate(mask_lst, 0) |
|
else: |
|
raise NotImplementedError |
|
return lines, mask |
|
|
|
@torch.no_grad() |
|
async def _infer(self, image: np.ndarray, detect_size: int, text_threshold: float, box_threshold: float, |
|
unclip_ratio: float, verbose: bool = False): |
|
|
|
|
|
|
|
|
|
im_h, im_w = image.shape[:2] |
|
lines_map, mask = det_rearrange_forward(image, self.det_batch_forward_ctd, self.input_size[0], 4, self.device, verbose) |
|
|
|
|
|
if lines_map is None: |
|
img_in, ratio, dw, dh = preprocess_img(image, input_size=self.input_size, device=self.device, half=self.half, to_tensor=self.backend=='torch') |
|
blks, mask, lines_map = self.model(img_in) |
|
|
|
if self.backend == 'opencv': |
|
if mask.shape[1] == 2: |
|
tmp = mask |
|
mask = lines_map |
|
lines_map = tmp |
|
mask = mask.squeeze() |
|
|
|
|
|
mask = mask[..., :mask.shape[0]-dh, :mask.shape[1]-dw] |
|
lines_map = lines_map[..., :lines_map.shape[2]-dh, :lines_map.shape[3]-dw] |
|
|
|
mask = postprocess_mask(mask) |
|
lines, scores = self.seg_rep(None, lines_map, height=im_h, width=im_w) |
|
box_thresh = 0.6 |
|
idx = np.where(scores[0] > box_thresh) |
|
lines, scores = lines[0][idx], scores[0][idx] |
|
|
|
|
|
mask = cv2.resize(mask, (im_w, im_h), interpolation=cv2.INTER_LINEAR) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
textlines = [Quadrilateral(pts.astype(int), '', score) for pts, score in zip(lines, scores)] |
|
mask_refined = refine_mask(image, mask, textlines, refine_mode=None) |
|
|
|
return textlines, mask_refined, None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|