TrustOCR-Demo / app.py
MohammadReza-Halakoo's picture
Update app.py
7fc4d76 verified
raw
history blame
8.14 kB
import os
import tempfile
import argparse
import io
from typing import List
import pypdfium2
import streamlit as st
from surya.ocr import run_ocr, batch_text_detection # ✅ درست
from surya.layout import batch_layout_detection
from surya.model.detection.segformer import load_model, load_processor
from surya.model.recognition.model import load_model as load_rec_model
from surya.model.recognition.processor import load_processor as load_rec_processor
from surya.model.ordering.processor import load_processor as load_order_processor
from surya.model.ordering.model import load_model as load_order_model
from surya.ordering import batch_ordering
from surya.postprocessing.heatmap import draw_polys_on_image
from surya.postprocessing.text import draw_text_on_image
from PIL import Image
from surya.languages import CODE_TO_LANGUAGE
from surya.input.langs import replace_lang_with_code
from surya.schema import OCRResult, TextDetectionResult, LayoutResult, OrderResult
import pytesseract
import cv2
import numpy as np
# -------------------
# مسیر امن برای Streamlit در Hugging Face
# -------------------
runtime_dir = os.path.join(tempfile.gettempdir(), ".streamlit")
os.environ["STREAMLIT_RUNTIME_DIR"] = runtime_dir
os.makedirs(runtime_dir, exist_ok=True)
# -------------------
# Args
# -------------------
parser = argparse.ArgumentParser(description="Run OCR on an image or PDF.")
parser.add_argument("--math", action="store_true", help="Use math model for detection", default=False)
try:
args = parser.parse_args()
except SystemExit as e:
print(f"Error parsing arguments: {e}")
os._exit(e.code)
# -------------------
# Helper Functions
# -------------------
def remove_border(image_path, output_path):
image = cv2.imread(image_path)
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
_, binary = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
max_contour = max(contours, key=cv2.contourArea)
epsilon = 0.02 * cv2.arcLength(max_contour, True)
approx = cv2.approxPolyDP(max_contour, epsilon, True)
if len(approx) == 4:
pts = approx.reshape(4, 2)
rect = np.zeros((4, 2), dtype="float32")
s = pts.sum(axis=1)
rect[0] = pts[np.argmin(s)]
rect[2] = pts[np.argmax(s)]
diff = np.diff(pts, axis=1)
rect[1] = pts[np.argmin(diff)]
rect[3] = pts[np.argmax(diff)]
(tl, tr, br, bl) = rect
widthA = np.linalg.norm(br - bl)
widthB = np.linalg.norm(tr - tl)
maxWidth = max(int(widthA), int(widthB))
heightA = np.linalg.norm(tr - br)
heightB = np.linalg.norm(tl - bl)
maxHeight = max(int(heightA), int(heightB))
dst = np.array([[0, 0], [maxWidth - 1, 0],
[maxWidth - 1, maxHeight - 1],
[0, maxHeight - 1]], dtype="float32")
M = cv2.getPerspectiveTransform(rect, dst)
cropped = cv2.warpPerspective(image, M, (maxWidth, maxHeight))
cv2.imwrite(output_path, cropped)
return cropped
else:
cv2.imwrite(output_path, image)
return image
def text_detection(img):
pred = batch_text_detection([img], det_model, det_processor)[0]
polygons = [p.polygon for p in pred.bboxes]
det_img = draw_polys_on_image(polygons, img.copy())
return det_img, pred
def layout_detection(img):
_, det_pred = text_detection(img)
pred = batch_layout_detection([img], layout_model, layout_processor, [det_pred])[0]
polygons = [p.polygon for p in pred.bboxes]
labels = [p.label for p in pred.bboxes]
layout_img = draw_polys_on_image(polygons, img.copy(), labels=labels, label_font_size=40)
return layout_img, pred
def order_detection(img):
_, layout_pred = layout_detection(img)
bboxes = [l.bbox for l in layout_pred.bboxes]
pred = batch_ordering([img], [bboxes], order_model, order_processor)[0]
polys = [l.polygon for l in pred.bboxes]
positions = [str(l.position) for l in pred.bboxes]
order_img = draw_polys_on_image(polys, img.copy(), labels=positions, label_font_size=40)
return order_img, pred
def ocr(img, langs: List[str]):
replace_lang_with_code(langs)
img_pred = run_ocr([img], [langs], det_model, det_processor, rec_model, rec_processor)[0]
bboxes = [l.bbox for l in img_pred.text_lines]
text = [l.text for l in img_pred.text_lines]
rec_img = draw_text_on_image(bboxes, text, img.size, langs, has_math="_math" in langs)
return rec_img, img_pred
def open_pdf(pdf_file):
stream = io.BytesIO(pdf_file.getvalue())
return pypdfium2.PdfDocument(stream)
@st.cache_data()
def get_page_image(pdf_file, page_num, dpi=96):
doc = open_pdf(pdf_file)
renderer = doc.render(pypdfium2.PdfBitmap.to_pil, page_indices=[page_num - 1], scale=dpi / 72)
png = list(renderer)[0]
return png.convert("RGB")
@st.cache_data()
def page_count(pdf_file):
doc = open_pdf(pdf_file)
return len(doc)
# -------------------
# Streamlit UI
# -------------------
st.set_page_config(layout="wide")
col2, col1 = st.columns([.5, .5])
@st.cache_resource()
def load_det_cached():
return load_model(checkpoint="vikp/surya_det2"), load_processor(checkpoint="vikp/surya_det2")
@st.cache_resource()
def load_rec_cached():
return load_rec_model(checkpoint="MohammadReza-Halakoo/TrustOCR"), \
load_rec_processor(checkpoint="MohammadReza-Halakoo/TrustOCR")
@st.cache_resource()
def load_layout_cached():
return load_model(checkpoint="vikp/surya_layout2"), load_processor(checkpoint="vikp/surya_layout2")
@st.cache_resource()
def load_order_cached():
return load_order_model(checkpoint="vikp/surya_order"), load_order_processor(checkpoint="vikp/surya_order")
det_model, det_processor = load_det_cached()
rec_model, rec_processor = load_rec_cached()
layout_model, layout_processor = load_layout_cached()
order_model, order_processor = load_order_cached()
st.markdown("# TRUST OCR DEMO")
in_file = st.sidebar.file_uploader("فایل PDF یا عکس :", type=["pdf", "png", "jpg", "jpeg", "gif", "webp"])
languages = st.sidebar.multiselect("زبان‌ها", sorted(list(CODE_TO_LANGUAGE.values())), default=["Persian"], max_selections=4)
if in_file is None:
st.stop()
filetype = in_file.type
if "pdf" in filetype:
page_number = st.sidebar.number_input(f"صفحه:", min_value=1, value=1, max_value=page_count(in_file))
pil_image = get_page_image(in_file, page_number)
else:
bytes_data = in_file.getvalue()
temp_dir = "temp_files"
os.makedirs(temp_dir, exist_ok=True)
file_path = os.path.join(temp_dir, in_file.name)
with open(file_path, "wb") as f:
f.write(bytes_data)
out_file = file_path.split('.')[0] + "-1.JPG"
remove_border(file_path, out_file)
pil_image = Image.open(out_file).convert("RGB")
text_det = st.sidebar.button("تشخیص متن")
text_rec = st.sidebar.button("تبدیل به متن")
layout_det = st.sidebar.button("آنالیز صفحه")
order_det = st.sidebar.button("ترتیب خوانش")
if text_det:
osd = pytesseract.image_to_osd(pil_image, output_type='dict')
im_fixed = pil_image.copy().rotate(osd['orientation'])
det_img, pred = text_detection(im_fixed)
with col1:
st.image(det_img, caption="تشخیص متن", use_column_width=True)
if layout_det:
layout_img, pred = layout_detection(pil_image)
with col1:
st.image(layout_img, caption="آنالیز صفحه", use_column_width=True)
if text_rec:
rec_img, pred = ocr(pil_image, languages)
with col1:
text_tab, json_tab = st.tabs(["متن صفحه", "JSON"])
with text_tab:
st.text("\n".join([p.text for p in pred.text_lines]))
with json_tab:
st.json(pred.model_dump(), expanded=True)
if order_det:
order_img, pred = order_detection(pil_image)
with col1:
st.image(order_img, caption="ترتیب خوانش", use_column_width=True)
with col2:
st.image(pil_image, caption="تصویر ورودی", use_column_width=True)