|
import base64 |
|
import os |
|
import re |
|
import shutil |
|
import time |
|
import uuid |
|
from pathlib import Path |
|
|
|
import cv2 |
|
import gradio as gr |
|
import numpy as np |
|
import spaces |
|
import torch |
|
from globe import description, title |
|
from PIL import Image |
|
from render import render_ocr_text |
|
|
|
from transformers import AutoModelForImageTextToText, AutoProcessor |
|
from transformers.image_utils import load_image |
|
|
|
model_name = "stepfun-ai/GOT-OCR-2.0-hf" |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
processor = AutoProcessor.from_pretrained(model_name) |
|
model = AutoModelForImageTextToText.from_pretrained( |
|
model_name, low_cpu_mem_usage=True, device_map=device |
|
) |
|
model = model.eval().to(device) |
|
|
|
UPLOAD_FOLDER = "./uploads" |
|
RESULTS_FOLDER = "./results" |
|
stop_str = "<|im_end|>" |
|
for folder in [UPLOAD_FOLDER, RESULTS_FOLDER]: |
|
if not os.path.exists(folder): |
|
os.makedirs(folder) |
|
|
|
input_index = 0 |
|
|
|
|
|
@spaces.GPU() |
|
def process_image(image, task, ocr_type=None, ocr_box=None, ocr_color=None): |
|
if image is None: |
|
return "Error: No image provided", None, None |
|
|
|
unique_id = str(uuid.uuid4()) |
|
image_path = os.path.join(UPLOAD_FOLDER, f"{unique_id}.png") |
|
result_path = os.path.join(RESULTS_FOLDER, f"{unique_id}.html") |
|
try: |
|
if not isinstance(image, (tuple, list)): |
|
image = [image] |
|
else: |
|
image = [img[0] for img in image] |
|
for i, img in enumerate(image): |
|
if isinstance(img, dict): |
|
composite_image = img.get("composite") |
|
if composite_image is not None: |
|
if isinstance(composite_image, np.ndarray): |
|
cv2.imwrite( |
|
image_path, cv2.cvtColor(composite_image, cv2.COLOR_RGB2BGR) |
|
) |
|
elif isinstance(composite_image, Image.Image): |
|
composite_image.save(image_path) |
|
else: |
|
return ( |
|
"Error: Unsupported image format from ImageEditor", |
|
None, |
|
None, |
|
) |
|
else: |
|
return ( |
|
"Error: No composite image found in ImageEditor output", |
|
None, |
|
None, |
|
) |
|
elif isinstance(img, np.ndarray): |
|
cv2.imwrite(image_path, cv2.cvtColor(img, cv2.COLOR_RGB2BGR)) |
|
elif isinstance(img, str): |
|
shutil.copy(img, image_path) |
|
else: |
|
return "Error: Unsupported image format", None, None |
|
|
|
image[i] = load_image(image_path) |
|
|
|
if task == "Plain Text OCR": |
|
inputs = processor(image, return_tensors="pt").to("cuda") |
|
generate_ids = model.generate( |
|
**inputs, |
|
do_sample=False, |
|
tokenizer=processor.tokenizer, |
|
stop_strings=stop_str, |
|
max_new_tokens=4096, |
|
) |
|
res = processor.decode( |
|
generate_ids[0, inputs["input_ids"].shape[1] :], |
|
skip_special_tokens=True, |
|
) |
|
return res, None, unique_id |
|
else: |
|
if task == "Format Text OCR": |
|
inputs = processor(image, return_tensors="pt", format=True).to("cuda") |
|
generate_ids = model.generate( |
|
**inputs, |
|
do_sample=False, |
|
tokenizer=processor.tokenizer, |
|
stop_strings=stop_str, |
|
max_new_tokens=4096, |
|
) |
|
res = processor.decode( |
|
generate_ids[0, inputs["input_ids"].shape[1] :], |
|
skip_special_tokens=True, |
|
) |
|
ocr_type = "format" |
|
elif task == "Fine-grained OCR (Box)": |
|
inputs = processor(image, return_tensors="pt", box=ocr_box).to("cuda") |
|
generate_ids = model.generate( |
|
**inputs, |
|
do_sample=False, |
|
tokenizer=processor.tokenizer, |
|
stop_strings=stop_str, |
|
max_new_tokens=4096, |
|
) |
|
res = processor.decode( |
|
generate_ids[0, inputs["input_ids"].shape[1] :], |
|
skip_special_tokens=True, |
|
) |
|
elif task == "Fine-grained OCR (Color)": |
|
inputs = processor(image, return_tensors="pt", color=ocr_color).to( |
|
"cuda" |
|
) |
|
generate_ids = model.generate( |
|
**inputs, |
|
do_sample=False, |
|
tokenizer=processor.tokenizer, |
|
stop_strings=stop_str, |
|
max_new_tokens=4096, |
|
) |
|
res = processor.decode( |
|
generate_ids[0, inputs["input_ids"].shape[1] :], |
|
skip_special_tokens=True, |
|
) |
|
elif task == "Multi-crop OCR": |
|
inputs = processor( |
|
image, |
|
return_tensors="pt", |
|
format=True, |
|
crop_to_patches=True, |
|
max_patches=5, |
|
).to("cuda") |
|
generate_ids = model.generate( |
|
**inputs, |
|
do_sample=False, |
|
tokenizer=processor.tokenizer, |
|
stop_strings=stop_str, |
|
max_new_tokens=4096, |
|
) |
|
res = processor.decode( |
|
generate_ids[0, inputs["input_ids"].shape[1] :], |
|
skip_special_tokens=True, |
|
) |
|
ocr_type = "format" |
|
elif task == "Multi-page OCR": |
|
inputs = processor( |
|
image, return_tensors="pt", multi_page=True, format=True |
|
).to("cuda") |
|
generate_ids = model.generate( |
|
**inputs, |
|
do_sample=False, |
|
tokenizer=processor.tokenizer, |
|
stop_strings=stop_str, |
|
max_new_tokens=4096, |
|
) |
|
res = processor.decode( |
|
generate_ids[0, inputs["input_ids"].shape[1] :], |
|
skip_special_tokens=True, |
|
) |
|
ocr_type = "format" |
|
|
|
render_ocr_text(res, result_path, format_text=ocr_type == "format") |
|
if os.path.exists(result_path): |
|
with open(result_path, "r") as f: |
|
html_content = f.read() |
|
return res, html_content, unique_id |
|
else: |
|
return res, None, unique_id |
|
except Exception as e: |
|
return f"Error: {str(e)}", None, None |
|
finally: |
|
if os.path.exists(image_path): |
|
os.remove(image_path) |
|
|
|
|
|
def update_image_input(task): |
|
if task == "Fine-grained OCR (Color)": |
|
return ( |
|
gr.update(visible=False), |
|
gr.update(visible=True), |
|
gr.update(visible=True), |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
) |
|
elif task == "Multi-page OCR": |
|
return ( |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
gr.update(visible=True), |
|
gr.update(visible=True), |
|
) |
|
else: |
|
return ( |
|
gr.update(visible=True), |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
) |
|
|
|
|
|
def update_inputs(task): |
|
if task in [ |
|
"Plain Text OCR", |
|
"Format Text OCR", |
|
"Multi-crop OCR", |
|
]: |
|
return [ |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
gr.update(visible=True), |
|
gr.update(visible=False), |
|
gr.update(visible=True), |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
] |
|
elif task == "Fine-grained OCR (Box)": |
|
return [ |
|
gr.update(visible=True, choices=["ocr", "format"]), |
|
gr.update(visible=True), |
|
gr.update(visible=False), |
|
gr.update(visible=True), |
|
gr.update(visible=False), |
|
gr.update(visible=True), |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
] |
|
elif task == "Fine-grained OCR (Color)": |
|
return [ |
|
gr.update(visible=True, choices=["ocr", "format"]), |
|
gr.update(visible=False), |
|
gr.update(visible=True, choices=["red", "green", "blue"]), |
|
gr.update(visible=False), |
|
gr.update(visible=True), |
|
gr.update(visible=False), |
|
gr.update(visible=True), |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
] |
|
elif task == "Multi-page OCR": |
|
return [ |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
gr.update(visible=True), |
|
gr.update(visible=True), |
|
] |
|
|
|
|
|
def parse_latex_output(res): |
|
|
|
lines = re.split(r"(\$\$.*?\$\$)", res, flags=re.DOTALL) |
|
parsed_lines = [] |
|
in_latex = False |
|
latex_buffer = [] |
|
|
|
for line in lines: |
|
if line == "\n": |
|
if in_latex: |
|
latex_buffer.append(line) |
|
else: |
|
parsed_lines.append(line) |
|
continue |
|
|
|
line = line.strip() |
|
|
|
latex_patterns = [r"\{", r"\}", r"\[", r"\]", r"\\", r"\$", r"_", r"^", r'"'] |
|
contains_latex = any(re.search(pattern, line) for pattern in latex_patterns) |
|
|
|
if contains_latex: |
|
if not in_latex: |
|
in_latex = True |
|
latex_buffer = ["$$"] |
|
latex_buffer.append(line) |
|
else: |
|
if in_latex: |
|
latex_buffer.append("$$") |
|
parsed_lines.extend(latex_buffer) |
|
in_latex = False |
|
latex_buffer = [] |
|
parsed_lines.append(line) |
|
|
|
if in_latex: |
|
latex_buffer.append("$$") |
|
parsed_lines.extend(latex_buffer) |
|
|
|
return "$$\\$$\n".join(parsed_lines) |
|
|
|
|
|
def ocr_demo(image, task, ocr_type, ocr_box, ocr_color): |
|
res, html_content, unique_id = process_image( |
|
image, task, ocr_type, ocr_box, ocr_color |
|
) |
|
|
|
if isinstance(res, str) and res.startswith("Error:"): |
|
return res, None |
|
|
|
res = res.replace("\\title", "\\title ") |
|
formatted_res = res |
|
|
|
|
|
if html_content: |
|
encoded_html = base64.b64encode(html_content.encode("utf-8")).decode("utf-8") |
|
iframe_src = f"data:text/html;base64,{encoded_html}" |
|
iframe = f'<iframe src="{iframe_src}" width="100%" height="600px"></iframe>' |
|
download_link = f'<a href="data:text/html;base64,{encoded_html}" download="result_{unique_id}.html">Download Full Result</a>' |
|
return formatted_res, f"{download_link}<br>{iframe}" |
|
return formatted_res, None |
|
|
|
|
|
def cleanup_old_files(): |
|
current_time = time.time() |
|
for folder in [UPLOAD_FOLDER, RESULTS_FOLDER]: |
|
for file_path in Path(folder).glob("*"): |
|
if current_time - file_path.stat().st_mtime > 3600: |
|
file_path.unlink() |
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Soft()) as demo: |
|
gr.Markdown(title) |
|
gr.Markdown(description) |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
with gr.Group(): |
|
image_input = gr.Image(type="filepath", label="Input Image") |
|
gallery_input = gr.Gallery( |
|
type="filepath", label="Input images", visible=False |
|
) |
|
image_editor = gr.ImageEditor( |
|
label="Image Editor", type="pil", visible=False |
|
) |
|
task_dropdown = gr.Dropdown( |
|
choices=[ |
|
"Plain Text OCR", |
|
"Format Text OCR", |
|
"Fine-grained OCR (Box)", |
|
"Fine-grained OCR (Color)", |
|
"Multi-crop OCR", |
|
"Multi-page OCR", |
|
], |
|
label="Select Task", |
|
value="Plain Text OCR", |
|
) |
|
ocr_type_dropdown = gr.Dropdown( |
|
choices=["ocr", "format"], label="OCR Type", visible=False |
|
) |
|
ocr_box_input = gr.Textbox( |
|
label="OCR Box (x1,y1,x2,y2)", |
|
placeholder="[100,100,200,200]", |
|
visible=False, |
|
) |
|
ocr_color_dropdown = gr.Dropdown( |
|
choices=["red", "green", "blue"], label="OCR Color", visible=False |
|
) |
|
|
|
|
|
|
|
|
|
submit_button = gr.Button("Process", variant="primary") |
|
editor_submit_button = gr.Button("Process Edited Image", visible=False, variant="primary") |
|
gallery_submit_button = gr.Button( |
|
"Process Multiple Images", visible=False, variant="primary" |
|
) |
|
|
|
with gr.Column(scale=1): |
|
with gr.Group(): |
|
output_markdown = gr.Textbox(label="Text output") |
|
output_html = gr.HTML(label="HTML output") |
|
|
|
input_types = [ |
|
image_input, |
|
image_editor, |
|
gallery_input, |
|
] |
|
|
|
task_dropdown.change( |
|
update_inputs, |
|
inputs=[task_dropdown], |
|
outputs=[ |
|
ocr_type_dropdown, |
|
ocr_box_input, |
|
ocr_color_dropdown, |
|
image_input, |
|
image_editor, |
|
submit_button, |
|
editor_submit_button, |
|
gallery_input, |
|
gallery_submit_button, |
|
], |
|
) |
|
|
|
task_dropdown.change( |
|
update_image_input, |
|
inputs=[task_dropdown], |
|
outputs=[ |
|
image_input, |
|
image_editor, |
|
editor_submit_button, |
|
gallery_input, |
|
gallery_submit_button, |
|
], |
|
) |
|
|
|
submit_button.click( |
|
ocr_demo, |
|
inputs=[ |
|
image_input, |
|
task_dropdown, |
|
ocr_type_dropdown, |
|
ocr_box_input, |
|
ocr_color_dropdown, |
|
], |
|
outputs=[output_markdown, output_html], |
|
) |
|
editor_submit_button.click( |
|
ocr_demo, |
|
inputs=[ |
|
image_editor, |
|
task_dropdown, |
|
ocr_type_dropdown, |
|
ocr_box_input, |
|
ocr_color_dropdown, |
|
], |
|
outputs=[output_markdown, output_html], |
|
) |
|
gallery_submit_button.click( |
|
ocr_demo, |
|
inputs=[ |
|
gallery_input, |
|
task_dropdown, |
|
ocr_type_dropdown, |
|
ocr_box_input, |
|
ocr_color_dropdown, |
|
], |
|
outputs=[output_markdown, output_html], |
|
) |
|
example = gr.Examples( |
|
examples=[ |
|
[ |
|
"./sheet_music.png", |
|
"Format Text OCR", |
|
"format", |
|
None, |
|
None, |
|
], |
|
[ |
|
"./latex.png", |
|
"Format Text OCR", |
|
"format", |
|
None, |
|
None, |
|
], |
|
], |
|
inputs=[ |
|
image_input, |
|
task_dropdown, |
|
ocr_type_dropdown, |
|
ocr_box_input, |
|
ocr_color_dropdown, |
|
], |
|
outputs=[output_markdown, output_html], |
|
) |
|
example_finegrained = gr.Examples( |
|
examples=[ |
|
[ |
|
"./multi_box.png", |
|
"Fine-grained OCR (Color)", |
|
"ocr", |
|
None, |
|
"red", |
|
] |
|
], |
|
inputs=[ |
|
image_editor, |
|
task_dropdown, |
|
ocr_type_dropdown, |
|
ocr_box_input, |
|
ocr_color_dropdown, |
|
], |
|
outputs=[output_markdown, output_html], |
|
label="Fine-grained example", |
|
) |
|
|
|
gr.Markdown( |
|
"Space based on [Tonic's GOT-OCR](https://huggingface.co/spaces/Tonic/GOT-OCR)" |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
cleanup_old_files() |
|
demo.launch() |
|
|