import io
import os
import boto3
import traceback
import re
import logging

import gradio as gr
from PIL import Image, ImageDraw

from docquery.document import load_document, ImageDocument
from docquery.ocr_reader import get_ocr_reader
from transformers import AutoTokenizer, AutoModelForQuestionAnswering
from transformers import DonutProcessor, VisionEncoderDecoderModel
from transformers import pipeline

# avoid ssl errors
import ssl

ssl._create_default_https_context = ssl._create_unverified_context

os.environ["TOKENIZERS_PARALLELISM"] = "false"

logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)

# Init models

layoutlm_pipeline = pipeline(
    "document-question-answering",
    model="impira/layoutlm-document-qa",
)
lilt_tokenizer = AutoTokenizer.from_pretrained("SCUT-DLVCLab/lilt-infoxlm-base")
lilt_model = AutoModelForQuestionAnswering.from_pretrained(
    "nielsr/lilt-xlm-roberta-base"
)

donut_processor = DonutProcessor.from_pretrained(
    "naver-clova-ix/donut-base-finetuned-docvqa"
)
donut_model = VisionEncoderDecoderModel.from_pretrained(
    "naver-clova-ix/donut-base-finetuned-docvqa"
)

TEXTRACT = "Textract Query"
LAYOUTLM = "LayoutLM"
DONUT = "Donut"
LILT = "LiLT"


def image_to_byte_array(image: Image) -> bytes:
    image_as_byte_array = io.BytesIO()
    image.save(image_as_byte_array, format="PNG")
    image_as_byte_array = image_as_byte_array.getvalue()
    return image_as_byte_array


def run_textract(question, document):
    logger.info(f"Running Textract model.")
    image_as_byte_base64 = image_to_byte_array(image=document.b)
    response = boto3.client("textract").analyze_document(
        Document={
            "Bytes": image_as_byte_base64,
        },
        FeatureTypes=[
            "QUERIES",
        ],
        QueriesConfig={
            "Queries": [
                {
                    "Text": question,
                    "Pages": [
                        "*",
                    ],
                },
            ]
        },
    )
    logger.info(f"Output of Textract model {response}.")
    for element in response["Blocks"]:
        if element["BlockType"] == "QUERY_RESULT":
            return {
                "score": element["Confidence"],
                "answer": element["Text"],
                # "word_ids": element
            }
    else:
        Exception("No QUERY_RESULT found in the response from Textract.")


def run_layoutlm(question, document):
    logger.info(f"Running layoutlm model.")
    result = layoutlm_pipeline(document.context["image"][0][0], question)[0]
    logger.info(f"Output of layoutlm model {result}.")
    # [{'score': 0.9999411106109619, 'answer': 'LETTER OF CREDIT', 'start': 106, 'end': 108}]
    return {
        "score": result["score"],
        "answer": result["answer"],
        "word_ids": [result["start"], result["end"]],
        "page": 0,
    }


def run_lilt(question, document):
    logger.info(f"Running lilt model.")
    # use this model + tokenizer
    processed_document = document.context["image"][0][1]
    words = [x[0] for x in processed_document]
    boxes = [x[1] for x in processed_document]

    encoding = lilt_tokenizer(
        text=question,
        text_pair=words,
        boxes=boxes,
        add_special_tokens=True,
        return_tensors="pt",
    )
    outputs = lilt_model(**encoding)
    logger.info(f"Output for lilt model {outputs}.")

    answer_start_index = outputs.start_logits.argmax()
    answer_end_index = outputs.end_logits.argmax()

    predict_answer_tokens = encoding.input_ids[
                            0, answer_start_index: answer_end_index + 1
                            ]
    predict_answer = lilt_tokenizer.decode(
        predict_answer_tokens, skip_special_tokens=True
    )
    return {
        "score": "n/a",
        "answer": predict_answer,
        # "word_ids": element
    }


def run_donut(question, document):
    logger.info(f"Running donut model.")
    # prepare encoder inputs
    pixel_values = donut_processor(
        document.context["image"][0][0], return_tensors="pt"
    ).pixel_values

    # prepare decoder inputs
    task_prompt = "<s_docvqa><s_question>{user_input}</s_question><s_answer>"
    prompt = task_prompt.replace("{user_input}", question)
    decoder_input_ids = donut_processor.tokenizer(
        prompt, add_special_tokens=False, return_tensors="pt"
    ).input_ids

    # generate answer
    outputs = donut_model.generate(
        pixel_values,
        decoder_input_ids=decoder_input_ids,
        max_length=donut_model.decoder.config.max_position_embeddings,
        early_stopping=True,
        pad_token_id=donut_processor.tokenizer.pad_token_id,
        eos_token_id=donut_processor.tokenizer.eos_token_id,
        use_cache=True,
        num_beams=1,
        bad_words_ids=[[donut_processor.tokenizer.unk_token_id]],
        return_dict_in_generate=True,
    )
    logger.info(f"Output for donut {outputs}")
    sequence = donut_processor.batch_decode(outputs.sequences)[0]
    sequence = sequence.replace(donut_processor.tokenizer.eos_token, "").replace(
        donut_processor.tokenizer.pad_token, ""
    )
    sequence = re.sub(
        r"<.*?>", "", sequence, count=1
    ).strip()  # remove first task start token

    result = donut_processor.token2json(sequence)
    return {
        "score": "n/a",
        "answer": result["answer"],
        # "word_ids": element
    }


