MohammadReza-Halakoo commited on
Commit
aac01c3
·
verified ·
1 Parent(s): 1981928

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -19
app.py CHANGED
@@ -701,23 +701,21 @@
701
 
702
  # app.py — TRUST OCR DEMO (Streamlit) — personal-recognition-only
703
 
 
 
704
  import os
705
  import io
706
  import tempfile
707
  from typing import List
708
 
709
- import numpy as np
710
- import cv2
711
- from PIL import Image
712
- import pypdfium2
713
- import pytesseract
714
-
715
- # -------------------- Safe, writable dirs & config (BEFORE importing streamlit) --------------------
716
- # Put everything under /tmp (world-writable on HF Spaces)
717
  os.environ.setdefault("HOME", "/tmp")
718
  os.environ.setdefault("STREAMLIT_CONFIG_DIR", "/tmp/.streamlit")
719
  os.environ.setdefault("STREAMLIT_RUNTIME_DIR", "/tmp/.streamlit")
720
  os.environ.setdefault("HF_HOME", "/tmp/hf_home")
 
 
721
 
722
  for d in (os.environ["STREAMLIT_CONFIG_DIR"], os.environ["STREAMLIT_RUNTIME_DIR"], os.environ["HF_HOME"]):
723
  os.makedirs(d, exist_ok=True)
@@ -736,16 +734,22 @@ if not os.path.exists(conf_path):
736
  "gatherUsageStats = false\n"
737
  )
738
 
739
- # HF auth (for private repos) — optional, picked up by transformers/surya automatically
740
  HF_TOKEN = os.environ.get("HUGGINGFACE_HUB_TOKEN") or os.environ.get("HF_TOKEN")
741
  if HF_TOKEN:
742
  os.environ["HUGGINGFACE_HUB_TOKEN"] = HF_TOKEN
743
  try:
744
  from huggingface_hub import login
745
- login(token=HF_TOKEN, add_to_git_credential=True)
746
  except Exception:
747
  pass
748
 
 
 
 
 
 
 
749
  import streamlit as st
750
 
751
  # -------------------- Surya imports (gated) --------------------
@@ -755,7 +759,7 @@ try:
755
  except Exception:
756
  DET_AVAILABLE = False
757
 
758
- from surya.layout import batch_layout_detection # will be gated via DET_AVAILABLE
759
 
760
  # Detection model loaders: try newer segformer, fall back to older
761
  try:
@@ -887,6 +891,7 @@ def load_layout_cached():
887
 
888
  @st.cache_resource(show_spinner=True)
889
  def load_order_cached():
 
890
  return load_order_model(checkpoint="vikp/surya_order"), load_order_processor(checkpoint="vikp/surya_order")
891
 
892
 
@@ -894,7 +899,7 @@ def load_order_cached():
894
  # Choose ONE source:
895
  # (A) Local folder path via env TRUSTOCR_PATH (e.g., /app/models/TrustOCR)
896
  # (B) Private HF repo via env TRUSTOCR_REPO (e.g., MohammadReza-Halakoo/TrustOCR) + HUGGINGFACE_HUB_TOKEN
897
- PERSONAL_MODEL_PATH = os.environ.get("TRUSTOCR_PATH") # local directory containing save_pretrained files
898
  PERSONAL_HF_REPO = os.environ.get("TRUSTOCR_REPO") # private repo id on HF Hub
899
 
900
  @st.cache_resource(show_spinner=True)
@@ -910,20 +915,20 @@ def load_rec_personal():
910
  return m, p
911
  else:
912
  raise RuntimeError(
913
- "No personal recognition model configured. "
914
- "Set TRUSTOCR_PATH to a local folder OR TRUSTOCR_REPO to a private HF repo."
915
  )
916
 
917
- # Load everything
918
  if DET_AVAILABLE:
919
  det_model, det_processor = load_det_cached()
920
  layout_model, layout_processor = load_layout_cached()
921
- order_model, order_processor = load_order_cached()
922
  else:
