import time from pathlib import Path import cv2 import uvicorn import numpy as np import pandas as pd from PIL import Image from rapidocr_onnxruntime import RapidOCR from fastapi import FastAPI, File, UploadFile, Query from fastapi.responses import RedirectResponse app = FastAPI() det_models = [ "ch_PP-OCRv4_det_infer.onnx", "ch_PP-OCRv3_det_infer.onnx", "ch_PP-OCRv2_det_infer.onnx", "ch_ppocr_server_v2.0_det_infer.onnx", ] rec_models = [ "ch_PP-OCRv4_rec_infer.onnx", "ch_PP-OCRv3_rec_infer.onnx", "ch_PP-OCRv2_rec_infer.onnx", "ch_PP-OCRv4_det_server_infer.onnx", "ch_ppocr_server_v2.0_rec_infer.onnx", "en_PP-OCRv3_rec_infer.onnx", "en_number_mobile_v2.0_rec_infer.onnx", "korean_mobile_v2.0_rec_infer.onnx", "japan_rec_crnn_v2.onnx", ] def get_text( image, text_det=None, text_rec=None, box_thresh=0.5, unclip_ratio=1.6, text_score=0.5, ): det_model_path = str(Path("models") / "text_det" / text_det) rec_model_path = str(Path("models") / "text_rec" / text_rec) if ( "v2" in rec_model_path or "korean" in rec_model_path or "japan" in rec_model_path ): rec_image_shape = [3, 32, 320] else: rec_image_shape = [3, 48, 320] rapid_ocr = RapidOCR( det_model_path=det_model_path, rec_model_path=rec_model_path, rec_img_shape=rec_image_shape, ) ocr_result, infer_elapse = rapid_ocr( image, box_thresh=box_thresh, unclip_ratio=unclip_ratio, text_score=text_score ) if not ocr_result or not infer_elapse: return None det_cost, cls_cost, rec_cost = infer_elapse dt_boxes, rec_res, scores = list(zip(*ocr_result)) out_df = pd.DataFrame( [[rec, score] for rec, score in zip(rec_res, scores)], columns=("Rec", "Score"), ) return out_df @app.get("/", include_in_schema=False) def docs_redirect(): return RedirectResponse(url='/docs') @app.post("/ocr") def create_upload_file( file: UploadFile, text_det: str = Query("ch_PP-OCRv4_det_infer.onnx", enum=det_models), text_rec: str = Query("en_number_mobile_v2.0_rec_infer.onnx", enum=rec_models), box_thresh: float = 0.5, unclip_ratio: float = 1.6, text_score: float = 0.5, ): resp = get_text( file.file.read(), text_det, text_rec, box_thresh, unclip_ratio, text_score ) return resp.to_dict("list") if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=7860)