def process_path(path):
    error = None
    if path:
        try:
            document = load_document(path)
            return (
                document,
                gr.update(visible=True, value=document.preview),
                gr.update(visible=True),
                gr.update(visible=False, value=None),
                gr.update(visible=False, value=None),
                None,
            )
        except Exception as e:
            traceback.print_exc()
            error = str(e)
    return (
        None,
        gr.update(visible=False, value=None),
        gr.update(visible=False),
        gr.update(visible=False, value=None),
        gr.update(visible=False, value=None),
        gr.update(visible=True, value=error) if error is not None else None,
        None,
    )


def process_upload(file):
    if file:
        return process_path(file.name)
    else:
        return (
            None,
            gr.update(visible=False, value=None),
            gr.update(visible=False),
            gr.update(visible=False, value=None),
            gr.update(visible=False, value=None),
            None,
        )


def lift_word_boxes(document, page):
    return document.context["image"][page][1]


def expand_bbox(word_boxes):
    if len(word_boxes) == 0:
        return None

    min_x, min_y, max_x, max_y = zip(*[x[1] for x in word_boxes])
    min_x, min_y, max_x, max_y = [min(min_x), min(min_y), max(max_x), max(max_y)]
    return [min_x, min_y, max_x, max_y]


# LayoutLM boxes are normalized to 0, 1000
def normalize_bbox(box, width, height, padding=0.005):
    min_x, min_y, max_x, max_y = [c / 1000 for c in box]
    if padding != 0:
        min_x = max(0, min_x - padding)
        min_y = max(0, min_y - padding)
        max_x = min(max_x + padding, 1)
        max_y = min(max_y + padding, 1)
    return [min_x * width, min_y * height, max_x * width, max_y * height]


MODELS = {
    LAYOUTLM: run_layoutlm,
    DONUT: run_donut,
    # LILT: run_lilt,
    # TEXTRACT: run_textract,
}


def process_question(question, document, model=list(MODELS.keys())[0]):
    if not question or document is None:
        return None, None, None
    logger.info(f"Running for model {model}")
    prediction = MODELS[model](question=question, document=document)
    logger.info(f"Got prediction {prediction}")
    pages = [x.copy().convert("RGB") for x in document.preview]
    text_value = prediction["answer"]
    if "word_ids" in prediction:
        logger.info(f"Setting bounding boxes.")
        image = pages[prediction["page"]]
        draw = ImageDraw.Draw(image, "RGBA")
        word_boxes = lift_word_boxes(document, prediction["page"])
        x1, y1, x2, y2 = normalize_bbox(
            expand_bbox([word_boxes[i] for i in prediction["word_ids"]]),
            image.width,
            image.height,
        )
        draw.rectangle(((x1, y1), (x2, y2)), fill=(0, 255, 0, int(0.4 * 255)))

    return (
        gr.update(visible=True, value=pages),
        gr.update(visible=True, value=prediction),
        gr.update(
            visible=True,
            value=text_value,
        ),
    )


def load_example_document(img, question, model):
    if img is not None:
        document = ImageDocument(Image.fromarray(img), get_ocr_reader())
        preview, answer, answer_text = process_question(question, document, model)
        return document, question, preview, gr.update(visible=True), answer, answer_text
    else:
        return None, None, None, gr.update(visible=False), None, None


CSS = """
#question input {
    font-size: 16px;
}
#url-textbox {
    padding: 0 !important;
}
#short-upload-box .w-full {
    min-height: 10rem !important;
}
/* I think something like this can be used to re-shape
 * the table
 */
/*
.gr-samples-table tr {
    display: inline;
}
.gr-samples-table .p-2 {
    width: 100px;
}
*/
#select-a-file {
    width: 100%;
}
#file-clear {
    padding-top: 2px !important;
    padding-bottom: 2px !important;
    padding-left: 8px !important;
    padding-right: 8px !important;
	margin-top: 10px;
}
.gradio-container .gr-button-primary {
    background: linear-gradient(180deg, #CDF9BE 0%, #AFF497 100%);
    border: 1px solid #B0DCCC;
    border-radius: 8px;
    color: #1B8700;
}
.gradio-container.dark button#submit-button {
    background: linear-gradient(180deg, #CDF9BE 0%, #AFF497 100%);
    border: 1px solid #B0DCCC;
    border-radius: 8px;
    color: #1B8700
}
table.gr-samples-table tr td {
    border: none;
    outline: none;
}
table.gr-samples-table tr td:first-of-type {
    width: 0%;
}
div#short-upload-box div.absolute {
    display: none !important;
}
gradio-app > div > div > div > div.w-full > div, .gradio-app > div > div > div > div.w-full > div {
    gap: 0px 2%;
}
gradio-app div div div div.w-full, .gradio-app div div div div.w-full {
    gap: 0px;
}
gradio-app h2, .gradio-app h2 {
    padding-top: 10px;
}
#answer {
    overflow-y: scroll;
    color: white;
    background: #666;
    border-color: #666;
    font-size: 20px;
    font-weight: bold;
}
#answer span {
    color: white;
}
#answer textarea {
    color:white;
    background: #777;
    border-color: #777;
    font-size: 18px;
}
#url-error input {
    color: red;
}
"""