923
- det_model = det_processor = layout_model = layout_processor = order_model = order_processor = None
924
 
925
  # Always require personal recognition model (no fallback)
926
  rec_model, rec_processor = load_rec_personal()
 
 
927
 
928
 
929
  # ===================== High-level Ops =====================
@@ -960,6 +965,11 @@ def layout_detection(pil_img: Image.Image):
960
 
961
 
962
  def order_detection(pil_img: Image.Image):
 
 
 
 
 
963
  _, layout_pred = layout_detection(pil_img)
964
  bboxes = [l.bbox for l in layout_pred.bboxes]
965
  pred: OrderResult = batch_ordering([pil_img], [bboxes], order_model, order_processor)[0]
@@ -974,10 +984,10 @@ def ocr_page(pil_img: Image.Image, langs: List[str]):
974
  langs = list(langs) if langs else ["Persian"]
975
  replace_lang_with_code(langs) # in-place
976
 
977
- if det_model and det_processor and rec_model and rec_processor:
978
  img_pred: OCRResult = run_ocr([pil_img], [langs], det_model, det_processor, rec_model, rec_processor)[0]
979
  else:
980
- # detection may be missing in some environments; use recognition-only flow
981
  img_pred: OCRResult = run_ocr([pil_img], [langs], rec_model=rec_model, rec_processor=rec_processor)[0]
982
 
983
  bboxes = [l.bbox for l in img_pred.text_lines]
@@ -1039,6 +1049,8 @@ with col1:
1039
  try:
1040
  order_img, order_pred = order_detection(pil_image)
1041
  st.image(order_img, caption="ترتیب خوانش (Reading Order)", use_column_width=True)
 
 
1042
  except Exception as e:
1043
  st.error(f"خطا در ترتیب خوانش: {e}")
1044
 
 
701
 
702
  # app.py — TRUST OCR DEMO (Streamlit) — personal-recognition-only
703
 
704
+ # app.py — TRUST OCR DEMO (Streamlit) with personal recognition model, safe dirs, eager attention, lazy order
705
+
706
  import os
707
  import io
708
  import tempfile
709
  from typing import List
710
 
711
+ # -------------------- Safe, writable dirs & envs BEFORE any ML imports --------------------
712
+ # همه‌ی مسیرهای نوشتنی را به /tmp ببریم (در HF Spaces world-writable)
 
 
 
 
 
 
713
  os.environ.setdefault("HOME", "/tmp")
714
  os.environ.setdefault("STREAMLIT_CONFIG_DIR", "/tmp/.streamlit")
715
  os.environ.setdefault("STREAMLIT_RUNTIME_DIR", "/tmp/.streamlit")
716
  os.environ.setdefault("HF_HOME", "/tmp/hf_home")
717
+ # جلوگیری از sdpa برای surya_order
718
+ os.environ.setdefault("TRANSFORMERS_ATTENTION_IMPLEMENTATION", "eager")
719
 
720
  for d in (os.environ["STREAMLIT_CONFIG_DIR"], os.environ["STREAMLIT_RUNTIME_DIR"], os.environ["HF_HOME"]):
721
  os.makedirs(d, exist_ok=True)
 
734
  "gatherUsageStats = false\n"
735
  )
736
 
737
+ # HF auth (برای ریپوی خصوصی) — اختیاری؛ transformers/surya خودش از env می‌خواند
738
  HF_TOKEN = os.environ.get("HUGGINGFACE_HUB_TOKEN") or os.environ.get("HF_TOKEN")
739
  if HF_TOKEN:
740
  os.environ["HUGGINGFACE_HUB_TOKEN"] = HF_TOKEN
741
  try:
742
  from huggingface_hub import login
743
+ login(token=HF_TOKEN, add_to_git_credential=False)
744
  except Exception:
745
  pass
746
 
747
+ # -------------------- Light deps (safe to import now) --------------------
748
+ import numpy as np
749
+ import cv2
750
+ from PIL import Image
751
+ import pypdfium2
752
+ import pytesseract
753
  import streamlit as st
