|
import gradio as gr |
|
import easyocr |
|
import numpy as np |
|
from PIL import Image |
|
from transformers import AutoTokenizer, AutoModelForTokenClassification |
|
import torch |
|
import logging |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
logger.info("Initializing EasyOCR...") |
|
|
|
reader = easyocr.Reader(['en'], gpu=False) |
|
logger.info("EasyOCR initialized.") |
|
|
|
logger.info("Loading nutrition extraction model...") |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("openfoodfacts/nutrition-extractor") |
|
model = AutoModelForTokenClassification.from_pretrained("openfoodfacts/nutrition-extractor") |
|
logger.info("Model loaded successfully.") |
|
|
|
def ocr_extract(image: Image.Image): |
|
""" |
|
Uses EasyOCR to extract text tokens and their bounding boxes from an image. |
|
Returns a list of tokens and corresponding boxes in [left, top, width, height] format. |
|
Bounding box coordinates are cast to int. |
|
""" |
|
|
|
np_image = np.array(image) |
|
results = reader.readtext(np_image) |
|
|
|
tokens = [] |
|
boxes = [] |
|
for bbox, text, confidence in results: |
|
if text.strip(): |
|
tokens.append(text) |
|
|
|
xs = [point[0] for point in bbox] |
|
ys = [point[1] for point in bbox] |
|
left = int(min(xs)) |
|
top = int(min(ys)) |
|
width = int(max(xs) - left) |
|
height = int(max(ys) - top) |
|
boxes.append([left, top, width, height]) |
|
logger.info(f"OCR extracted {len(tokens)} tokens.") |
|
return tokens, boxes |
|
|
|
def predict(image: Image.Image): |
|
""" |
|
Runs OCR with EasyOCR to extract tokens and bounding boxes, |
|
then uses the nutrition extraction model to classify tokens and aggregate nutritional values. |
|
""" |
|
tokens, boxes = ocr_extract(image) |
|
if len(tokens) == 0: |
|
logger.error("No text detected in the image.") |
|
return {"error": "No text detected in the image."} |
|
|
|
|
|
encoding = tokenizer(tokens, boxes=boxes, return_tensors="pt", truncation=True, padding=True) |
|
|
|
try: |
|
outputs = model(**encoding) |
|
except Exception as e: |
|
logger.error(f"Error during model inference: {e}") |
|
return {"error": f"Model inference error: {e}"} |
|
|
|
|
|
predictions = torch.argmax(outputs.logits, dim=2) |
|
extracted_data = {} |
|
for token, pred in zip(tokens, predictions[0].tolist()): |
|
label = model.config.id2label.get(pred, "O").lower() |
|
if label == "o": |
|
continue |
|
|
|
num_str = "".join(filter(lambda c: c.isdigit() or c == '.', token)) |
|
try: |
|
value = float(num_str) |
|
extracted_data[label] = extracted_data.get(label, 0) + value |
|
except ValueError: |
|
continue |
|
|
|
if not extracted_data: |
|
logger.warning("No nutritional information extracted.") |
|
return {"error": "No nutritional information extracted."} |
|
|
|
logger.info(f"Extracted data: {extracted_data}") |
|
return extracted_data |
|
|
|
|
|
demo = gr.Interface( |
|
fn=predict, |
|
inputs=gr.Image(type="pil"), |
|
outputs="json", |
|
title="Nutrition Extractor API with EasyOCR", |
|
description="Upload an image of a nutrition table to extract nutritional values. The pipeline uses EasyOCR to extract tokens and bounding boxes, then processes them with the openfoodfacts/nutrition-extractor model." |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch(share=True) |
|
|