from flask import Flask, request, jsonify from flask_cors import CORS from PIL import Image import io import os from transformers import DonutProcessor, VisionEncoderDecoderModel import torch import fitz # PyMuPDF # Initialize Flask app = Flask(__name__) CORS(app) # Load Donut model and processor device = "cpu" processor = DonutProcessor.from_pretrained("naver-clova-ix/donut-base") model = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base").to(device) model.eval() def convert_pdf_to_image(file_stream): doc = fitz.open(stream=file_stream.read(), filetype="pdf") page = doc.load_page(0) pix = page.get_pixmap(dpi=150) img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples) return img @app.route("/ocr", methods=["POST"]) def ocr(): if "file" not in request.files: return jsonify({"error": "No file uploaded"}), 400 file = request.files["file"] filename = file.filename.lower() # Convert input to PIL image if filename.endswith(".pdf"): image = convert_pdf_to_image(file) else: image = Image.open(io.BytesIO(file.read())).convert("RGB") # Preprocess image pixel_values = processor(image, return_tensors="pt").pixel_values.to(device) # Run model with torch.no_grad(): output = model.generate(pixel_values, max_length=512, return_dict_in_generate=True) # Decode output parsed_text = processor.batch_decode(output.sequences)[0] parsed_text = processor.tokenizer.decode(output.sequences[0], skip_special_tokens=True) return jsonify({"text": parsed_text}) @app.route("/", methods=["GET"]) def index(): return "Smart OCR Flask API (Donut-based)" if __name__ == "__main__": app.run(host="0.0.0.0", port=7860)