754
 
755
  # -------------------- Surya imports (gated) --------------------
 
759
  except Exception:
760
  DET_AVAILABLE = False
761
 
762
+ from surya.layout import batch_layout_detection # we'll gate usage via DET_AVAILABLE
763
 
764
  # Detection model loaders: try newer segformer, fall back to older
765
  try:
 
891
 
892
  @st.cache_resource(show_spinner=True)
893
  def load_order_cached():
894
+ # Order model گاهی با attention=sdpa کرش می‌کند؛ ما قبلاً eager را با env ست کرده‌ایم.
895
  return load_order_model(checkpoint="vikp/surya_order"), load_order_processor(checkpoint="vikp/surya_order")
896
 
897
 
 
899
  # Choose ONE source:
900
  # (A) Local folder path via env TRUSTOCR_PATH (e.g., /app/models/TrustOCR)
901
  # (B) Private HF repo via env TRUSTOCR_REPO (e.g., MohammadReza-Halakoo/TrustOCR) + HUGGINGFACE_HUB_TOKEN
902
+ PERSONAL_MODEL_PATH = os.environ.get("TRUSTOCR_PATH") # local directory with save_pretrained files
903
  PERSONAL_HF_REPO = os.environ.get("TRUSTOCR_REPO") # private repo id on HF Hub
904
 
905
  @st.cache_resource(show_spinner=True)
 
915
  return m, p
916
  else:
917
  raise RuntimeError(
918
+ "مدل شخصی تنظیم نشده است. یکی را ست کن: TRUSTOCR_PATH (پوشه محلی) یا TRUSTOCR_REPO (ریپوی خصوصی HF)."
 
919
  )
920
 
921
+ # Load detection/layout immediately (سبک‌ترند). Order را lazy-load می‌کنیم.
922
  if DET_AVAILABLE:
923
  det_model, det_processor = load_det_cached()
924
  layout_model, layout_processor = load_layout_cached()
 
925
  else:
926
+ det_model = det_processor = layout_model = layout_processor = None
927
 
928
  # Always require personal recognition model (no fallback)
929
  rec_model, rec_processor = load_rec_personal()
930
+ order_model = None
931
+ order_processor = None # will lazy-load on button click
932
 
933
 
934
  # ===================== High-level Ops =====================
 
965
 
966
 
967
  def order_detection(pil_img: Image.Image):
968
+ # Lazy-load order model on demand
969
+ global order_model, order_processor
970
+ if order_model is None or order_processor is None:
971
+ order_model, order_processor = load_order_cached()
972
+
973
  _, layout_pred = layout_detection(pil_img)
974
  bboxes = [l.bbox for l in layout_pred.bboxes]
975
  pred: OrderResult = batch_ordering([pil_img], [bboxes], order_model, order_processor)[0]
 
984
  langs = list(langs) if langs else ["Persian"]
985
  replace_lang_with_code(langs) # in-place
986
 
987
+ if DET_AVAILABLE and det_model and det_processor and rec_model and rec_processor:
988
  img_pred: OCRResult = run_ocr([pil_img], [langs], det_model, det_processor, rec_model, rec_processor)[0]
989
  else:
990
+ # detection may be missing in some environments; use recognition-only path
991
  img_pred: OCRResult = run_ocr([pil_img], [langs], rec_model=rec_model, rec_processor=rec_processor)[0]
992
 
993
  bboxes = [l.bbox for l in img_pred.text_lines]
 
1049
  try:
1050
  order_img, order_pred = order_detection(pil_image)
1051
  st.image(order_img, caption="ترتیب خوانش (Reading Order)", use_column_width=True)
1052
+ except KeyError:
1053
+ st.error("Reading Order فعلاً با نسخه‌های فعلی سازگار نیست (attention=sdpa). نسخه‌ها را طبق requirements پین کن.")
1054
  except Exception as e:
1055
  st.error(f"خطا در ترتیب خوانش: {e}")
1056