|
""" |
|
OCR Models Module |
|
Contains all OCR-related functions for different AI models. |
|
""" |
|
|
|
import google.generativeai as genai |
|
from mistralai import Mistral |
|
from PIL import Image |
|
import io |
|
import base64 |
|
import logging |
|
import openai |
|
import os |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
def gemini_ocr(image: Image.Image): |
|
"""Process OCR using Google's Gemini 2.0 Flash model.""" |
|
try: |
|
|
|
gemini_model = initialize_gemini() |
|
if not gemini_model: |
|
return "Gemini OCR error: Failed to initialize Gemini model" |
|
|
|
|
|
buffered = io.BytesIO() |
|
image.save(buffered, format="JPEG") |
|
img_bytes = buffered.getvalue() |
|
base64_image = base64.b64encode(img_bytes).decode('utf-8') |
|
|
|
|
|
image_part = { |
|
"mime_type": "image/jpeg", |
|
"data": base64_image |
|
} |
|
|
|
|
|
response = gemini_model.generate_content([ |
|
"Extract and transcribe all text from this image. Return only the transcribed text in markdown format, preserving any formatting like headers, lists, etc.", |
|
image_part |
|
]) |
|
|
|
markdown_text = response.text |
|
logger.info("Gemini OCR completed successfully") |
|
return markdown_text |
|
|
|
except Exception as e: |
|
logger.error(f"Gemini OCR error: {e}") |
|
return f"Gemini OCR error: {e}" |
|
|
|
def mistral_ocr(image: Image.Image): |
|
"""Process OCR using Mistral AI's OCR model.""" |
|
try: |
|
|
|
buffered = io.BytesIO() |
|
image.save(buffered, format="JPEG") |
|
img_bytes = buffered.getvalue() |
|
base64_image = base64.b64encode(img_bytes).decode('utf-8') |
|
|
|
client = Mistral(api_key=os.getenv("MISTRAL_API_KEY")) |
|
ocr_response = client.ocr.process( |
|
model="mistral-ocr-latest", |
|
document={ |
|
"type": "image_url", |
|
"image_url": f"data:image/jpeg;base64,{base64_image}" |
|
} |
|
) |
|
|
|
|
|
markdown_text = "" |
|
if hasattr(ocr_response, 'pages') and ocr_response.pages: |
|
page = ocr_response.pages[0] |
|
markdown_text = getattr(page, 'markdown', "") |
|
|
|
if not markdown_text: |
|
markdown_text = str(ocr_response) |
|
|
|
logger.info("Mistral OCR completed successfully") |
|
return markdown_text |
|
|
|
except Exception as e: |
|
logger.error(f"Mistral OCR error: {e}") |
|
return f"Mistral OCR error: {e}" |
|
|
|
def openai_ocr(image: Image.Image): |
|
"""Process OCR using OpenAI's GPT-4o model.""" |
|
try: |
|
|
|
buffered = io.BytesIO() |
|
image.save(buffered, format="PNG") |
|
img_bytes = buffered.getvalue() |
|
base64_image = base64.b64encode(img_bytes).decode('utf-8') |
|
image_data_url = f"data:image/png;base64,{base64_image}" |
|
|
|
|
|
response = openai.chat.completions.create( |
|
model="gpt-4o", |
|
messages=[ |
|
{ |
|
"role": "user", |
|
"content": [ |
|
{"type": "text", "text": "Extract and transcribe all text from this image. Return only the transcribed text in markdown format, preserving any formatting like headers, lists, etc."}, |
|
{"type": "image_url", "image_url": {"url": image_data_url}} |
|
] |
|
} |
|
] |
|
) |
|
|
|
markdown_text = response.choices[0].message.content |
|
logger.info("OpenAI OCR completed successfully") |
|
return markdown_text |
|
|
|
except Exception as e: |
|
logger.error(f"OpenAI OCR error: {e}") |
|
return f"OpenAI OCR error: {e}" |
|
|
|
def gpt5_ocr(image: Image.Image): |
|
"""Process OCR using OpenAI's GPT-5 model with the same prompt.""" |
|
try: |
|
|
|
buffered = io.BytesIO() |
|
image.save(buffered, format="PNG") |
|
img_bytes = buffered.getvalue() |
|
base64_image = base64.b64encode(img_bytes).decode('utf-8') |
|
image_data_url = f"data:image/png;base64,{base64_image}" |
|
|
|
|
|
response = openai.chat.completions.create( |
|
model="gpt-5", |
|
messages=[ |
|
{ |
|
"role": "user", |
|
"content": [ |
|
{"type": "text", "text": "Extract and transcribe all text from this image. Return only the transcribed text in markdown format, preserving any formatting like headers, lists, etc."}, |
|
{"type": "image_url", "image_url": {"url": image_data_url}} |
|
] |
|
} |
|
] |
|
) |
|
|
|
markdown_text = response.choices[0].message.content |
|
logger.info("GPT-5 OCR completed successfully") |
|
return markdown_text |
|
except Exception as e: |
|
logger.error(f"GPT-5 OCR error: {e}") |
|
return f"GPT-5 OCR error: {e}" |
|
|
|
def process_model_ocr(image, model_name): |
|
"""Process OCR for a specific model.""" |
|
if model_name == "gemini": |
|
return gemini_ocr(image) |
|
elif model_name == "mistral": |
|
return mistral_ocr(image) |
|
elif model_name == "openai": |
|
return openai_ocr(image) |
|
elif model_name == "gpt5": |
|
return gpt5_ocr(image) |
|
else: |
|
return f"Unknown model: {model_name}" |
|
|
|
|
|
def initialize_gemini(): |
|
"""Initialize the Gemini model with API key.""" |
|
gemini_api_key = os.getenv("GEMINI_API_KEY") |
|
if gemini_api_key: |
|
genai.configure(api_key=gemini_api_key) |
|
logger.info("✅ GEMINI_API_KEY loaded successfully") |
|
return genai.GenerativeModel('gemini-2.0-flash-exp') |
|
else: |
|
logger.error("❌ GEMINI_API_KEY not found in environment variables") |
|
return None |
|
|
|
|
|
def initialize_openai(): |
|
"""Initialize OpenAI with API key.""" |
|
openai_api_key = os.getenv("OPENAI_API_KEY") |
|
if openai_api_key: |
|
openai.api_key = openai_api_key |
|
logger.info("✅ OPENAI_API_KEY loaded successfully") |
|
else: |
|
logger.error("❌ OPENAI_API_KEY not found in environment variables") |
|
|
|
|
|
def initialize_mistral(): |
|
"""Initialize Mistral with API key.""" |
|
mistral_api_key = os.getenv("MISTRAL_API_KEY") |
|
if mistral_api_key: |
|
logger.info("✅ MISTRAL_API_KEY loaded successfully") |
|
else: |
|
logger.error("❌ MISTRAL_API_KEY not found in environment variables") |