rapidocr-APId / app.py
valerii777's picture
Initial (no) commit (v0.0.1)
7f4886d verified
import time
from pathlib import Path
import cv2
import uvicorn
import numpy as np
import pandas as pd
from PIL import Image
from rapidocr_onnxruntime import RapidOCR
from fastapi import FastAPI, File, UploadFile, Query
from fastapi.responses import RedirectResponse
app = FastAPI()
det_models = [
"ch_PP-OCRv4_det_infer.onnx",
"ch_PP-OCRv3_det_infer.onnx",
"ch_PP-OCRv2_det_infer.onnx",
"ch_ppocr_server_v2.0_det_infer.onnx",
]
rec_models = [
"ch_PP-OCRv4_rec_infer.onnx",
"ch_PP-OCRv3_rec_infer.onnx",
"ch_PP-OCRv2_rec_infer.onnx",
"ch_PP-OCRv4_det_server_infer.onnx",
"ch_ppocr_server_v2.0_rec_infer.onnx",
"en_PP-OCRv3_rec_infer.onnx",
"en_number_mobile_v2.0_rec_infer.onnx",
"korean_mobile_v2.0_rec_infer.onnx",
"japan_rec_crnn_v2.onnx",
]
def get_text(
image,
text_det=None,
text_rec=None,
box_thresh=0.5,
unclip_ratio=1.6,
text_score=0.5,
):
det_model_path = str(Path("models") / "text_det" / text_det)
rec_model_path = str(Path("models") / "text_rec" / text_rec)
if (
"v2" in rec_model_path
or "korean" in rec_model_path
or "japan" in rec_model_path
):
rec_image_shape = [3, 32, 320]
else:
rec_image_shape = [3, 48, 320]
rapid_ocr = RapidOCR(
det_model_path=det_model_path,
rec_model_path=rec_model_path,
rec_img_shape=rec_image_shape,
)
ocr_result, infer_elapse = rapid_ocr(
image, box_thresh=box_thresh, unclip_ratio=unclip_ratio, text_score=text_score
)
if not ocr_result or not infer_elapse:
return None
det_cost, cls_cost, rec_cost = infer_elapse
dt_boxes, rec_res, scores = list(zip(*ocr_result))
out_df = pd.DataFrame(
[[rec, score] for rec, score in zip(rec_res, scores)],
columns=("Rec", "Score"),
)
return out_df
@app.get("/", include_in_schema=False)
def docs_redirect():
return RedirectResponse(url='/docs')
@app.post("/ocr")
def create_upload_file(
file: UploadFile,
text_det: str = Query("ch_PP-OCRv4_det_infer.onnx", enum=det_models),
text_rec: str = Query("en_number_mobile_v2.0_rec_infer.onnx", enum=rec_models),
box_thresh: float = 0.5,
unclip_ratio: float = 1.6,
text_score: float = 0.5,
):
resp = get_text(
file.file.read(), text_det, text_rec, box_thresh, unclip_ratio, text_score
)
return resp.to_dict("list")
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)