Add OCR app and dependencies
Browse files- ocr_compare_app.py +244 -0
- requirements.txt +9 -0
ocr_compare_app.py
ADDED
@@ -0,0 +1,244 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import google.generativeai as genai
|
3 |
+
from mistralai import Mistral
|
4 |
+
from PIL import Image
|
5 |
+
import io
|
6 |
+
import base64
|
7 |
+
import logging
|
8 |
+
from ocr_secrets import GEMINI_API_KEY, MISTRAL_API_KEY
|
9 |
+
import os
|
10 |
+
import json
|
11 |
+
import threading
|
12 |
+
import time
|
13 |
+
|
14 |
+
# Configure logging
|
15 |
+
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
|
16 |
+
logger = logging.getLogger(__name__)
|
17 |
+
|
18 |
+
# Configure Gemini
|
19 |
+
genai.configure(api_key=GEMINI_API_KEY)
|
20 |
+
gemini_model = genai.GenerativeModel('gemini-2.0-flash-exp')
|
21 |
+
|
22 |
+
VOTES_FILE = "votes.json"
|
23 |
+
votes = []
|
24 |
+
votes_lock = threading.Lock()
|
25 |
+
|
26 |
+
def load_votes():
|
27 |
+
global votes
|
28 |
+
if os.path.exists(VOTES_FILE):
|
29 |
+
with open(VOTES_FILE, "r") as f:
|
30 |
+
votes = json.load(f)
|
31 |
+
else:
|
32 |
+
votes = []
|
33 |
+
|
34 |
+
def save_votes():
|
35 |
+
with votes_lock:
|
36 |
+
with open(VOTES_FILE, "w") as f:
|
37 |
+
json.dump(votes, f, indent=2)
|
38 |
+
|
39 |
+
def periodic_vote_dump(interval=60):
|
40 |
+
def run():
|
41 |
+
while True:
|
42 |
+
save_votes()
|
43 |
+
time.sleep(interval)
|
44 |
+
t = threading.Thread(target=run, daemon=True)
|
45 |
+
t.start()
|
46 |
+
|
47 |
+
load_votes()
|
48 |
+
periodic_vote_dump(60)
|
49 |
+
|
50 |
+
def get_default_username(profile: gr.OAuthProfile | None) -> str:
|
51 |
+
"""
|
52 |
+
Returns the username if the user is logged in, or an empty string if not logged in.
|
53 |
+
"""
|
54 |
+
if profile is None:
|
55 |
+
return ""
|
56 |
+
return profile.username
|
57 |
+
|
58 |
+
def gemini_ocr(image: Image.Image):
|
59 |
+
try:
|
60 |
+
# Convert image to bytes
|
61 |
+
buffered = io.BytesIO()
|
62 |
+
image.save(buffered, format="JPEG")
|
63 |
+
image_bytes = buffered.getvalue()
|
64 |
+
|
65 |
+
# Create image part
|
66 |
+
image_part = {
|
67 |
+
"mime_type": "image/jpeg",
|
68 |
+
"data": image_bytes
|
69 |
+
}
|
70 |
+
|
71 |
+
# Generate content
|
72 |
+
response = gemini_model.generate_content([
|
73 |
+
"Please transcribe all text visible in this image in markdown format. Return only the transcribed text without any additional commentary. Do not include any icon names such as Pokeball icon or Clipboard icon. Only extract the actual text that a human can read directly from the image. Format the output clearly using appropriate Markdown, such as headings, bold text, and paragraphs. The output should contain only the transcribed text, with no additional explanation or description.",
|
74 |
+
image_part
|
75 |
+
])
|
76 |
+
|
77 |
+
logger.info("Gemini OCR completed successfully")
|
78 |
+
return str(response.text)
|
79 |
+
|
80 |
+
except Exception as e:
|
81 |
+
logger.error(f"Gemini OCR error: {e}")
|
82 |
+
return f"Gemini OCR error: {e}"
|
83 |
+
|
84 |
+
def mistral_ocr(image: Image.Image):
|
85 |
+
try:
|
86 |
+
# Convert image to base64
|
87 |
+
buffered = io.BytesIO()
|
88 |
+
image.save(buffered, format="JPEG")
|
89 |
+
img_bytes = buffered.getvalue()
|
90 |
+
base64_image = base64.b64encode(img_bytes).decode('utf-8')
|
91 |
+
|
92 |
+
client = Mistral(api_key=MISTRAL_API_KEY)
|
93 |
+
ocr_response = client.ocr.process(
|
94 |
+
model="mistral-ocr-latest",
|
95 |
+
document={
|
96 |
+
"type": "image_url",
|
97 |
+
"image_url": f"data:image/jpeg;base64,{base64_image}"
|
98 |
+
}
|
99 |
+
)
|
100 |
+
|
101 |
+
# Extract markdown from the first page if available
|
102 |
+
markdown_text = ""
|
103 |
+
if hasattr(ocr_response, 'pages') and ocr_response.pages:
|
104 |
+
page = ocr_response.pages[0]
|
105 |
+
markdown_text = getattr(page, 'markdown', "")
|
106 |
+
|
107 |
+
if not markdown_text:
|
108 |
+
markdown_text = str(ocr_response)
|
109 |
+
|
110 |
+
logger.info("Mistral OCR completed successfully")
|
111 |
+
return markdown_text
|
112 |
+
|
113 |
+
except Exception as e:
|
114 |
+
logger.error(f"Mistral OCR error: {e}")
|
115 |
+
return f"Mistral OCR error: {e}"
|
116 |
+
|
117 |
+
def process_image(image):
|
118 |
+
if image is None:
|
119 |
+
return "Please upload an image.", "Please upload an image."
|
120 |
+
|
121 |
+
try:
|
122 |
+
gemini_result = gemini_ocr(image)
|
123 |
+
mistral_result = mistral_ocr(image)
|
124 |
+
return gemini_result, mistral_result
|
125 |
+
except Exception as e:
|
126 |
+
logger.error(f"Error processing image: {e}")
|
127 |
+
return f"Error processing image: {e}", f"Error processing image: {e}"
|
128 |
+
|
129 |
+
# Helper to get a unique image id (hash of image bytes)
|
130 |
+
def get_image_id(image):
|
131 |
+
buffered = io.BytesIO()
|
132 |
+
image.save(buffered, format="JPEG")
|
133 |
+
return str(hash(buffered.getvalue()))
|
134 |
+
|
135 |
+
# Create the Gradio interface
|
136 |
+
with gr.Blocks(title="OCR Comparison: Gemini vs Mistral", css="""
|
137 |
+
.output-box {
|
138 |
+
border: 2px solid #e0e0e0;
|
139 |
+
border-radius: 8px;
|
140 |
+
padding: 15px;
|
141 |
+
margin: 10px 0;
|
142 |
+
background-color: #f9f9f9;
|
143 |
+
min-height: 200px;
|
144 |
+
}
|
145 |
+
.output-box:hover {
|
146 |
+
border-color: #007bff;
|
147 |
+
box-shadow: 0 2px 8px rgba(0,123,255,0.1);
|
148 |
+
}
|
149 |
+
""") as demo:
|
150 |
+
gr.Markdown("# 🔍 OCR Comparison: Gemini vs Mistral")
|
151 |
+
gr.Markdown("Upload an image to compare OCR results from Gemini and Mistral")
|
152 |
+
|
153 |
+
# Authentication section
|
154 |
+
with gr.Row():
|
155 |
+
with gr.Column(scale=3):
|
156 |
+
username_display = gr.Textbox(
|
157 |
+
label="Current User",
|
158 |
+
placeholder="Please login with your Hugging Face account to vote",
|
159 |
+
interactive=False,
|
160 |
+
show_label=False
|
161 |
+
)
|
162 |
+
with gr.Column(scale=1):
|
163 |
+
login_button = gr.LoginButton()
|
164 |
+
|
165 |
+
with gr.Row():
|
166 |
+
with gr.Column():
|
167 |
+
gemini_output = gr.Markdown(label="🤖 Gemini OCR Output", elem_classes=["output-box"])
|
168 |
+
gemini_vote_btn = gr.Button("A is better", variant="primary", size="sm")
|
169 |
+
|
170 |
+
image_input = gr.Image(type="pil", label="Upload or Paste Image")
|
171 |
+
|
172 |
+
with gr.Column():
|
173 |
+
mistral_output = gr.Markdown(label="🦅 Mistral OCR Output", elem_classes=["output-box"])
|
174 |
+
mistral_vote_btn = gr.Button("B is better", variant="primary", size="sm")
|
175 |
+
|
176 |
+
with gr.Row():
|
177 |
+
process_btn = gr.Button("🔍 Run OCR", variant="primary")
|
178 |
+
|
179 |
+
# Vote functions
|
180 |
+
def vote_gemini(image, username):
|
181 |
+
if not username:
|
182 |
+
gr.Info("Please login with your Hugging Face account to vote.")
|
183 |
+
return
|
184 |
+
image_id = get_image_id(image)
|
185 |
+
with votes_lock:
|
186 |
+
for v in votes:
|
187 |
+
if v["image_id"] == image_id and v["username"] == username:
|
188 |
+
gr.Info("You already voted for this image.")
|
189 |
+
return
|
190 |
+
votes.append({"image_id": image_id, "username": username, "winner": "gemini"})
|
191 |
+
info_message = (
|
192 |
+
f"<p>You voted for <strong style='color:green;'>👈 Gemini OCR</strong>.</p>"
|
193 |
+
f"<p><span style='color:green;'>👈 Gemini OCR</span> - "
|
194 |
+
f"<span style='color:blue;'>👉 Mistral OCR</span></p>"
|
195 |
+
)
|
196 |
+
gr.Info(info_message)
|
197 |
+
|
198 |
+
def vote_mistral(image, username):
|
199 |
+
if not username:
|
200 |
+
gr.Info("Please login with your Hugging Face account to vote.")
|
201 |
+
return
|
202 |
+
image_id = get_image_id(image)
|
203 |
+
with votes_lock:
|
204 |
+
for v in votes:
|
205 |
+
if v["image_id"] == image_id and v["username"] == username:
|
206 |
+
gr.Info("You already voted for this image.")
|
207 |
+
return
|
208 |
+
votes.append({"image_id": image_id, "username": username, "winner": "mistral"})
|
209 |
+
info_message = (
|
210 |
+
f"<p>You voted for <strong style='color:blue;'>👉 Mistral OCR</strong>.</p>"
|
211 |
+
f"<p><span style='color:green;'>👈 Gemini OCR</span> - "
|
212 |
+
f"<span style='color:blue;'>👉 Mistral OCR</span></p>"
|
213 |
+
)
|
214 |
+
gr.Info(info_message)
|
215 |
+
|
216 |
+
# Event handlers
|
217 |
+
process_btn.click(
|
218 |
+
process_image,
|
219 |
+
inputs=[image_input],
|
220 |
+
outputs=[gemini_output, mistral_output],
|
221 |
+
)
|
222 |
+
|
223 |
+
gemini_vote_btn.click(
|
224 |
+
vote_gemini,
|
225 |
+
inputs=[image_input, login_button]
|
226 |
+
)
|
227 |
+
|
228 |
+
mistral_vote_btn.click(
|
229 |
+
vote_mistral,
|
230 |
+
inputs=[image_input, login_button]
|
231 |
+
)
|
232 |
+
|
233 |
+
# Update username display when user logs in
|
234 |
+
demo.load(fn=get_default_username, inputs=None, outputs=username_display)
|
235 |
+
|
236 |
+
if __name__ == "__main__":
|
237 |
+
logger.info("Starting OCR Comparison App...")
|
238 |
+
try:
|
239 |
+
# Try to launch on localhost first
|
240 |
+
demo.launch(server_name="127.0.0.1", server_port=7860, share=False)
|
241 |
+
except ValueError as e:
|
242 |
+
logger.warning(f"Localhost not accessible: {e}")
|
243 |
+
logger.info("Launching with public URL...")
|
244 |
+
demo.launch(share=True)
|
requirements.txt
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio
|
2 |
+
pillow
|
3 |
+
mistralai
|
4 |
+
uvicorn
|
5 |
+
numpy<2
|
6 |
+
google-generativeai
|
7 |
+
supabase
|
8 |
+
python-dotenv
|
9 |
+
websockets
|