hiandrewfisher's picture
Update app.py
f630d31 verified
import gradio as gr
import easyocr
import numpy as np
from PIL import Image
from transformers import AutoTokenizer, AutoModelForTokenClassification
import torch
import logging
# Set up logging for debugging.
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
logger.info("Initializing EasyOCR...")
# Initialize the EasyOCR reader for English.
reader = easyocr.Reader(['en'], gpu=False)
logger.info("EasyOCR initialized.")
logger.info("Loading nutrition extraction model...")
# Load the model using the Hugging Face Transformers pipeline.
# Force CPU inference with device=-1.
tokenizer = AutoTokenizer.from_pretrained("openfoodfacts/nutrition-extractor")
model = AutoModelForTokenClassification.from_pretrained("openfoodfacts/nutrition-extractor")
logger.info("Model loaded successfully.")
def ocr_extract(image: Image.Image):
"""
Uses EasyOCR to extract text tokens and their bounding boxes from an image.
Returns a list of tokens and corresponding boxes in [left, top, width, height] format.
Bounding box coordinates are cast to int.
"""
# Convert PIL image to numpy array.
np_image = np.array(image)
results = reader.readtext(np_image)
tokens = []
boxes = []
for bbox, text, confidence in results:
if text.strip():
tokens.append(text)
# Convert the bounding box (list of 4 points) to [left, top, width, height].
xs = [point[0] for point in bbox]
ys = [point[1] for point in bbox]
left = int(min(xs))
top = int(min(ys))
width = int(max(xs) - left)
height = int(max(ys) - top)
boxes.append([left, top, width, height])
logger.info(f"OCR extracted {len(tokens)} tokens.")
return tokens, boxes
def predict(image: Image.Image):
"""
Runs OCR with EasyOCR to extract tokens and bounding boxes,
then uses the nutrition extraction model to classify tokens and aggregate nutritional values.
"""
tokens, boxes = ocr_extract(image)
if len(tokens) == 0:
logger.error("No text detected in the image.")
return {"error": "No text detected in the image."}
# Prepare inputs: pass the tokens and boxes to the tokenizer.
encoding = tokenizer(tokens, boxes=boxes, return_tensors="pt", truncation=True, padding=True)
try:
outputs = model(**encoding)
except Exception as e:
logger.error(f"Error during model inference: {e}")
return {"error": f"Model inference error: {e}"}
# Get predicted labels for each token.
predictions = torch.argmax(outputs.logits, dim=2)
extracted_data = {}
for token, pred in zip(tokens, predictions[0].tolist()):
label = model.config.id2label.get(pred, "O").lower()
if label == "o":
continue
# Extract numeric value from token.
num_str = "".join(filter(lambda c: c.isdigit() or c == '.', token))
try:
value = float(num_str)
extracted_data[label] = extracted_data.get(label, 0) + value
except ValueError:
continue
if not extracted_data:
logger.warning("No nutritional information extracted.")
return {"error": "No nutritional information extracted."}
logger.info(f"Extracted data: {extracted_data}")
return extracted_data
# Create a Gradio interface that exposes the API.
demo = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil"),
outputs="json",
title="Nutrition Extractor API with EasyOCR",
description="Upload an image of a nutrition table to extract nutritional values. The pipeline uses EasyOCR to extract tokens and bounding boxes, then processes them with the openfoodfacts/nutrition-extractor model."
)
if __name__ == "__main__":
demo.launch(share=True)