Spaces:
Sleeping
Sleeping
File size: 13,346 Bytes
1eea0ea 2ef1dc3 7fc4d76 1eea0ea 94b3206 22e7155 1eea0ea 94b3206 1eea0ea 810036c 8936a61 1eea0ea 8936a61 1eea0ea 810036c 1eea0ea d5e4264 1eea0ea 7fc4d76 1eea0ea d5e4264 1eea0ea 0e2a088 d5e4264 1eea0ea 2ef1dc3 1eea0ea d5e4264 1eea0ea d5e4264 1eea0ea 2ef1dc3 d5e4264 1eea0ea d5e4264 1eea0ea d5e4264 1eea0ea d5e4264 1eea0ea d5e4264 7fc4d76 1eea0ea d5e4264 1eea0ea d5e4264 1eea0ea d5e4264 1eea0ea 94b3206 22e7155 1eea0ea 22e7155 1eea0ea 22e7155 1eea0ea 22e7155 1eea0ea 22e7155 1eea0ea 22e7155 1eea0ea 22e7155 1eea0ea 22e7155 1eea0ea 22e7155 d5e4264 1eea0ea 2ef1dc3 1eea0ea d5e4264 1eea0ea d5e4264 1eea0ea d5e4264 1eea0ea |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 |
# app.py — TRUST OCR DEMO (Streamlit) — works even if batch_text_detection is missing
import os
import io
import tempfile
from typing import List
import numpy as np
import cv2
from PIL import Image
import pypdfium2
import pytesseract
# --- set safe dirs before importing streamlit ---
safe_home = os.environ.get("HOME") or "/app"
os.environ["HOME"] = safe_home
cfg_dir = os.path.join(safe_home, ".streamlit")
os.makedirs(cfg_dir, exist_ok=True)
# اطمینان از اینکه Streamlit همه فایلها را اینجا مینویسد
os.environ["STREAMLIT_CONFIG_DIR"] = cfg_dir
# اگر دوست داری همینجا config.toml بسازی و usage stats را خاموش کنی:
conf_path = os.path.join(cfg_dir, "config.toml")
if not os.path.exists(conf_path):
with open(conf_path, "w", encoding="utf-8") as f:
f.write("browser.gatherUsageStats = false\n")
# runtime dir امن
runtime_dir = os.path.join(tempfile.gettempdir(), ".streamlit")
os.environ["STREAMLIT_RUNTIME_DIR"] = runtime_dir
os.makedirs(runtime_dir, exist_ok=True)
import streamlit as st
# ===== Safe runtime dir for Streamlit/HF cache =====
# runtime_dir = os.path.join(tempfile.gettempdir(), ".streamlit")
# os.environ["STREAMLIT_RUNTIME_DIR"] = runtime_dir
# os.makedirs(runtime_dir, exist_ok=True)
# ===== Try to import Surya APIs =====
DET_AVAILABLE = True
try:
from surya.detection import batch_text_detection
except Exception:
DET_AVAILABLE = False
from surya.layout import batch_layout_detection # may still import; we’ll gate usage by DET_AVAILABLE
# Detection model loaders: segformer (newer) vs model (older)
try:
from surya.model.detection.segformer import load_model as load_det_model, load_processor as load_det_processor
except Exception:
from surya.model.detection.model import load_model as load_det_model, load_processor as load_det_processor
from surya.model.recognition.model import load_model as load_rec_model
from surya.model.recognition.processor import load_processor as load_rec_processor
from surya.model.ordering.model import load_model as load_order_model
from surya.model.ordering.processor import load_processor as load_order_processor
from surya.ordering import batch_ordering
from surya.ocr import run_ocr
from surya.postprocessing.heatmap import draw_polys_on_image
from surya.postprocessing.text import draw_text_on_image
from surya.languages import CODE_TO_LANGUAGE
from surya.input.langs import replace_lang_with_code
from surya.schema import OCRResult, TextDetectionResult, LayoutResult, OrderResult
# ===================== Helper Functions =====================
def remove_border(image_path: str, output_path: str) -> np.ndarray:
"""Remove outer border & deskew (perspective) if a rectangular contour is found."""
image = cv2.imread(image_path)
if image is None:
raise ValueError(f"Cannot read image: {image_path}")
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
_, binary = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
if not contours:
cv2.imwrite(output_path, image)
return image
max_contour = max(contours, key=cv2.contourArea)
epsilon = 0.02 * cv2.arcLength(max_contour, True)
approx = cv2.approxPolyDP(max_contour, epsilon, True)
if len(approx) == 4:
pts = approx.reshape(4, 2).astype("float32")
rect = np.zeros((4, 2), dtype="float32")
s = pts.sum(axis=1)
rect[0] = pts[np.argmin(s)] # tl
rect[2] = pts[np.argmax(s)] # br
diff = np.diff(pts, axis=1)
rect[1] = pts[np.argmin(diff)] # tr
rect[3] = pts[np.argmax(diff)] # bl
(tl, tr, br, bl) = rect
widthA = np.linalg.norm(br - bl)
widthB = np.linalg.norm(tr - tl)
maxWidth = max(int(widthA), int(widthB))
heightA = np.linalg.norm(tr - br)
heightB = np.linalg.norm(tl - bl)
maxHeight = max(int(heightA), int(heightB))
dst = np.array([[0, 0], [maxWidth - 1, 0],
[maxWidth - 1, maxHeight - 1],
[0, maxHeight - 1]], dtype="float32")
M = cv2.getPerspectiveTransform(rect, dst)
cropped = cv2.warpPerspective(image, M, (maxWidth, maxHeight))
cv2.imwrite(output_path, cropped)
return cropped
else:
cv2.imwrite(output_path, image)
return image
def open_pdf(pdf_file) -> pypdfium2.PdfDocument:
stream = io.BytesIO(pdf_file.getvalue())
return pypdfium2.PdfDocument(stream)
@st.cache_data(show_spinner=False)
def get_page_image(pdf_file, page_num: int, dpi: int = 96) -> Image.Image:
doc = open_pdf(pdf_file)
renderer = doc.render(pypdfium2.PdfBitmap.to_pil, page_indices=[page_num - 1], scale=dpi / 72)
png = list(renderer)[0]
return png.convert("RGB")
@st.cache_data(show_spinner=False)
def page_count(pdf_file) -> int:
doc = open_pdf(pdf_file)
return len(doc)
# ===================== Streamlit UI =====================
st.set_page_config(page_title="TRUST OCR DEMO", layout="wide")
st.markdown("# TRUST OCR DEMO")
if not DET_AVAILABLE:
st.warning("⚠️ ماژول تشخیص متن Surya در این محیط در دسترس نیست. OCR کامل کار میکند، اما دکمههای Detection/Layout/Order غیرفعال شدهاند. برای فعالسازی آنها، Surya را به نسخهٔ سازگار پین کنید (راهنما پایین صفحه).")
# Sidebar controls
in_file = st.sidebar.file_uploader("فایل PDF یا عکس :", type=["pdf", "png", "jpg", "jpeg", "gif", "webp"])
languages = st.sidebar.multiselect(
"زبانها (Languages)",
sorted(list(CODE_TO_LANGUAGE.values())),
default=["Persian"],
max_selections=4
)
auto_rotate = st.sidebar.toggle("چرخش خودکار (Tesseract OSD)", value=True)
auto_border = st.sidebar.toggle("حذف قاب/کادر تصویر ورودی", value=True)
# Buttons (disable some if detection missing)
text_det_btn = st.sidebar.button("تشخیص متن (Detection)", disabled=not DET_AVAILABLE)
layout_det_btn = st.sidebar.button("آنالیز صفحه (Layout)", disabled=not DET_AVAILABLE)
order_det_btn = st.sidebar.button("ترتیب خوانش (Reading Order)", disabled=not DET_AVAILABLE)
text_rec_btn = st.sidebar.button("تبدیل به متن (Recognition)")
if in_file is None:
st.info("یک فایل PDF/عکس از سایدبار انتخاب کنید. | Please upload a file to begin.")
st.stop()
filetype = in_file.type
# Two-column layout (left: outputs / right: input image)
col2, col1 = st.columns([.5, .5])
# ===================== Load Models (cached) =====================
@st.cache_resource(show_spinner=True)
def load_det_cached():
return load_det_model(checkpoint="vikp/surya_det2"), load_det_processor(checkpoint="vikp/surya_det2")
@st.cache_resource(show_spinner=True)
def load_rec_cached():
return load_rec_model(checkpoint="MohammadReza-Halakoo/TrustOCR"), \
load_rec_processor(checkpoint="MohammadReza-Halakoo/TrustOCR")
@st.cache_resource(show_spinner=True)
def load_layout_cached():
return load_det_model(checkpoint="vikp/surya_layout2"), load_det_processor(checkpoint="vikp/surya_layout2")
@st.cache_resource(show_spinner=True)
def load_order_cached():
return load_order_model(checkpoint="vikp/surya_order"), load_order_processor(checkpoint="vikp/surya_order")
# recognition models are enough for run_ocr; detection/layout/order models used only if DET_AVAILABLE
rec_model, rec_processor = load_rec_cached()
if DET_AVAILABLE:
det_model, det_processor = load_det_cached()
layout_model, layout_processor = load_layout_cached()
order_model, order_processor = load_order_cached()
else:
det_model = det_processor = layout_model = layout_processor = order_model = order_processor = None
# ===================== High-level Ops =====================
def _apply_auto_rotate(pil_img: Image.Image) -> Image.Image:
"""Auto-rotate using Tesseract OSD if enabled."""
if not auto_rotate:
return pil_img
try:
osd = pytesseract.image_to_osd(pil_img, output_type=pytesseract.Output.DICT)
angle = int(osd.get("rotate", 0)) # 0/90/180/270
if angle and angle % 360 != 0:
return pil_img.rotate(-angle, expand=True)
return pil_img
except Exception as e:
st.warning(f"OSD rotation failed, continuing without rotation. Error: {e}")
return pil_img
def text_detection(pil_img: Image.Image):
pred: TextDetectionResult = batch_text_detection([pil_img], det_model, det_processor)[0]
polygons = [p.polygon for p in pred.bboxes]
det_img = draw_polys_on_image(polygons, pil_img.copy())
return det_img, pred
def layout_detection(pil_img: Image.Image):
_, det_pred = text_detection(pil_img)
pred: LayoutResult = batch_layout_detection([pil_img], layout_model, layout_processor, [det_pred])[0]
polygons = [p.polygon for p in pred.bboxes]
labels = [p.label for p in pred.bboxes]
layout_img = draw_polys_on_image(polygons, pil_img.copy(), labels=labels, label_font_size=40)
return layout_img, pred
def order_detection(pil_img: Image.Image):
_, layout_pred = layout_detection(pil_img)
bboxes = [l.bbox for l in layout_pred.bboxes]
pred: OrderResult = batch_ordering([pil_img], [bboxes], order_model, order_processor)[0]
polys = [l.polygon for l in pred.bboxes]
positions = [str(l.position) for l in pred.bboxes]
order_img = draw_polys_on_image(polys, pil_img.copy(), labels=positions, label_font_size=40)
return order_img, pred
def ocr_page(pil_img: Image.Image, langs: List[str]):
"""Full-page OCR using Surya run_ocr — works without detection import."""
langs = list(langs) if langs else ["Persian"]
replace_lang_with_code(langs) # in-place
# If detection models are loaded, pass them; else, let run_ocr use its internal defaults
args = [pil_img], [langs]
if det_model and det_processor and rec_model and rec_processor:
img_pred: OCRResult = run_ocr([pil_img], [langs], det_model, det_processor, rec_model, rec_processor)[0]
else:
img_pred: OCRResult = run_ocr([pil_img], [langs])[0]
bboxes = [l.bbox for l in img_pred.text_lines]
text = [l.text for l in img_pred.text_lines]
rec_img = draw_text_on_image(bboxes, text, pil_img.size, langs, has_math="_math" in langs)
return rec_img, img_pred
# ===================== Input Handling =====================
if "pdf" in filetype:
try:
pg_cnt = page_count(in_file)
except Exception as e:
st.error(f"خواندن PDF ناموفق بود: {e}")
st.stop()
page_number = st.sidebar.number_input("صفحه:", min_value=1, value=1, max_value=pg_cnt)
pil_image = get_page_image(in_file, page_number)
else:
bytes_data = in_file.getvalue()
temp_dir = "temp_files"
os.makedirs(temp_dir, exist_ok=True)
file_path = os.path.join(temp_dir, in_file.name)
with open(file_path, "wb") as f:
f.write(bytes_data)
out_file = os.path.splitext(file_path)[0] + "-1.JPG"
try:
if auto_border:
_ = remove_border(file_path, out_file)
pil_image = Image.open(out_file).convert("RGB")
else:
pil_image = Image.open(file_path).convert("RGB")
except Exception as e:
st.warning(f"حذف قاب/بازخوانی تصویر با خطا مواجه شد؛ تصویر اصلی استفاده میشود. Error: {e}")
pil_image = Image.open(file_path).convert("RGB")
# Auto-rotate if enabled
pil_image = _apply_auto_rotate(pil_image)
# ===================== Buttons Logic =====================
with col1:
if text_det_btn and DET_AVAILABLE:
try:
det_img, det_pred = text_detection(pil_image)
st.image(det_img, caption="تشخیص متن (Detection)", use_column_width=True)
except Exception as e:
st.error(f"خطا در تشخیص متن: {e}")
if layout_det_btn and DET_AVAILABLE:
try:
layout_img, layout_pred = layout_detection(pil_image)
st.image(layout_img, caption="آنالیز صفحه (Layout)", use_column_width=True)
except Exception as e:
st.error(f"خطا در آنالیز صفحه: {e}")
if order_det_btn and DET_AVAILABLE:
try:
order_img, order_pred = order_detection(pil_image)
st.image(order_img, caption="ترتیب خوانش (Reading Order)", use_column_width=True)
except Exception as e:
st.error(f"خطا در ترتیب خوانش: {e}")
if text_rec_btn:
try:
rec_img, ocr_pred = ocr_page(pil_image, languages)
text_tab, json_tab = st.tabs(["متن صفحه | Page Text", "JSON"])
with text_tab:
st.text("\n".join([p.text for p in ocr_pred.text_lines]))
with json_tab:
st.json(ocr_pred.model_dump(), expanded=False)
except Exception as e:
st.error(f"خطا در بازشناسی متن (Recognition): {e}")
with col2:
st.image(pil_image, caption="تصویر ورودی | Input Preview", use_column_width=True)
|