import streamlit as st
from streamlit_cropper import st_cropper
from PIL import Image
from transformers import TrOCRProcessor, VisionEncoderDecoderModel, DonutProcessor
import torch
import re
import pytesseract


def predict_arabic(img, model_name="UBC-NLP/Qalam"):
  # if img is None:
  #   _,generated_text=main(image)
  #   return generated_text
  # else:
    # model_name = "UBC-NLP/Qalam"
    processor = TrOCRProcessor.from_pretrained(model_name)
    model = VisionEncoderDecoderModel.from_pretrained(model_name)
    images = img.convert("RGB")
    pixel_values = processor(images, return_tensors="pt").pixel_values
    generated_ids = model.generate(pixel_values, max_length=256)
    generated_text = processor.batch_decode(
        generated_ids, skip_special_tokens=True)[0]
    return generated_text


def predict_english(img, model_name="naver-clova-ix/donut-base-finetuned-cord-v2"):
    processor = DonutProcessor.from_pretrained(model_name)
    model = VisionEncoderDecoderModel.from_pretrained(model_name)

    device = "cuda" if torch.cuda.is_available() else "cpu"
    model.to(device)

    task_prompt = "<s_cord-v2>"
    decoder_input_ids = processor.tokenizer(
        task_prompt, add_special_tokens=False, return_tensors="pt").input_ids

    image = img.convert("RGB")

    pixel_values = processor(image, return_tensors="pt").pixel_values

    outputs = model.generate(
        pixel_values.to(device),
        decoder_input_ids=decoder_input_ids.to(device),
        max_length=model.decoder.config.max_position_embeddings,
        early_stopping=True,
        pad_token_id=processor.tokenizer.pad_token_id,
        eos_token_id=processor.tokenizer.eos_token_id,
        use_cache=True,
        num_beams=1,
        bad_words_ids=[[processor.tokenizer.unk_token_id]],
        return_dict_in_generate=True,
    )

    sequence = processor.batch_decode(outputs.sequences)[0]
    sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(
        processor.tokenizer.pad_token, "")
    sequence = re.sub(r"<.*?>", "", sequence).strip()
    return sequence


def predict_tesseract(img):
    text = pytesseract.image_to_string(Image.open(img))
    return text


st.set_option('deprecation.showfileUploaderEncoding', False)

st.set_page_config(
    page_title="Ex-stream-ly Cool App",
    page_icon="🖊️",
    layout="wide",
    initial_sidebar_state="expanded",
    menu_items={
        'Get Help': 'https://www.extremelycoolapp.com/help',
        'Report a bug': "https://www.extremelycoolapp.com/bug",
        'About': "# This is a header. This is an *extremely* cool app!"
    }
)

# Upload an image and set some options for demo purposes
st.header("Qalam: A Multilingual OCR System")
img_file = st.sidebar.file_uploader(label='Upload a file', type=['png', 'jpg'])
realtime_update = st.sidebar.checkbox(label="Update in Real Time", value=True)
# box_color = st.sidebar.color_picker(label="Box Color", value='#0000FF')
aspect_choice = st.sidebar.radio(label="Aspect Ratio", options=[
                                 "Free"])
aspect_dict = {
    "Free": None
}
aspect_ratio = aspect_dict[aspect_choice]
Lng = st.sidebar.selectbox(label="Language", options=[
    "Arabic", "English", "French", "Korean", "Chinese"])

Models = {
    "Arabic": "Qalam",
    "English": "Donut",
    "French": "Tesseract",
    "Korean": "Donut",
    "Chinese": "Donut"
}

st.sidebar.write("# Model: ", Models[Lng])

if img_file:
    img = Image.open(img_file)
    if not realtime_update:
        st.write("Double click to save crop")

    col1, col2 = st.columns(2)
    with col1:
        st.header("Select Input Image")
    # Get a cropped image from the frontend
        cropped_img = st_cropper(
            img,
            realtime_update=realtime_update,
            box_color="#FF0000",
            aspect_ratio=aspect_ratio,
            should_resize_image=True,
        )

    with col2:
        # Manipulate cropped image at will
        st.header("Output Image")
        # _ = cropped_img.thumbnail((150, 150))
        st.image(cropped_img)
        button = st.button("Run OCR")
        if button:
            if Lng == "Arabic":
                st.write("# Arabic Text:")
                st.write(predict_arabic(cropped_img))
            elif Lng == "English":
                st.write("# English Text:")
                st.write(predict_english(cropped_img))
            elif Lng == "French":
                st.write("# French Text:")
                st.write(predict_tesseract(cropped_img))
            elif Lng == "Korean":
                st.write("# Korean Text:")
                st.write(predict_english(cropped_img))
            elif Lng == "Chinese":
                st.write("# Chinese Text:")
                st.write(predict_english(cropped_img))