|
import numpy as np
|
|
from typing import List
|
|
|
|
from .common import CommonOCR, OfflineOCR
|
|
from .model_32px import Model32pxOCR
|
|
from .model_48px import Model48pxOCR
|
|
from .model_48px_ctc import Model48pxCTCOCR
|
|
from .model_manga_ocr import ModelMangaOCR
|
|
from ..utils import Quadrilateral
|
|
|
|
OCRS = {
|
|
'32px': Model32pxOCR,
|
|
'48px': Model48pxOCR,
|
|
'48px_ctc': Model48pxCTCOCR,
|
|
'mocr': ModelMangaOCR,
|
|
}
|
|
ocr_cache = {}
|
|
|
|
def get_ocr(key: str, *args, **kwargs) -> CommonOCR:
|
|
if key not in OCRS:
|
|
raise ValueError(f'Could not find OCR for: "{key}". Choose from the following: %s' % ','.join(OCRS))
|
|
if not ocr_cache.get(key):
|
|
ocr = OCRS[key]
|
|
ocr_cache[key] = ocr(*args, **kwargs)
|
|
return ocr_cache[key]
|
|
|
|
async def prepare(ocr_key: str, device: str = 'cpu'):
|
|
ocr = get_ocr(ocr_key)
|
|
if isinstance(ocr, OfflineOCR):
|
|
await ocr.download()
|
|
await ocr.load(device)
|
|
|
|
async def dispatch(ocr_key: str, image: np.ndarray, regions: List[Quadrilateral], args = None, device: str = 'cpu', verbose: bool = False) -> List[Quadrilateral]:
|
|
ocr = get_ocr(ocr_key)
|
|
if isinstance(ocr, OfflineOCR):
|
|
await ocr.load(device)
|
|
args = args or {}
|
|
return await ocr.recognize(image, regions, args, verbose)
|
|
|