Spaces:
Runtime error
Runtime error
import time | |
from pathlib import Path | |
import cv2 | |
import uvicorn | |
import numpy as np | |
import pandas as pd | |
from PIL import Image | |
from rapidocr_onnxruntime import RapidOCR | |
from fastapi import FastAPI, File, UploadFile, Query | |
from fastapi.responses import RedirectResponse | |
app = FastAPI() | |
det_models = [ | |
"ch_PP-OCRv4_det_infer.onnx", | |
"ch_PP-OCRv3_det_infer.onnx", | |
"ch_PP-OCRv2_det_infer.onnx", | |
"ch_ppocr_server_v2.0_det_infer.onnx", | |
] | |
rec_models = [ | |
"ch_PP-OCRv4_rec_infer.onnx", | |
"ch_PP-OCRv3_rec_infer.onnx", | |
"ch_PP-OCRv2_rec_infer.onnx", | |
"ch_PP-OCRv4_det_server_infer.onnx", | |
"ch_ppocr_server_v2.0_rec_infer.onnx", | |
"en_PP-OCRv3_rec_infer.onnx", | |
"en_number_mobile_v2.0_rec_infer.onnx", | |
"korean_mobile_v2.0_rec_infer.onnx", | |
"japan_rec_crnn_v2.onnx", | |
] | |
def get_text( | |
image, | |
text_det=None, | |
text_rec=None, | |
box_thresh=0.5, | |
unclip_ratio=1.6, | |
text_score=0.5, | |
): | |
det_model_path = str(Path("models") / "text_det" / text_det) | |
rec_model_path = str(Path("models") / "text_rec" / text_rec) | |
if ( | |
"v2" in rec_model_path | |
or "korean" in rec_model_path | |
or "japan" in rec_model_path | |
): | |
rec_image_shape = [3, 32, 320] | |
else: | |
rec_image_shape = [3, 48, 320] | |
rapid_ocr = RapidOCR( | |
det_model_path=det_model_path, | |
rec_model_path=rec_model_path, | |
rec_img_shape=rec_image_shape, | |
) | |
ocr_result, infer_elapse = rapid_ocr( | |
image, box_thresh=box_thresh, unclip_ratio=unclip_ratio, text_score=text_score | |
) | |
if not ocr_result or not infer_elapse: | |
return None | |
det_cost, cls_cost, rec_cost = infer_elapse | |
dt_boxes, rec_res, scores = list(zip(*ocr_result)) | |
out_df = pd.DataFrame( | |
[[rec, score] for rec, score in zip(rec_res, scores)], | |
columns=("Rec", "Score"), | |
) | |
return out_df | |
def docs_redirect(): | |
return RedirectResponse(url='/docs') | |
def create_upload_file( | |
file: UploadFile, | |
text_det: str = Query("ch_PP-OCRv4_det_infer.onnx", enum=det_models), | |
text_rec: str = Query("en_number_mobile_v2.0_rec_infer.onnx", enum=rec_models), | |
box_thresh: float = 0.5, | |
unclip_ratio: float = 1.6, | |
text_score: float = 0.5, | |
): | |
resp = get_text( | |
file.file.read(), text_det, text_rec, box_thresh, unclip_ratio, text_score | |
) | |
return resp.to_dict("list") | |
if __name__ == "__main__": | |
uvicorn.run(app, host="0.0.0.0", port=7860) |