Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
58b56ea
1
Parent(s):
864e5c4
Refactor OCR model loading to use lazy initialization and enhance error handling in predict function
Browse files
app.py
CHANGED
|
@@ -5,21 +5,12 @@ import os
|
|
| 5 |
import torch
|
| 6 |
from transformers import AutoProcessor, AutoModelForImageTextToText, pipeline
|
| 7 |
import spaces
|
| 8 |
-
|
| 9 |
-
#
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
torch_dtype=torch.bfloat16,
|
| 15 |
-
# attn_implementation="flash_attention_2", # User had this commented out
|
| 16 |
-
device_map="auto"
|
| 17 |
-
)
|
| 18 |
-
HF_PIPE = pipeline("image-text-to-text", model=HF_MODEL, processor=HF_PROCESSOR)
|
| 19 |
-
print("Hugging Face OCR model loaded successfully.")
|
| 20 |
-
except Exception as e:
|
| 21 |
-
print(f"Error loading Hugging Face model: {e}")
|
| 22 |
-
HF_PIPE = None
|
| 23 |
|
| 24 |
# --- Helper Functions ---
|
| 25 |
|
|
@@ -68,72 +59,87 @@ def parse_alto_xml_for_text(xml_file_path):
|
|
| 68 |
except Exception as e:
|
| 69 |
return f"An unexpected error occurred during XML parsing: {e}"
|
| 70 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
def run_hf_ocr(image_path):
|
| 72 |
"""
|
| 73 |
-
Runs OCR on the provided image using the
|
| 74 |
"""
|
| 75 |
-
if HF_PIPE is None:
|
| 76 |
-
return "Hugging Face OCR model not available."
|
| 77 |
if image_path is None:
|
| 78 |
return "No image provided for OCR."
|
| 79 |
|
| 80 |
try:
|
| 81 |
-
# Load the image using PIL, as the pipeline expects an image object or path
|
| 82 |
pil_image = Image.open(image_path).convert("RGB")
|
| 83 |
-
|
| 84 |
-
# The user's example output for the pipeline call was:
|
| 85 |
-
# [{'generated_text': [{'role': 'user', ...}, {'role': 'assistant', 'content': "TEXT..."}]}]
|
| 86 |
-
# This suggests the pipeline is returning a conversational style output.
|
| 87 |
-
# We will try to call the pipeline with the image and prompt directly.
|
| 88 |
-
ocr_results = predict(pil_image)
|
| 89 |
|
| 90 |
# Parse the output based on the user's example structure
|
| 91 |
if isinstance(ocr_results, list) and ocr_results and 'generated_text' in ocr_results[0]:
|
| 92 |
generated_content = ocr_results[0]['generated_text']
|
| 93 |
|
| 94 |
-
# Check if generated_content itself is the direct text (some pipelines do this)
|
| 95 |
if isinstance(generated_content, str):
|
| 96 |
return generated_content
|
| 97 |
|
| 98 |
-
# Check for the conversational structure
|
| 99 |
-
# [{'role': 'user', ...}, {'role': 'assistant', 'content': "TEXT..."}]
|
| 100 |
if isinstance(generated_content, list) and generated_content:
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
|
|
|
|
|
|
| 109 |
return assistant_message
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
|
| 118 |
print(f"Unexpected OCR output structure from HF model: {ocr_results}")
|
| 119 |
-
return "Error: Could not parse OCR model output.
|
| 120 |
|
| 121 |
else:
|
| 122 |
print(f"Unexpected OCR output structure from HF model: {ocr_results}")
|
| 123 |
-
return "Error: OCR model did not return expected output.
|
| 124 |
|
|
|
|
|
|
|
| 125 |
except Exception as e:
|
| 126 |
-
print(f"Error during Hugging Face OCR: {e}")
|
| 127 |
return f"Error during Hugging Face OCR: {str(e)}"
|
| 128 |
-
@spaces.GPU
|
| 129 |
-
def predict(pil_image):
|
| 130 |
-
ocr_results = HF_PIPE(
|
| 131 |
-
pil_image,
|
| 132 |
-
prompt="Return the plain text representation of this document as if you were reading it naturally.\n"
|
| 133 |
-
# The pipeline should handle formatting this into messages if needed by the model.
|
| 134 |
-
)
|
| 135 |
-
|
| 136 |
-
return ocr_results
|
| 137 |
|
| 138 |
# --- Gradio Interface Function ---
|
| 139 |
|
|
@@ -241,5 +247,5 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
|
| 241 |
if __name__ == "__main__":
|
| 242 |
# Removed dummy file creation as it's less relevant for single file focus
|
| 243 |
print("Attempting to launch Gradio demo...")
|
| 244 |
-
print("If the Hugging Face model is large, initial startup might take some time due to model download/loading.")
|
| 245 |
demo.launch()
|
|
|
|
| 5 |
import torch
|
| 6 |
from transformers import AutoProcessor, AutoModelForImageTextToText, pipeline
|
| 7 |
import spaces
|
| 8 |
+
|
| 9 |
+
# --- Global Model and Processor (initialize as None for lazy loading) ---
|
| 10 |
+
HF_PROCESSOR = None
|
| 11 |
+
HF_MODEL = None
|
| 12 |
+
HF_PIPE = None
|
| 13 |
+
MODEL_LOAD_ERROR_MSG = None # To store any error message from loading
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
# --- Helper Functions ---
|
| 16 |
|
|
|
|
| 59 |
except Exception as e:
|
| 60 |
return f"An unexpected error occurred during XML parsing: {e}"
|
| 61 |
|
| 62 |
+
@spaces.GPU # Ensures GPU is available for model loading (on first call) and inference
|
| 63 |
+
def predict(pil_image):
|
| 64 |
+
"""Performs OCR prediction using the Hugging Face model, with lazy loading."""
|
| 65 |
+
global HF_PROCESSOR, HF_MODEL, HF_PIPE, MODEL_LOAD_ERROR_MSG
|
| 66 |
+
|
| 67 |
+
if HF_PIPE is None and MODEL_LOAD_ERROR_MSG is None:
|
| 68 |
+
try:
|
| 69 |
+
print("Attempting to load Hugging Face model and processor within @spaces.GPU context...")
|
| 70 |
+
HF_PROCESSOR = AutoProcessor.from_pretrained("reducto/RolmOCR")
|
| 71 |
+
HF_MODEL = AutoModelForImageTextToText.from_pretrained(
|
| 72 |
+
"reducto/RolmOCR",
|
| 73 |
+
torch_dtype=torch.bfloat16,
|
| 74 |
+
device_map="auto" # Should utilize ZeroGPU correctly here
|
| 75 |
+
)
|
| 76 |
+
HF_PIPE = pipeline("image-text-to-text", model=HF_MODEL, processor=HF_PROCESSOR)
|
| 77 |
+
print("Hugging Face OCR model loaded successfully.")
|
| 78 |
+
except Exception as e:
|
| 79 |
+
MODEL_LOAD_ERROR_MSG = f"Error loading Hugging Face model: {str(e)}"
|
| 80 |
+
print(MODEL_LOAD_ERROR_MSG)
|
| 81 |
+
# HF_PIPE remains None, error message is stored
|
| 82 |
+
|
| 83 |
+
if HF_PIPE is None:
|
| 84 |
+
error_to_report = MODEL_LOAD_ERROR_MSG if MODEL_LOAD_ERROR_MSG else "OCR model could not be initialized."
|
| 85 |
+
raise RuntimeError(error_to_report)
|
| 86 |
+
|
| 87 |
+
# Proceed with inference if pipe is available
|
| 88 |
+
return HF_PIPE(
|
| 89 |
+
pil_image,
|
| 90 |
+
prompt="Return the plain text representation of this document as if you were reading it naturally.\n",
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
def run_hf_ocr(image_path):
|
| 94 |
"""
|
| 95 |
+
Runs OCR on the provided image using the Hugging Face model (via predict function).
|
| 96 |
"""
|
|
|
|
|
|
|
| 97 |
if image_path is None:
|
| 98 |
return "No image provided for OCR."
|
| 99 |
|
| 100 |
try:
|
|
|
|
| 101 |
pil_image = Image.open(image_path).convert("RGB")
|
| 102 |
+
ocr_results = predict(pil_image) # predict handles model loading and inference
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
|
| 104 |
# Parse the output based on the user's example structure
|
| 105 |
if isinstance(ocr_results, list) and ocr_results and 'generated_text' in ocr_results[0]:
|
| 106 |
generated_content = ocr_results[0]['generated_text']
|
| 107 |
|
|
|
|
| 108 |
if isinstance(generated_content, str):
|
| 109 |
return generated_content
|
| 110 |
|
|
|
|
|
|
|
| 111 |
if isinstance(generated_content, list) and generated_content:
|
| 112 |
+
if assistant_message := next(
|
| 113 |
+
(
|
| 114 |
+
msg['content']
|
| 115 |
+
for msg in reversed(generated_content)
|
| 116 |
+
if isinstance(msg, dict)
|
| 117 |
+
and msg.get('role') == 'assistant'
|
| 118 |
+
and 'content' in msg
|
| 119 |
+
),
|
| 120 |
+
None,
|
| 121 |
+
):
|
| 122 |
return assistant_message
|
| 123 |
+
|
| 124 |
+
# Fallback if the specific assistant message structure isn't found but there's content
|
| 125 |
+
if isinstance(generated_content[0], dict) and 'content' in generated_content[0]:
|
| 126 |
+
if len(generated_content) > 1 and isinstance(generated_content[1], dict) and 'content' in generated_content[1]:
|
| 127 |
+
return generated_content[1]['content'] # Assuming second part is assistant
|
| 128 |
+
elif 'content' in generated_content[0]: # Or if first part is already the content
|
| 129 |
+
return generated_content[0]['content']
|
| 130 |
|
| 131 |
print(f"Unexpected OCR output structure from HF model: {ocr_results}")
|
| 132 |
+
return "Error: Could not parse OCR model output. Check console."
|
| 133 |
|
| 134 |
else:
|
| 135 |
print(f"Unexpected OCR output structure from HF model: {ocr_results}")
|
| 136 |
+
return "Error: OCR model did not return expected output. Check console."
|
| 137 |
|
| 138 |
+
except RuntimeError as e: # Catch model loading/initialization errors from predict
|
| 139 |
+
return str(e)
|
| 140 |
except Exception as e:
|
| 141 |
+
print(f"Error during Hugging Face OCR processing: {e}")
|
| 142 |
return f"Error during Hugging Face OCR: {str(e)}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 143 |
|
| 144 |
# --- Gradio Interface Function ---
|
| 145 |
|
|
|
|
| 247 |
if __name__ == "__main__":
|
| 248 |
# Removed dummy file creation as it's less relevant for single file focus
|
| 249 |
print("Attempting to launch Gradio demo...")
|
| 250 |
+
print("If the Hugging Face model is large, initial startup might take some time due to model download/loading (on first OCR attempt).")
|
| 251 |
demo.launch()
|