Spaces:
Running
Running
| import gradio as gr | |
| import torch | |
| from transformers import AutoModel, AutoTokenizer, AutoConfig | |
| import os | |
| import base64 | |
| import spaces | |
| import io | |
| from PIL import Image | |
| import numpy as np | |
| import yaml | |
| from pathlib import Path | |
| from globe import title, description, modelinfor, joinus | |
| import uuid | |
| import tempfile | |
| import time | |
| import shutil | |
| model_name = 'ucaslcl/GOT-OCR2_0' | |
| tokenizer = AutoTokenizer.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True) | |
| config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) | |
| model = AutoModel.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True, low_cpu_mem_usage=True, device_map='cuda', use_safetensors=True, pad_token_id=tokenizer.eos_token_id) | |
| model = model.eval().cuda() | |
| model.config.pad_token_id = tokenizer.eos_token_id | |
| UPLOAD_FOLDER = "./uploads" | |
| RESULTS_FOLDER = "./results" | |
| for folder in [UPLOAD_FOLDER, RESULTS_FOLDER]: | |
| if not os.path.exists(folder): | |
| os.makedirs(folder) | |
| def image_to_base64(image): | |
| buffered = io.BytesIO() | |
| image.save(buffered, format="PNG") | |
| return base64.b64encode(buffered.getvalue()).decode() | |
| def process_image(image, task, ocr_type=None, ocr_box=None, ocr_color=None): | |
| if image is None: | |
| return "Error: No image provided", None, None | |
| unique_id = str(uuid.uuid4()) | |
| image_path = os.path.join(UPLOAD_FOLDER, f"{unique_id}.png") | |
| result_path = os.path.join(RESULTS_FOLDER, f"{unique_id}.html") | |
| try: | |
| if isinstance(image, dict): # If image is from ImageEditor | |
| composite_image = image.get("composite") | |
| if composite_image is not None: | |
| if isinstance(composite_image, np.ndarray): | |
| Image.fromarray(composite_image).save(image_path) | |
| elif isinstance(composite_image, str): | |
| shutil.copy(composite_image, image_path) | |
| else: | |
| return "Error: Unsupported image format from ImageEditor", None, None | |
| else: | |
| return "Error: No composite image found in ImageEditor output", None, None | |
| elif isinstance(image, np.ndarray): | |
| Image.fromarray(image).save(image_path) | |
| elif isinstance(image, str): | |
| shutil.copy(image, image_path) | |
| else: | |
| return "Error: Unsupported image format", None, None | |
| if task == "Plain Text OCR": | |
| res = model.chat(tokenizer, image_path, ocr_type='ocr') | |
| return res, None, unique_id | |
| else: | |
| if task == "Format Text OCR": | |
| res = model.chat(tokenizer, image_path, ocr_type='format', render=True, save_render_file=result_path) | |
| elif task == "Fine-grained OCR (Box)": | |
| res = model.chat(tokenizer, image_path, ocr_type=ocr_type, ocr_box=ocr_box, render=True, save_render_file=result_path) | |
| elif task == "Fine-grained OCR (Color)": | |
| res = model.chat(tokenizer, image_path, ocr_type=ocr_type, ocr_color=ocr_color, render=True, save_render_file=result_path) | |
| elif task == "Multi-crop OCR": | |
| res = model.chat_crop(tokenizer, image_path, ocr_type='format', render=True, save_render_file=result_path) | |
| elif task == "Render Formatted OCR": | |
| res = model.chat(tokenizer, image_path, ocr_type='format', render=True, save_render_file=result_path) | |
| if os.path.exists(result_path): | |
| with open(result_path, 'r') as f: | |
| html_content = f.read() | |
| return res, html_content, unique_id | |
| else: | |
| return res, None, unique_id | |
| except Exception as e: | |
| return f"Error: {str(e)}", None, None | |
| finally: | |
| if os.path.exists(image_path): | |
| os.remove(image_path) | |
| def update_image_input(task): | |
| if task == "Fine-grained OCR (Color)": | |
| return gr.update(visible=False), gr.update(visible=True), gr.update(visible=True) | |
| else: | |
| return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False) | |
| def update_inputs(task): | |
| if task in ["Plain Text OCR", "Format Text OCR", "Multi-crop OCR", "Render Formatted OCR"]: | |
| return [gr.update(visible=False)] * 5 + [gr.update(visible=True), gr.update(visible=False)] | |
| elif task == "Fine-grained OCR (Box)": | |
| return [ | |
| gr.update(visible=True, choices=["ocr", "format"]), | |
| gr.update(visible=True), | |
| gr.update(visible=False), | |
| gr.update(visible=False), | |
| gr.update(visible=False), | |
| gr.update(visible=True), | |
| gr.update(visible=False) | |
| ] | |
| elif task == "Fine-grained OCR (Color)": | |
| return [ | |
| gr.update(visible=True, choices=["ocr", "format"]), | |
| gr.update(visible=False), | |
| gr.update(visible=True, choices=["red", "green", "blue"]), | |
| gr.update(visible=False), | |
| gr.update(visible=True), | |
| gr.update(visible=False), | |
| gr.update(visible=True) | |
| ] | |
| def ocr_demo(image, task, ocr_type, ocr_box, ocr_color): | |
| res, html_content, unique_id = process_image(image, task, ocr_type, ocr_box, ocr_color) | |
| if res.startswith("Error:"): | |
| return res, None | |
| res = res.replace("\\title", "\\title ") | |
| res = f"$$ {res} $$" | |
| if html_content: | |
| encoded_html = base64.b64encode(html_content.encode('utf-8')).decode('utf-8') | |
| iframe_src = f"data:text/html;base64,{encoded_html}" | |
| iframe = f'<iframe src="{iframe_src}" width="100%" height="600px"></iframe>' | |
| download_link = f'<a href="data:text/html;base64,{encoded_html}" download="result_{unique_id}.html">Download Full Result</a>' | |
| return res, f"{download_link}<br>{iframe}" | |
| return res, None | |
| def cleanup_old_files(): | |
| current_time = time.time() | |
| for folder in [UPLOAD_FOLDER, RESULTS_FOLDER]: | |
| for file_path in Path(folder).glob('*'): | |
| if current_time - file_path.stat().st_mtime > 3600: # 1 hour | |
| file_path.unlink() | |
| with gr.Blocks(theme=gr.themes.Base()) as demo: | |
| with gr.Row(): | |
| gr.Markdown(title) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown(description) | |
| with gr.Column(scale=1): | |
| with gr.Group(): | |
| gr.Markdown(modelinfor) | |
| gr.Markdown(joinus) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| with gr.Group(): | |
| image_input = gr.Image(type="filepath", label="Input Image") | |
| image_editor = gr.ImageEditor(label="Image Editor", type="pil", visible=False) | |
| task_dropdown = gr.Dropdown( | |
| choices=[ | |
| "Plain Text OCR", | |
| "Format Text OCR", | |
| "Fine-grained OCR (Box)", | |
| "Fine-grained OCR (Color)", | |
| "Multi-crop OCR", | |
| "Render Formatted OCR" | |
| ], | |
| label="Select Task", | |
| value="Plain Text OCR" | |
| ) | |
| ocr_type_dropdown = gr.Dropdown( | |
| choices=["ocr", "format"], | |
| label="OCR Type", | |
| visible=False | |
| ) | |
| ocr_box_input = gr.Textbox( | |
| label="OCR Box (x1,y1,x2,y2)", | |
| placeholder="[100,100,200,200]", | |
| visible=False | |
| ) | |
| ocr_color_dropdown = gr.Dropdown( | |
| choices=["red", "green", "blue"], | |
| label="OCR Color", | |
| visible=False | |
| ) | |
| submit_button = gr.Button("Process") | |
| editor_submit_button = gr.Button("Process Edited Image", visible=False) | |
| with gr.Column(scale=1): | |
| with gr.Group(): | |
| output_markdown = gr.Markdown(label="🫴🏻📸GOT-OCR") | |
| output_html = gr.HTML(label="🫴🏻📸GOT-OCR") | |
| task_dropdown.change( | |
| update_inputs, | |
| inputs=[task_dropdown], | |
| outputs=[ocr_type_dropdown, ocr_box_input, ocr_color_dropdown, image_input, image_editor, submit_button, editor_submit_button] | |
| ) | |
| task_dropdown.change( | |
| update_image_input, | |
| inputs=[task_dropdown], | |
| outputs=[image_input, image_editor, editor_submit_button] | |
| ) | |
| submit_button.click( | |
| ocr_demo, | |
| inputs=[image_input, task_dropdown, ocr_type_dropdown, ocr_box_input, ocr_color_dropdown], | |
| outputs=[output_markdown, output_html] | |
| ) | |
| editor_submit_button.click( | |
| ocr_demo, | |
| inputs=[image_editor, task_dropdown, ocr_type_dropdown, ocr_box_input, ocr_color_dropdown], | |
| outputs=[output_markdown, output_html] | |
| ) | |
| if __name__ == "__main__": | |
| cleanup_old_files() | |
| demo.launch() |