Spaces:
Running
Running
Jordan Legg
refactor: retrieve title and desc from markdown, improve UI for more responsive usage
bbed54b
| import gradio as gr | |
| import torch | |
| from transformers import AutoModel, AutoTokenizer, AutoConfig | |
| import os | |
| import base64 | |
| import spaces | |
| import io | |
| from PIL import Image | |
| import numpy as np | |
| import yaml | |
| import markdown | |
| from pathlib import Path | |
| # Function to extract title and description from the markdown file | |
| def extract_title_description(md_file_path): | |
| with open(md_file_path, 'r') as f: | |
| lines = f.readlines() | |
| # Extract frontmatter (YAML) for title | |
| frontmatter = [] | |
| content_start = 0 | |
| if lines[0].strip() == '---': | |
| for idx, line in enumerate(lines[1:], 1): | |
| if line.strip() == '---': | |
| content_start = idx + 1 | |
| break | |
| frontmatter.append(line) | |
| frontmatter_yaml = yaml.safe_load(''.join(frontmatter)) | |
| title = frontmatter_yaml.get('title', 'Title Not Found') | |
| # Extract content (description) | |
| description_md = ''.join(lines[content_start:]) | |
| description = markdown.markdown(description_md) | |
| return title, description | |
| # Path to the markdown file | |
| md_file_path = 'content/index.md' | |
| # Extract title and description from the markdown file | |
| title, description = extract_title_description(md_file_path) | |
| # Rest of the script continues as before | |
| model_name = 'ucaslcl/GOT-OCR2_0' | |
| tokenizer = AutoTokenizer.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True) | |
| config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) | |
| model = AutoModel.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True, low_cpu_mem_usage=True, device_map='cuda', use_safetensors=True, pad_token_id=tokenizer.eos_token_id) | |
| model = model.eval().cuda() | |
| model.config.pad_token_id = tokenizer.eos_token_id | |
| def image_to_base64(image): | |
| buffered = io.BytesIO() | |
| image.save(buffered, format="PNG") | |
| return base64.b64encode(buffered.getvalue()).decode() | |
| def process_image(image, task, ocr_type=None, ocr_box=None, ocr_color=None, render=False): | |
| if task == "Plain Text OCR": | |
| res = model.chat(tokenizer, image, ocr_type='ocr') | |
| elif task == "Format Text OCR": | |
| res = model.chat(tokenizer, image, ocr_type='format') | |
| elif task == "Fine-grained OCR (Box)": | |
| res = model.chat(tokenizer, image, ocr_type=ocr_type, ocr_box=ocr_box) | |
| elif task == "Fine-grained OCR (Color)": | |
| res = model.chat(tokenizer, image, ocr_type=ocr_type, ocr_color=ocr_color) | |
| elif task == "Multi-crop OCR": | |
| res = model.chat_crop(tokenizer, image_file=image) | |
| elif task == "Render Formatted OCR": | |
| res = model.chat(tokenizer, image, ocr_type='format', render=True, save_render_file='./demo.html') | |
| with open('./demo.html', 'r') as f: | |
| html_content = f.read() | |
| return res, html_content | |
| return res, None | |
| def update_inputs(task): | |
| if task == "Plain Text OCR" or task == "Format Text OCR" or task == "Multi-crop OCR": | |
| return [gr.update(visible=False)] * 4 | |
| elif task == "Fine-grained OCR (Box)": | |
| return [ | |
| gr.update(visible=True, choices=["ocr", "format"]), | |
| gr.update(visible=True), | |
| gr.update(visible=False), | |
| gr.update(visible=False) | |
| ] | |
| elif task == "Fine-grained OCR (Color)": | |
| return [ | |
| gr.update(visible=True, choices=["ocr", "format"]), | |
| gr.update(visible=False), | |
| gr.update(visible=True, choices=["red", "green", "blue"]), | |
| gr.update(visible=False) | |
| ] | |
| elif task == "Render Formatted OCR": | |
| return [gr.update(visible=False)] * 3 + [gr.update(visible=True)] | |
| def ocr_demo(image, task, ocr_type, ocr_box, ocr_color): | |
| res, html_content = process_image(image, task, ocr_type, ocr_box, ocr_color) | |
| if html_content: | |
| return res, html_content | |
| return res, None | |
| import gradio as gr | |
| with gr.Blocks() as demo: | |
| with gr.Row(): | |
| # Left Column: Description | |
| with gr.Column(scale=1): | |
| gr.Markdown(f"# {title}") | |
| gr.Markdown(description) | |
| # Right Column: App Inputs and Outputs | |
| with gr.Column(scale=3): | |
| image_input = gr.Image(type="filepath", label="Input Image") | |
| task_dropdown = gr.Dropdown( | |
| choices=[ | |
| "Plain Text OCR", | |
| "Format Text OCR", | |
| "Fine-grained OCR (Box)", | |
| "Fine-grained OCR (Color)", | |
| "Multi-crop OCR", | |
| "Render Formatted OCR" | |
| ], | |
| label="Select Task", | |
| value="Plain Text OCR" | |
| ) | |
| ocr_type_dropdown = gr.Dropdown( | |
| choices=["ocr", "format"], | |
| label="OCR Type", | |
| visible=False | |
| ) | |
| ocr_box_input = gr.Textbox( | |
| label="OCR Box (x1,y1,x2,y2)", | |
| placeholder="e.g., 100,100,200,200", | |
| visible=False | |
| ) | |
| ocr_color_dropdown = gr.Dropdown( | |
| choices=["red", "green", "blue"], | |
| label="OCR Color", | |
| visible=False | |
| ) | |
| render_checkbox = gr.Checkbox( | |
| label="Render Result", | |
| visible=False | |
| ) | |
| submit_button = gr.Button("Process") | |
| # OCR Result below the Submit button | |
| output_text = gr.Textbox(label="OCR Result") | |
| output_html = gr.HTML(label="Rendered HTML Output") | |
| # Update inputs dynamically based on task selection | |
| task_dropdown.change( | |
| update_inputs, | |
| inputs=[task_dropdown], | |
| outputs=[ocr_type_dropdown, ocr_box_input, ocr_color_dropdown, render_checkbox] | |
| ) | |
| # Process OCR on button click | |
| submit_button.click( | |
| ocr_demo, | |
| inputs=[image_input, task_dropdown, ocr_type_dropdown, ocr_box_input, ocr_color_dropdown], | |
| outputs=[output_text, output_html] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |