LPRforNajm / app.py
MahmoudAbdelmaged's picture
Update app.py
eed7393 verified
import torch
import cv2
import easyocr
import numpy as np
from PIL import Image, ImageDraw
import gradio as gr
# Load YOLOv9 model
model = torch.jit.load("best.torchscript") # Ensure this file exists
model.eval() # Set to evaluation mode
# Initialize EasyOCR reader
reader = easyocr.Reader(['en', 'ar'], gpu=False)
# Define a transformation pipeline for YOLOv9
def transform_image(image):
transform = transforms.Compose([
transforms.Resize((640, 640)), # Resize to model's expected input size
transforms.ToTensor(), # Convert PIL Image to Tensor
])
return transform(image).unsqueeze(0)
# Function to process the uploaded image and extract text
def extract_text_from_image(image):
# Convert numpy array to PIL Image
img_pil = Image.fromarray(image)
# Transform image for YOLOv9 input
img_tensor = transform_image(img_pil)
# Run YOLOv9 model to detect objects
with torch.no_grad():
results = model(img_tensor) # Run inference
# Draw bounding boxes and extract text
img = np.array(img_pil)
draw = ImageDraw.Draw(img)
extracted_text = [] # To store extracted text
if isinstance(results, tuple): # Ensure we are handling correct output
boxes = results[0] # Adjust based on actual YOLO output format
# Iterate through the boxes and draw rectangles
for box in boxes:
if isinstance(box, torch.Tensor): # Ensure that 'box' is a tensor
box = box.cpu().numpy() # Convert tensor to numpy array
x1, y1, x2, y2 = box[:4] # Extract coordinates
# Crop the image inside the bounding box
cropped_img = img_pil.crop((x1, y1, x2, y2))
# Run OCR on the cropped image
ocr_result = reader.readtext(np.array(cropped_img))
# Extract the text from the OCR result
for detection in ocr_result:
extracted_text.append(detection[1]) # Append the detected text
# Draw bounding box on the image
draw.rectangle([x1, y1, x2, y2], outline="red", width=3)
# Convert the image to RGB (Gradio requires RGB format)
img = cv2.cvtColor(np.array(img), cv2.COLOR_BGR2RGB)
# Join the extracted text into a single string with line breaks
extracted_text_str = "\n".join(extracted_text)
return img, extracted_text_str
# Define the Gradio interface
interface = gr.Interface(
fn=extract_text_from_image,
inputs=gr.Image(type="numpy", label="Upload Image"),
outputs=[
gr.Image(type="numpy", label="Processed Image"),
gr.Text(label="Extracted Text (Line by Line)")
],
title="Object and Text Extractor",
description="Upload an image to detect objects using YOLOv9 and extract text using EasyOCR.",
)
# Launch the Gradio app
interface.launch()