examples = [
    [
        "scenario-1.png",
        "What is the final consignee?",
    ],
    [
        "scenario-1.png",
        "What are the payment terms?",
    ],
    [
        "scenario-2.png",
        "What is the actual manufacturer?",
    ],
    [
        "scenario-3.png",
        'What is the "ship to" destination?',
    ],
    [
        "scenario-4.png",
        "What is the color?",
    ],
    [
        "scenario-5.png",
        'What is the "said to contain"?',
    ],
    [
        "scenario-5.png",
        'What is the "Net Weight"?',
    ],
    [
        "scenario-5.png",
        'What is the "Freight Collect"?',
    ],
    [
        "bill_of_lading_1.png",
        "What is the shipper?",
    ],
    [
        "japanese-invoice.png",
        "What is the total amount?",
    ],
    [
        "example-10.jpeg",
        "What is mineral water price amount?"
    ]
]

with gr.Blocks(css=CSS) as demo:
    gr.Markdown("# Document Question Answer Comparator")
    gr.Markdown("""
This space compares some of the latest models that can be used commercially.
- [LayoutLM](https://huggingface.co/impira/layoutlm-document-qa) uses text/layout and images. Uses tesseract for OCR. 
- [Donut](https://huggingface.co/naver-clova-ix/donut-base-finetuned-docvqa) OCR free document understanding. Uses vision encoder for OCR and a text decoder for providing the answer  
""")

    document = gr.Variable()
    example_question = gr.Textbox(visible=False)
    example_image = gr.Image(visible=False)

    with gr.Row(equal_height=True):
        with gr.Column():
            with gr.Row():
                gr.Markdown("## 1. Select a file", elem_id="select-a-file")
                img_clear_button = gr.Button(
                    "Clear", variant="secondary", elem_id="file-clear", visible=False
                )
            image = gr.Gallery(visible=False)
            upload = gr.File(label=None, interactive=True, elem_id="short-upload-box")
            gr.Examples(
                examples=examples,
                inputs=[example_image, example_question],
            )

        with gr.Column() as col:
            gr.Markdown("## 2. Ask a question")
            question = gr.Textbox(
                label="Question",
                placeholder="e.g. What is the invoice number?",
                lines=1,
                max_lines=1,
            )
            model = gr.Radio(
                choices=list(MODELS.keys()),
                value=list(MODELS.keys())[0],
                label="Model",
            )

            with gr.Row():
                clear_button = gr.Button("Clear", variant="secondary")
                submit_button = gr.Button(
                    "Submit", variant="primary", elem_id="submit-button"
                )
            with gr.Column():
                output_text = gr.Textbox(
                    label="Top Answer", visible=False, elem_id="answer"
                )
                output = gr.JSON(label="Output", visible=False)

    for cb in [img_clear_button, clear_button]:
        cb.click(
            lambda _: (
                gr.update(visible=False, value=None),
                None,
                gr.update(visible=False, value=None),
                gr.update(visible=False, value=None),
                gr.update(visible=False),
                None,
                None,
                None,
                gr.update(visible=False, value=None),
                None,
            ),
            inputs=clear_button,
            outputs=[
                image,
                document,
                output,
                output_text,
                img_clear_button,
                example_image,
                upload,
                question,
            ],
        )

    upload.change(
        fn=process_upload,
        inputs=[upload],
        outputs=[document, image, img_clear_button, output, output_text],
    )

    question.submit(
        fn=process_question,
        inputs=[question, document, model],
        outputs=[image, output, output_text],
    )

    submit_button.click(
        process_question,
        inputs=[question, document, model],
        outputs=[image, output, output_text],
    )

    model.change(
        process_question,
        inputs=[question, document, model],
        outputs=[image, output, output_text],
    )

    example_image.change(
        fn=load_example_document,
        inputs=[example_image, example_question, model],
        outputs=[document, question, image, img_clear_button, output, output_text],
    )

if __name__ == "__main__":
    demo.launch(enable_queue=False)