TrustOCR-Demo / app.py
MohammadReza-Halakoo's picture
Update app.py
1eea0ea verified
raw
history blame
12.5 kB
# app.py — TRUST OCR DEMO (Streamlit) — works even if batch_text_detection is missing
import os
import io
import tempfile
from typing import List
import numpy as np
import cv2
from PIL import Image
import pypdfium2
import pytesseract
import streamlit as st
# ===== Safe runtime dir for Streamlit/HF cache =====
runtime_dir = os.path.join(tempfile.gettempdir(), ".streamlit")
os.environ["STREAMLIT_RUNTIME_DIR"] = runtime_dir
os.makedirs(runtime_dir, exist_ok=True)
# ===== Try to import Surya APIs =====
DET_AVAILABLE = True
try:
from surya.detection import batch_text_detection
except Exception:
DET_AVAILABLE = False
from surya.layout import batch_layout_detection # may still import; we’ll gate usage by DET_AVAILABLE
# Detection model loaders: segformer (newer) vs model (older)
try:
from surya.model.detection.segformer import load_model as load_det_model, load_processor as load_det_processor
except Exception:
from surya.model.detection.model import load_model as load_det_model, load_processor as load_det_processor
from surya.model.recognition.model import load_model as load_rec_model
from surya.model.recognition.processor import load_processor as load_rec_processor
from surya.model.ordering.model import load_model as load_order_model
from surya.model.ordering.processor import load_processor as load_order_processor
from surya.ordering import batch_ordering
from surya.ocr import run_ocr
from surya.postprocessing.heatmap import draw_polys_on_image
from surya.postprocessing.text import draw_text_on_image
from surya.languages import CODE_TO_LANGUAGE
from surya.input.langs import replace_lang_with_code
from surya.schema import OCRResult, TextDetectionResult, LayoutResult, OrderResult
# ===================== Helper Functions =====================
def remove_border(image_path: str, output_path: str) -> np.ndarray:
"""Remove outer border & deskew (perspective) if a rectangular contour is found."""
image = cv2.imread(image_path)
if image is None:
raise ValueError(f"Cannot read image: {image_path}")
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
_, binary = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
if not contours:
cv2.imwrite(output_path, image)
return image
max_contour = max(contours, key=cv2.contourArea)
epsilon = 0.02 * cv2.arcLength(max_contour, True)
approx = cv2.approxPolyDP(max_contour, epsilon, True)
if len(approx) == 4:
pts = approx.reshape(4, 2).astype("float32")
rect = np.zeros((4, 2), dtype="float32")
s = pts.sum(axis=1)
rect[0] = pts[np.argmin(s)] # tl
rect[2] = pts[np.argmax(s)] # br
diff = np.diff(pts, axis=1)
rect[1] = pts[np.argmin(diff)] # tr
rect[3] = pts[np.argmax(diff)] # bl
(tl, tr, br, bl) = rect
widthA = np.linalg.norm(br - bl)
widthB = np.linalg.norm(tr - tl)
maxWidth = max(int(widthA), int(widthB))
heightA = np.linalg.norm(tr - br)
heightB = np.linalg.norm(tl - bl)
maxHeight = max(int(heightA), int(heightB))
dst = np.array([[0, 0], [maxWidth - 1, 0],
[maxWidth - 1, maxHeight - 1],
[0, maxHeight - 1]], dtype="float32")
M = cv2.getPerspectiveTransform(rect, dst)
cropped = cv2.warpPerspective(image, M, (maxWidth, maxHeight))
cv2.imwrite(output_path, cropped)
return cropped
else:
cv2.imwrite(output_path, image)
return image
def open_pdf(pdf_file) -> pypdfium2.PdfDocument:
stream = io.BytesIO(pdf_file.getvalue())
return pypdfium2.PdfDocument(stream)
@st.cache_data(show_spinner=False)
def get_page_image(pdf_file, page_num: int, dpi: int = 96) -> Image.Image:
doc = open_pdf(pdf_file)
renderer = doc.render(pypdfium2.PdfBitmap.to_pil, page_indices=[page_num - 1], scale=dpi / 72)
png = list(renderer)[0]
return png.convert("RGB")
@st.cache_data(show_spinner=False)
def page_count(pdf_file) -> int:
doc = open_pdf(pdf_file)
return len(doc)
# ===================== Streamlit UI =====================
st.set_page_config(page_title="TRUST OCR DEMO", layout="wide")
st.markdown("# TRUST OCR DEMO")
if not DET_AVAILABLE:
st.warning("⚠️ ماژول تشخیص متن Surya در این محیط در دسترس نیست. OCR کامل کار می‌کند، اما دکمه‌های Detection/Layout/Order غیرفعال شده‌اند. برای فعال‌سازی آن‌ها، Surya را به نسخهٔ سازگار پین کنید (راهنما پایین صفحه).")
# Sidebar controls
in_file = st.sidebar.file_uploader("فایل PDF یا عکس :", type=["pdf", "png", "jpg", "jpeg", "gif", "webp"])
languages = st.sidebar.multiselect(
"زبان‌ها (Languages)",
sorted(list(CODE_TO_LANGUAGE.values())),
default=["Persian"],
max_selections=4
)
auto_rotate = st.sidebar.toggle("چرخش خودکار (Tesseract OSD)", value=True)
auto_border = st.sidebar.toggle("حذف قاب/کادر تصویر ورودی", value=True)
# Buttons (disable some if detection missing)
text_det_btn = st.sidebar.button("تشخیص متن (Detection)", disabled=not DET_AVAILABLE)
layout_det_btn = st.sidebar.button("آنالیز صفحه (Layout)", disabled=not DET_AVAILABLE)
order_det_btn = st.sidebar.button("ترتیب خوانش (Reading Order)", disabled=not DET_AVAILABLE)
text_rec_btn = st.sidebar.button("تبدیل به متن (Recognition)")
if in_file is None:
st.info("یک فایل PDF/عکس از سایدبار انتخاب کنید. | Please upload a file to begin.")
st.stop()
filetype = in_file.type
# Two-column layout (left: outputs / right: input image)
col2, col1 = st.columns([.5, .5])
# ===================== Load Models (cached) =====================
@st.cache_resource(show_spinner=True)
def load_det_cached():
return load_det_model(checkpoint="vikp/surya_det2"), load_det_processor(checkpoint="vikp/surya_det2")
@st.cache_resource(show_spinner=True)
def load_rec_cached():
return load_rec_model(checkpoint="MohammadReza-Halakoo/TrustOCR"), \
load_rec_processor(checkpoint="MohammadReza-Halakoo/TrustOCR")
@st.cache_resource(show_spinner=True)
def load_layout_cached():
return load_det_model(checkpoint="vikp/surya_layout2"), load_det_processor(checkpoint="vikp/surya_layout2")
@st.cache_resource(show_spinner=True)
def load_order_cached():
return load_order_model(checkpoint="vikp/surya_order"), load_order_processor(checkpoint="vikp/surya_order")
# recognition models are enough for run_ocr; detection/layout/order models used only if DET_AVAILABLE
rec_model, rec_processor = load_rec_cached()
if DET_AVAILABLE:
det_model, det_processor = load_det_cached()
layout_model, layout_processor = load_layout_cached()
order_model, order_processor = load_order_cached()
else:
det_model = det_processor = layout_model = layout_processor = order_model = order_processor = None
# ===================== High-level Ops =====================
def _apply_auto_rotate(pil_img: Image.Image) -> Image.Image:
"""Auto-rotate using Tesseract OSD if enabled."""
if not auto_rotate:
return pil_img
try:
osd = pytesseract.image_to_osd(pil_img, output_type=pytesseract.Output.DICT)
angle = int(osd.get("rotate", 0)) # 0/90/180/270
if angle and angle % 360 != 0:
return pil_img.rotate(-angle, expand=True)
return pil_img
except Exception as e:
st.warning(f"OSD rotation failed, continuing without rotation. Error: {e}")
return pil_img
def text_detection(pil_img: Image.Image):
pred: TextDetectionResult = batch_text_detection([pil_img], det_model, det_processor)[0]
polygons = [p.polygon for p in pred.bboxes]
det_img = draw_polys_on_image(polygons, pil_img.copy())
return det_img, pred
def layout_detection(pil_img: Image.Image):
_, det_pred = text_detection(pil_img)
pred: LayoutResult = batch_layout_detection([pil_img], layout_model, layout_processor, [det_pred])[0]
polygons = [p.polygon for p in pred.bboxes]
labels = [p.label for p in pred.bboxes]
layout_img = draw_polys_on_image(polygons, pil_img.copy(), labels=labels, label_font_size=40)
return layout_img, pred
def order_detection(pil_img: Image.Image):
_, layout_pred = layout_detection(pil_img)
bboxes = [l.bbox for l in layout_pred.bboxes]
pred: OrderResult = batch_ordering([pil_img], [bboxes], order_model, order_processor)[0]
polys = [l.polygon for l in pred.bboxes]
positions = [str(l.position) for l in pred.bboxes]
order_img = draw_polys_on_image(polys, pil_img.copy(), labels=positions, label_font_size=40)
return order_img, pred
def ocr_page(pil_img: Image.Image, langs: List[str]):
"""Full-page OCR using Surya run_ocr — works without detection import."""
langs = list(langs) if langs else ["Persian"]
replace_lang_with_code(langs) # in-place
# If detection models are loaded, pass them; else, let run_ocr use its internal defaults
args = [pil_img], [langs]
if det_model and det_processor and rec_model and rec_processor:
img_pred: OCRResult = run_ocr([pil_img], [langs], det_model, det_processor, rec_model, rec_processor)[0]
else:
img_pred: OCRResult = run_ocr([pil_img], [langs])[0]
bboxes = [l.bbox for l in img_pred.text_lines]
text = [l.text for l in img_pred.text_lines]
rec_img = draw_text_on_image(bboxes, text, pil_img.size, langs, has_math="_math" in langs)
return rec_img, img_pred
# ===================== Input Handling =====================
if "pdf" in filetype:
try:
pg_cnt = page_count(in_file)
except Exception as e:
st.error(f"خواندن PDF ناموفق بود: {e}")
st.stop()
page_number = st.sidebar.number_input("صفحه:", min_value=1, value=1, max_value=pg_cnt)
pil_image = get_page_image(in_file, page_number)
else:
bytes_data = in_file.getvalue()
temp_dir = "temp_files"
os.makedirs(temp_dir, exist_ok=True)
file_path = os.path.join(temp_dir, in_file.name)
with open(file_path, "wb") as f:
f.write(bytes_data)
out_file = os.path.splitext(file_path)[0] + "-1.JPG"
try:
if auto_border:
_ = remove_border(file_path, out_file)
pil_image = Image.open(out_file).convert("RGB")
else:
pil_image = Image.open(file_path).convert("RGB")
except Exception as e:
st.warning(f"حذف قاب/بازخوانی تصویر با خطا مواجه شد؛ تصویر اصلی استفاده می‌شود. Error: {e}")
pil_image = Image.open(file_path).convert("RGB")
# Auto-rotate if enabled
pil_image = _apply_auto_rotate(pil_image)
# ===================== Buttons Logic =====================
with col1:
if text_det_btn and DET_AVAILABLE:
try:
det_img, det_pred = text_detection(pil_image)
st.image(det_img, caption="تشخیص متن (Detection)", use_column_width=True)
except Exception as e:
st.error(f"خطا در تشخیص متن: {e}")
if layout_det_btn and DET_AVAILABLE:
try:
layout_img, layout_pred = layout_detection(pil_image)
st.image(layout_img, caption="آنالیز صفحه (Layout)", use_column_width=True)
except Exception as e:
st.error(f"خطا در آنالیز صفحه: {e}")
if order_det_btn and DET_AVAILABLE:
try:
order_img, order_pred = order_detection(pil_image)
st.image(order_img, caption="ترتیب خوانش (Reading Order)", use_column_width=True)
except Exception as e:
st.error(f"خطا در ترتیب خوانش: {e}")
if text_rec_btn:
try:
rec_img, ocr_pred = ocr_page(pil_image, languages)
text_tab, json_tab = st.tabs(["متن صفحه | Page Text", "JSON"])
with text_tab:
st.text("\n".join([p.text for p in ocr_pred.text_lines]))
with json_tab:
st.json(ocr_pred.model_dump(), expanded=False)
except Exception as e:
st.error(f"خطا در بازشناسی متن (Recognition): {e}")
with col2:
st.image(pil_image, caption="تصویر ورودی | Input Preview", use_column_width=True)