""" Copyright (c) 2019-present NAVER Corp. MIT License """ import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import os import shutil import numpy as np import torch import cv2 import einops from typing import List, Tuple from .default_utils.DBNet_resnet34 import TextDetection as TextDetectionDefault from .default_utils import imgproc, dbnet_utils, craft_utils from .common import OfflineDetector from ..utils import TextBlock, Quadrilateral, det_rearrange_forward from shapely.geometry import Polygon, MultiPoint from shapely import affinity from .craft_utils.vgg16_bn import vgg16_bn, init_weights from .craft_utils.refiner import RefineNet class double_conv(nn.Module): def __init__(self, in_ch, mid_ch, out_ch): super(double_conv, self).__init__() self.conv = nn.Sequential( nn.Conv2d(in_ch + mid_ch, mid_ch, kernel_size=1), nn.BatchNorm2d(mid_ch), nn.ReLU(inplace=True), nn.Conv2d(mid_ch, out_ch, kernel_size=3, padding=1), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True) ) def forward(self, x): x = self.conv(x) return x class CRAFT(nn.Module): def __init__(self, pretrained=False, freeze=False): super(CRAFT, self).__init__() """ Base network """ self.basenet = vgg16_bn(pretrained, freeze) """ U network """ self.upconv1 = double_conv(1024, 512, 256) self.upconv2 = double_conv(512, 256, 128) self.upconv3 = double_conv(256, 128, 64) self.upconv4 = double_conv(128, 64, 32) num_class = 2 self.conv_cls = nn.Sequential( nn.Conv2d(32, 32, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(32, 32, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(32, 16, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(16, 16, kernel_size=1), nn.ReLU(inplace=True), nn.Conv2d(16, num_class, kernel_size=1), ) init_weights(self.upconv1.modules()) init_weights(self.upconv2.modules()) init_weights(self.upconv3.modules()) init_weights(self.upconv4.modules()) init_weights(self.conv_cls.modules()) def forward(self, x): """ Base network """ sources = self.basenet(x) """ U network """ y = torch.cat([sources[0], sources[1]], dim=1) y = self.upconv1(y) y = F.interpolate(y, size=sources[2].size()[2:], mode='bilinear', align_corners=False) y = torch.cat([y, sources[2]], dim=1) y = self.upconv2(y) y = F.interpolate(y, size=sources[3].size()[2:], mode='bilinear', align_corners=False) y = torch.cat([y, sources[3]], dim=1) y = self.upconv3(y) y = F.interpolate(y, size=sources[4].size()[2:], mode='bilinear', align_corners=False) y = torch.cat([y, sources[4]], dim=1) feature = self.upconv4(y) y = self.conv_cls(feature) return y.permute(0,2,3,1), feature from collections import OrderedDict def copyStateDict(state_dict): if list(state_dict.keys())[0].startswith("module"): start_idx = 1 else: start_idx = 0 new_state_dict = OrderedDict() for k, v in state_dict.items(): name = ".".join(k.split(".")[start_idx:]) new_state_dict[name] = v return new_state_dict class CRAFTDetector(OfflineDetector): _MODEL_MAPPING = { 'refiner': { 'url': 'https://github.com/zyddnys/manga-image-translator/releases/download/beta-0.3/craft_refiner_CTW1500.pth', 'hash': 'f7000cd3e9c76f2231b62b32182212203f73c08dfaa12bb16ffb529948a01399', 'file': 'craft_refiner_CTW1500.pth', }, 'craft': { 'url': 'https://github.com/zyddnys/manga-image-translator/releases/download/beta-0.3/craft_mlt_25k.pth', 'hash': '4a5efbfb48b4081100544e75e1e2b57f8de3d84f213004b14b85fd4b3748db17', 'file': 'craft_mlt_25k.pth', } } def __init__(self, *args, **kwargs): os.makedirs(self.model_dir, exist_ok=True) if os.path.exists('craft_mlt_25k.pth'): shutil.move('craft_mlt_25k.pth', self._get_file_path('craft_mlt_25k.pth')) if os.path.exists('craft_refiner_CTW1500.pth'): shutil.move('craft_refiner_CTW1500.pth', self._get_file_path('craft_refiner_CTW1500.pth')) super().__init__(*args, **kwargs) async def _load(self, device: str): self.model = CRAFT() self.model.load_state_dict(copyStateDict(torch.load(self._get_file_path('craft_mlt_25k.pth'), map_location='cpu'))) self.model.eval() self.model_refiner = RefineNet() self.model_refiner.load_state_dict(copyStateDict(torch.load(self._get_file_path('craft_refiner_CTW1500.pth'), map_location='cpu'))) self.model_refiner.eval() self.device = device if device == 'cuda' or device == 'mps': self.model = self.model.to(self.device) self.model_refiner = self.model_refiner.to(self.device) global MODEL MODEL = self.model async def _unload(self): del self.model async def _infer(self, image: np.ndarray, detect_size: int, text_threshold: float, box_threshold: float, unclip_ratio: float, verbose: bool = False): img_resized, target_ratio, size_heatmap, pad_w, pad_h = imgproc.resize_aspect_ratio(image, detect_size, interpolation = cv2.INTER_CUBIC, mag_ratio = 1) ratio_h = ratio_w = 1 / target_ratio # preprocessing x = imgproc.normalizeMeanVariance(img_resized) x = torch.from_numpy(x).permute(2, 0, 1) # [h, w, c] to [c, h, w] x = x.unsqueeze(0).to(self.device) # [c, h, w] to [b, c, h, w] with torch.no_grad() : y, feature = self.model(x) # make score and link map score_text = y[0,:,:,0].cpu().data.numpy() score_link = y[0,:,:,1].cpu().data.numpy() # refine link y_refiner = self.model_refiner(y, feature) score_link = y_refiner[0,:,:,0].cpu().data.numpy() # Post-processing boxes, polys = craft_utils.getDetBoxes(score_text, score_link, text_threshold, box_threshold, box_threshold, True) # coordinate adjustment boxes = craft_utils.adjustResultCoordinates(boxes, ratio_w, ratio_h) polys = craft_utils.adjustResultCoordinates(polys, ratio_w, ratio_h) for k in range(len(polys)): if polys[k] is None: polys[k] = boxes[k] mask = np.zeros(shape = (image.shape[0], image.shape[1]), dtype = np.uint8) for poly in polys : mask = cv2.fillPoly(mask, [poly.reshape((-1, 1, 2)).astype(np.int32)], color = 255) polys_ret = [] for i in range(len(polys)) : poly = MultiPoint(polys[i]) if poly.area > 10 : rect = poly.minimum_rotated_rectangle rect = affinity.scale(rect, xfact = 1.2, yfact = 1.2) polys_ret.append(np.roll(np.asarray(list(rect.exterior.coords)[:4]), 2)) kern = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (9, 9)) mask = cv2.dilate(mask, kern) textlines = [Quadrilateral(pts.astype(int), '', 1) for pts in polys_ret] textlines = list(filter(lambda q: q.area > 16, textlines)) return textlines, mask, None