from .utils import load_model,load_processor,normalize_box,compare_boxes,adjacent
from .model_base_path import LAYOUTLMV2_BASE_PATH,LAYOUTLMV3_BASE_PATH 
from .annotate_image import get_flattened_output,annotate_image
from PIL import Image,ImageDraw, ImageFont
import logging
import torch
import json
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")



logger = logging.getLogger(__name__)

class ModelHandler(object):
    """
    A base Model handler implementation.
    """

    def __init__(self):
        # self.model = None
        # self.model_dir = None
        # self.device = 'cpu'
        # self.error = None
        # self._context = None
        # self._batch_size = 0
        self.initialized = False
        self._raw_input_data = None
        self._processed_data = None
        self._images_size = None

    def initialize(self, context,preprocessor,name):
        """
        Initialize model. This will be called during model loading time
        :param context: Initial context contains model server system properties.
        :return:
        """
        logger.info("Loading transformer model")

        # self._context = context
        # properties = self._context
        # self._batch_size = properties["batch_size"] or 1
        # self.model_dir = properties.get("model_dir")
        self.name = name
        self.model = context
        self.preprocessor = preprocessor
        self.initialized = True

    def preprocess(self, batch):
        """
        Transform raw input into model input data.
        :param batch: list of raw requests, should match batch size
        :return: list of preprocessed model input data
        """
        # Take the input data and pre-process it make it inference ready
        # assert self._batch_size == len(batch), "Invalid input batch size: {}".format(len(batch))
        inference_dict = batch
        
        print("inference_dict",inference_dict)
        self._raw_input_data = inference_dict
        # model_name_or_path = None
        # if 'v2' in self.model.config.architectures[0]:
        #     model_name_or_path = LAYOUTLMV2_BASE_PATH
        # elif 'v3' in self.model.config.architectures[0]:
        #     model_name_or_path = LAYOUTLMV3_BASE_PATH
        # else:
        #     raise ValueError('invalid model architecture, please make sure the model is either Layoutlmv2 or Layoutlmv3')
        # processor = load_processor(model_name_or_path)
        processor = self.preprocessor

        images = [Image.open(path).convert("RGB")
                  for path in inference_dict['image_path']]
        self._images_size = [img.size for img in images]
        words = inference_dict['words']
        boxes = [[normalize_box(box, images[i].size[0], images[i].size[1])
                  for box in doc] for i, doc in enumerate(inference_dict['bboxes'])]
        encoded_inputs = processor(
            images, words, boxes=boxes, return_tensors="pt", padding="max_length", truncation=True)
        self._processed_data = encoded_inputs
        encoded_inputs = {key: val.to(device) for key, val in encoded_inputs.items()}
        print("encoded_inputs",encoded_inputs)

        return encoded_inputs

    def load(self, model_dir):
        """The load handler is responsible for loading the hunggingface transformer model.
        Returns:
            hf_pipeline (Pipeline): A Hugging Face Transformer pipeline.
        """
        # TODO model dir should be microsoft/layoutlmv2-base-uncased
        model = load_model(model_dir)
        return model

    def inference(self, model_input):
        """
        Internal inference methods
        :param model_input: transformed model input data
        :return: list of inference output in NDArray
        """
        # TODO load the model state_dict before running the inference
        # Do some inference call to engine here and return output
        with torch.no_grad():
            inference_outputs = self.model(**model_input)
            predictions = inference_outputs.logits.argmax(-1).tolist()
        print("these are predictions",predictions)
        results = []
        for i in range(len(predictions)):
            tmp = dict()
            tmp[f'output_{i}'] = predictions[i]
            results.append(tmp)

        return [results]

    def postprocess(self, inference_output):
        print("self._raw_input_data['words']",self._raw_input_data['words'])
        print("inference_output",inference_output)
        
        
        docs = []
        k = 0
        for page, doc_words in enumerate(self._raw_input_data['words']):
            print(page,doc_words)
            doc_list = []
            width, height = self._images_size[page]
            for i, doc_word in enumerate(doc_words, start=0):
                word_tagging = None
                word_labels = []
                word = dict()
                word['id'] = k
                k += 1
                word['text'] = doc_word
                word['pageNum'] = page + 1
                word['box'] = self._raw_input_data['bboxes'][page][i]
                _normalized_box = normalize_box(
                    self._raw_input_data['bboxes'][page][i], width, height)
                for j, box in enumerate(self._processed_data['bbox'].tolist()[page]):
                    if compare_boxes(box, _normalized_box):
                        if self.model.config.id2label[inference_output[0][page][f'output_{page}'][j]] != 'O':
                            word_labels.append(
                                self.model.config.id2label[inference_output[0][page][f'output_{page}'][j]][2:])
                        else:
                            word_labels.append('other')
                if word_labels != []:
                    word_tagging = word_labels[0] if word_labels[0] != 'other' else word_labels[-1]
                else:
                    word_tagging = 'other'
                word['label'] = word_tagging
                word['pageSize'] = {'width': width, 'height': height}
                if word['label'] != 'other':
                    doc_list.append(word)
            spans = []
            def adjacents(entity): return [
                adj for adj in doc_list if adjacent(entity, adj)]
            output_test_tmp = doc_list[:]
            for entity in doc_list:
                if adjacents(entity) == []:
                    spans.append([entity])
                    output_test_tmp.remove(entity)

            while output_test_tmp != []:
                span = [output_test_tmp[0]]
                output_test_tmp = output_test_tmp[1:]
                while output_test_tmp != [] and adjacent(span[-1], output_test_tmp[0]):
                    span.append(output_test_tmp[0])
                    output_test_tmp.remove(output_test_tmp[0])
                spans.append(span)

            output_spans = []
            for span in spans:
                if len(span) == 1:
                    output_span = {"text": span[0]['text'],
                                   "label": span[0]['label'],
                                   "words": [{
                                       'id': span[0]['id'],
                                       'box': span[0]['box'],
                                       'text': span[0]['text']
                                   }],
                                   }
                else:
                    output_span = {"text": ' '.join([entity['text'] for entity in span]),
                                   "label": span[0]['label'],
                                   "words": [{
                                       'id': entity['id'],
                                       'box': entity['box'],
                                       'text': entity['text']
                                   } for entity in span]

                                   }
                output_spans.append(output_span)
            docs.append({f'output': output_spans})
        return [json.dumps(docs, ensure_ascii=False)]

    def handle(self, data, context):
        """
        Call preprocess, inference and post-process functions
        :param data: input data
        :param context: mms context
        """
        # print("\nmodel_input\n",data)
        print("context",context)
        
        model_input = self.preprocess(data)
        print("this is model input",model_input)
        model_out = self.inference(model_input)
        print("\nmodel_output\n",model_out)
        inference_out = self.postprocess(model_out)[0]

        print("\nprocessed output\n",inference_out)

        # with open('LayoutlMV3InferenceOutput.json', 'w') as inf_out:
        #     inf_out.write(inference_out)
        inference_out_list = json.loads(inference_out)
        flattened_output_list = get_flattened_output(inference_out_list)
        print("flattened_output_list",flattened_output_list)
        
        
        
        if self.name == "cheque":
            acc_num = "".join(item['text'] for item in flattened_output_list[0]['output'] if item['label'] == 'AN')
            IFSC = "".join(item['text'] for item in flattened_output_list[0]['output'] if item['label'] == 'IFSC')
            
            print("entered cheque\n\n",flattened_output_list,"\n\n")
            result = {"attachment_num":acc_num,
                      "attachment_ifsc":IFSC,
                      "attachment_status":200}
        if self.name  == "aadhar":
            # aadhar_num = "".join(item['text'] for item in flattened_output_list[0]['output'] if item['label'] == 'AN')
            output_ls = []

            for item in flattened_output_list[0]['output']:
                if item['label'] == 'AN' and item['text'] not in output_ls:
                    print("outputls",output_ls)
                    print("item['text']",item['text'])
                    
                    output_ls.append(item['text'])
            print("output_ls aadhar",output_ls)
            aadhar_num = "".join(item for item in output_ls)

            print("entered aadhar\n\n",flattened_output_list,"\n\n")
                # IFSC = "".join(item['text'] for item in flattened_output_list[0]['output'] if item['label'] == 'IFSC')
            result = {"attachment_num":aadhar_num,
                      "attachment_status":200}
            
        if self.name == "pan":
            pan_num = "".join(item['text'] for item in flattened_output_list[0]['output'] if item['label'] == 'PAN_VALUE')
            print("entered pan\n\n",flattened_output_list,"\n\n")
            # IFSC = "".join(item['text'] for item in flattened_output_list[0]['output'] if item['label'] == 'IFSC')
            result = {"attachment_num":pan_num,
                      "attachment_status":200}
        if self.name == "gst":
            gstin_num = "".join(item['text'] for item in flattened_output_list[0]['output'] if item['label'] == 'GSTIN')
            print("entered gst\n\n",flattened_output_list,"\n\n")
            # IFSC = "".join(item['text'] for item in flattened_output_list[0]['output'] if item['label'] == 'IFSC')
            result = {"attachment_num":gstin_num,
                      "attachment_status":200}
        
            
        # if 
        # an_tokens = "".join(item['text'] for item in flattened_output_list[0]['output'] if item['label'] == 'AN')
        #PAN_VALUE
        #AN
        #IFSC

        # print(f"Concatenated AN tokens: {an_tokens}")

        # print("this is flattened output",flattened_output_list)
        for i, flattened_output in enumerate(flattened_output_list):
            annotate_image(data['image_path'][i], flattened_output)
            
        return result
            


_service = ModelHandler()


def handle(data, context,processor,name):
    # if not _service.initialized:
    _service.initialize(context,processor,name)

    # if data is None:
    #     return None

    return _service.handle(data, context)