File size: 1,784 Bytes
4463d10 739fe18 84982b3 739fe18 3ec26e4 84982b3 739fe18 84982b3 52540c8 3ec26e4 739fe18 84982b3 739fe18 3ec26e4 84982b3 739fe18 84982b3 739fe18 4463d10 3ec26e4 4463d10 3ec26e4 4463d10 84982b3 739fe18 3ec26e4 52540c8 3ec26e4 84982b3 3ec26e4 739fe18 84982b3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 |
# app.py
import gradio as gr
import torch
from PIL import Image
from model import load_model
from utils import preprocess_image, decode_predictions
import os
# Load the model (ensure the path is correct)
MODEL_PATH = "finetuned_recog_model.pth"
FONT_PATH = "NotoSansEthiopic-Regular.ttf" # Path to your font
# Check if model file exists
if not os.path.exists(MODEL_PATH):
raise FileNotFoundError(f"Model file not found at {MODEL_PATH}. Please provide the correct path.")
# Check if font file exists (if you plan to use it for any visualization)
if not os.path.exists(FONT_PATH):
raise FileNotFoundError(f"Font file not found at {FONT_PATH}. Please provide the correct path.")
# Load the model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = load_model(MODEL_PATH, device=device)
def recognize_text(image: Image.Image) -> str:
"""
Function to recognize text from an image.
"""
if image is None:
return "No image provided."
# Preprocess the image
input_tensor = preprocess_image(image).unsqueeze(0).to(device) # [1, 3, 224, 224]
# Perform inference
with torch.no_grad():
log_probs = model(input_tensor) # [H*W, 1, vocab_size]
# Decode predictions
recognized_texts = decode_predictions(log_probs)
# Assuming batch size of 1
return recognized_texts[0]
# Define Gradio Interface
iface = gr.Interface(
fn=recognize_text,
inputs=gr.Image(type="pil", label="Upload Image"),
outputs=gr.Textbox(label="Recognized Amharic Text"),
title="Amharic Text Recognition",
description="Upload an image containing Amharic text, and the model will recognize and display the text."
)
# Launch the Gradio app
if __name__ == "__main__":
iface.launch()
|