Spaces:
Runtime error
Runtime error
import streamlit as st | |
from streamlit_cropper import st_cropper | |
from PIL import Image | |
from transformers import TrOCRProcessor, VisionEncoderDecoderModel, DonutProcessor | |
import torch | |
import re | |
import pytesseract | |
def predict_arabic(img, model_name="UBC-NLP/Qalam"): | |
# if img is None: | |
# _,generated_text=main(image) | |
# return generated_text | |
# else: | |
# model_name = "UBC-NLP/Qalam" | |
processor = TrOCRProcessor.from_pretrained(model_name) | |
model = VisionEncoderDecoderModel.from_pretrained(model_name) | |
images = img.convert("RGB") | |
pixel_values = processor(images, return_tensors="pt").pixel_values | |
generated_ids = model.generate(pixel_values, max_length=256) | |
generated_text = processor.batch_decode( | |
generated_ids, skip_special_tokens=True)[0] | |
return generated_text | |
def predict_english(img, model_name="naver-clova-ix/donut-base-finetuned-cord-v2"): | |
processor = DonutProcessor.from_pretrained(model_name) | |
model = VisionEncoderDecoderModel.from_pretrained(model_name) | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model.to(device) | |
task_prompt = "<s_cord-v2>" | |
decoder_input_ids = processor.tokenizer( | |
task_prompt, add_special_tokens=False, return_tensors="pt").input_ids | |
image = img.convert("RGB") | |
pixel_values = processor(image, return_tensors="pt").pixel_values | |
outputs = model.generate( | |
pixel_values.to(device), | |
decoder_input_ids=decoder_input_ids.to(device), | |
max_length=model.decoder.config.max_position_embeddings, | |
early_stopping=True, | |
pad_token_id=processor.tokenizer.pad_token_id, | |
eos_token_id=processor.tokenizer.eos_token_id, | |
use_cache=True, | |
num_beams=1, | |
bad_words_ids=[[processor.tokenizer.unk_token_id]], | |
return_dict_in_generate=True, | |
) | |
sequence = processor.batch_decode(outputs.sequences)[0] | |
sequence = sequence.replace(processor.tokenizer.eos_token, "").replace( | |
processor.tokenizer.pad_token, "") | |
sequence = re.sub(r"<.*?>", "", sequence).strip() | |
return sequence | |
def predict_tesseract(img): | |
text = pytesseract.image_to_string(Image.open(img)) | |
return text | |
st.set_option('deprecation.showfileUploaderEncoding', False) | |
st.set_page_config( | |
page_title="Ex-stream-ly Cool App", | |
page_icon="🖊️", | |
layout="wide", | |
initial_sidebar_state="expanded", | |
menu_items={ | |
'Get Help': 'https://www.extremelycoolapp.com/help', | |
'Report a bug': "https://www.extremelycoolapp.com/bug", | |
'About': "# This is a header. This is an *extremely* cool app!" | |
} | |
) | |
# Upload an image and set some options for demo purposes | |
st.header("Qalam: A Multilingual OCR System") | |
img_file = st.sidebar.file_uploader(label='Upload a file', type=['png', 'jpg']) | |
realtime_update = st.sidebar.checkbox(label="Update in Real Time", value=True) | |
# box_color = st.sidebar.color_picker(label="Box Color", value='#0000FF') | |
aspect_choice = st.sidebar.radio(label="Aspect Ratio", options=[ | |
"Free"]) | |
aspect_dict = { | |
"Free": None | |
} | |
aspect_ratio = aspect_dict[aspect_choice] | |
Lng = st.sidebar.selectbox(label="Language", options=[ | |
"Arabic", "English", "French", "Korean", "Chinese"]) | |
Models = { | |
"Arabic": "Qalam", | |
"English": "Donut", | |
"French": "Tesseract", | |
"Korean": "Donut", | |
"Chinese": "Donut" | |
} | |
st.sidebar.write("# Model: ", Models[Lng]) | |
if img_file: | |
img = Image.open(img_file) | |
if not realtime_update: | |
st.write("Double click to save crop") | |
col1, col2 = st.columns(2) | |
with col1: | |
st.header("Select Input Image") | |
# Get a cropped image from the frontend | |
cropped_img = st_cropper( | |
img, | |
realtime_update=realtime_update, | |
box_color="#FF0000", | |
aspect_ratio=aspect_ratio, | |
should_resize_image=True, | |
) | |
with col2: | |
# Manipulate cropped image at will | |
st.header("Output Image") | |
# _ = cropped_img.thumbnail((150, 150)) | |
st.image(cropped_img) | |
button = st.button("Run OCR") | |
if button: | |
if Lng == "Arabic": | |
st.write("# Arabic Text:") | |
st.write(predict_arabic(cropped_img)) | |
elif Lng == "English": | |
st.write("# English Text:") | |
st.write(predict_english(cropped_img)) | |
elif Lng == "French": | |
st.write("# French Text:") | |
st.write(predict_tesseract(cropped_img)) | |
elif Lng == "Korean": | |
st.write("# Korean Text:") | |
st.write(predict_english(cropped_img)) | |
elif Lng == "Chinese": | |
st.write("# Chinese Text:") | |
st.write(predict_english(cropped_img)) | |