| import os | |
| import shutil | |
| import tempfile | |
| from time import perf_counter | |
| from typing import Any, List, Union | |
| from doctr import models as models | |
| from doctr.io import DocumentFile | |
| from doctr.models import ocr_predictor | |
| from PIL import Image | |
| from inference.core.entities.requests.doctr import DoctrOCRInferenceRequest | |
| from inference.core.entities.requests.inference import InferenceRequest | |
| from inference.core.entities.responses.doctr import DoctrOCRInferenceResponse | |
| from inference.core.entities.responses.inference import InferenceResponse | |
| from inference.core.env import MODEL_CACHE_DIR | |
| from inference.core.models.roboflow import RoboflowCoreModel | |
| from inference.core.utils.image_utils import load_image | |
| class DocTR(RoboflowCoreModel): | |
| def __init__(self, *args, model_id: str = "doctr_rec/crnn_vgg16_bn", **kwargs): | |
| """Initializes the DocTR model. | |
| Args: | |
| *args: Variable length argument list. | |
| **kwargs: Arbitrary keyword arguments. | |
| """ | |
| self.api_key = kwargs.get("api_key") | |
| self.dataset_id = "doctr" | |
| self.version_id = "default" | |
| self.endpoint = model_id | |
| model_id = model_id.lower() | |
| os.environ["DOCTR_CACHE_DIR"] = os.path.join(MODEL_CACHE_DIR, "doctr_rec") | |
| self.det_model = DocTRDet(api_key=kwargs.get("api_key")) | |
| self.rec_model = DocTRRec(api_key=kwargs.get("api_key")) | |
| os.makedirs(f"{MODEL_CACHE_DIR}/doctr_rec/models/", exist_ok=True) | |
| os.makedirs(f"{MODEL_CACHE_DIR}/doctr_det/models/", exist_ok=True) | |
| shutil.copyfile( | |
| f"{MODEL_CACHE_DIR}/doctr_det/db_resnet50/model.pt", | |
| f"{MODEL_CACHE_DIR}/doctr_det/models/db_resnet50-ac60cadc.pt", | |
| ) | |
| shutil.copyfile( | |
| f"{MODEL_CACHE_DIR}/doctr_rec/crnn_vgg16_bn/model.pt", | |
| f"{MODEL_CACHE_DIR}/doctr_rec/models/crnn_vgg16_bn-9762b0b0.pt", | |
| ) | |
| self.model = ocr_predictor( | |
| det_arch=self.det_model.version_id, | |
| reco_arch=self.rec_model.version_id, | |
| pretrained=True, | |
| ) | |
| self.task_type = "ocr" | |
| def clear_cache(self) -> None: | |
| self.det_model.clear_cache() | |
| self.rec_model.clear_cache() | |
| def preprocess_image(self, image: Image.Image) -> Image.Image: | |
| """ | |
| DocTR pre-processes images as part of its inference pipeline. | |
| Thus, no preprocessing is required here. | |
| """ | |
| pass | |
| def infer_from_request( | |
| self, request: DoctrOCRInferenceRequest | |
| ) -> DoctrOCRInferenceResponse: | |
| t1 = perf_counter() | |
| result = self.infer(**request.dict()) | |
| return DoctrOCRInferenceResponse( | |
| result=result, | |
| time=perf_counter() - t1, | |
| ) | |
| def infer(self, image: Any, **kwargs): | |
| """ | |
| Run inference on a provided image. | |
| Args: | |
| request (DoctrOCRInferenceRequest): The inference request. | |
| Returns: | |
| DoctrOCRInferenceResponse: The inference response. | |
| """ | |
| img = load_image(image) | |
| with tempfile.NamedTemporaryFile(suffix=".jpg") as f: | |
| image = Image.fromarray(img[0]) | |
| image.save(f.name) | |
| doc = DocumentFile.from_images([f.name]) | |
| result = self.model(doc).export() | |
| result = result["pages"][0]["blocks"] | |
| result = [ | |
| " ".join([word["value"] for word in line["words"]]) | |
| for block in result | |
| for line in block["lines"] | |
| ] | |
| result = " ".join(result) | |
| return result | |
| def get_infer_bucket_file_list(self) -> list: | |
| """Get the list of required files for inference. | |
| Returns: | |
| list: A list of required files for inference, e.g., ["model.pt"]. | |
| """ | |
| return ["model.pt"] | |
| class DocTRRec(RoboflowCoreModel): | |
| def __init__(self, *args, model_id: str = "doctr_rec/crnn_vgg16_bn", **kwargs): | |
| """Initializes the DocTR model. | |
| Args: | |
| *args: Variable length argument list. | |
| **kwargs: Arbitrary keyword arguments. | |
| """ | |
| pass | |
| self.get_infer_bucket_file_list() | |
| super().__init__(*args, model_id=model_id, **kwargs) | |
| def get_infer_bucket_file_list(self) -> list: | |
| """Get the list of required files for inference. | |
| Returns: | |
| list: A list of required files for inference, e.g., ["model.pt"]. | |
| """ | |
| return ["model.pt"] | |
| class DocTRDet(RoboflowCoreModel): | |
| """DocTR class for document Optical Character Recognition (OCR). | |
| Attributes: | |
| doctr: The DocTR model. | |
| ort_session: ONNX runtime inference session. | |
| """ | |
| def __init__(self, *args, model_id: str = "doctr_det/db_resnet50", **kwargs): | |
| """Initializes the DocTR model. | |
| Args: | |
| *args: Variable length argument list. | |
| **kwargs: Arbitrary keyword arguments. | |
| """ | |
| self.get_infer_bucket_file_list() | |
| super().__init__(*args, model_id=model_id, **kwargs) | |
| def get_infer_bucket_file_list(self) -> list: | |
| """Get the list of required files for inference. | |
| Returns: | |
| list: A list of required files for inference, e.g., ["model.pt"]. | |
| """ | |
| return ["model.pt"] | |