QalamV0.2 / app.py
gagan3012's picture
Upload 2 files
9fefb79
raw
history blame
4.79 kB
import streamlit as st
from streamlit_cropper import st_cropper
from PIL import Image
from transformers import TrOCRProcessor, VisionEncoderDecoderModel, DonutProcessor
import torch
import re
import pytesseract
def predict_arabic(img, model_name="UBC-NLP/Qalam"):
# if img is None:
# _,generated_text=main(image)
# return generated_text
# else:
# model_name = "UBC-NLP/Qalam"
processor = TrOCRProcessor.from_pretrained(model_name)
model = VisionEncoderDecoderModel.from_pretrained(model_name)
images = img.convert("RGB")
pixel_values = processor(images, return_tensors="pt").pixel_values
generated_ids = model.generate(pixel_values, max_length=256)
generated_text = processor.batch_decode(
generated_ids, skip_special_tokens=True)[0]
return generated_text
def predict_english(img, model_name="naver-clova-ix/donut-base-finetuned-cord-v2"):
processor = DonutProcessor.from_pretrained(model_name)
model = VisionEncoderDecoderModel.from_pretrained(model_name)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
task_prompt = "<s_cord-v2>"
decoder_input_ids = processor.tokenizer(
task_prompt, add_special_tokens=False, return_tensors="pt").input_ids
image = img.convert("RGB")
pixel_values = processor(image, return_tensors="pt").pixel_values
outputs = model.generate(
pixel_values.to(device),
decoder_input_ids=decoder_input_ids.to(device),
max_length=model.decoder.config.max_position_embeddings,
early_stopping=True,
pad_token_id=processor.tokenizer.pad_token_id,
eos_token_id=processor.tokenizer.eos_token_id,
use_cache=True,
num_beams=1,
bad_words_ids=[[processor.tokenizer.unk_token_id]],
return_dict_in_generate=True,
)
sequence = processor.batch_decode(outputs.sequences)[0]
sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(
processor.tokenizer.pad_token, "")
sequence = re.sub(r"<.*?>", "", sequence).strip()
return sequence
def predict_tesseract(img):
text = pytesseract.image_to_string(Image.open(img))
return text
st.set_option('deprecation.showfileUploaderEncoding', False)
st.set_page_config(
page_title="Ex-stream-ly Cool App",
page_icon="🖊️",
layout="wide",
initial_sidebar_state="expanded",
menu_items={
'Get Help': 'https://www.extremelycoolapp.com/help',
'Report a bug': "https://www.extremelycoolapp.com/bug",
'About': "# This is a header. This is an *extremely* cool app!"
}
)
# Upload an image and set some options for demo purposes
st.header("Qalam: A Multilingual OCR System")
img_file = st.sidebar.file_uploader(label='Upload a file', type=['png', 'jpg'])
realtime_update = st.sidebar.checkbox(label="Update in Real Time", value=True)
# box_color = st.sidebar.color_picker(label="Box Color", value='#0000FF')
aspect_choice = st.sidebar.radio(label="Aspect Ratio", options=[
"Free"])
aspect_dict = {
"Free": None
}
aspect_ratio = aspect_dict[aspect_choice]
Lng = st.sidebar.selectbox(label="Language", options=[
"Arabic", "English", "French", "Korean", "Chinese"])
Models = {
"Arabic": "Qalam",
"English": "Donut",
"French": "Tesseract",
"Korean": "Donut",
"Chinese": "Donut"
}
st.sidebar.write("# Model: ", Models[Lng])
if img_file:
img = Image.open(img_file)
if not realtime_update:
st.write("Double click to save crop")
col1, col2 = st.columns(2)
with col1:
st.header("Select Input Image")
# Get a cropped image from the frontend
cropped_img = st_cropper(
img,
realtime_update=realtime_update,
box_color="#FF0000",
aspect_ratio=aspect_ratio,
should_resize_image=True,
)
with col2:
# Manipulate cropped image at will
st.header("Output Image")
# _ = cropped_img.thumbnail((150, 150))
st.image(cropped_img)
button = st.button("Run OCR")
if button:
if Lng == "Arabic":
st.write("# Arabic Text:")
st.write(predict_arabic(cropped_img))
elif Lng == "English":
st.write("# English Text:")
st.write(predict_english(cropped_img))
elif Lng == "French":
st.write("# French Text:")
st.write(predict_tesseract(cropped_img))
elif Lng == "Korean":
st.write("# Korean Text:")
st.write(predict_english(cropped_img))
elif Lng == "Chinese":
st.write("# Chinese Text:")
st.write(predict_english(cropped_img))