# Copyright (c) 2022, National Diet Library, Japan
#
# This software is released under the CC BY 4.0.
# https://creativecommons.org/licenses/by/4.0/


import copy
import cv2
import os


class BaseInferenceProcess:
    """
    各推論処理を実行するプロセスクラスを作るためのメタクラス。

    Attributes
    ----------
    proc_name : str
        推論処理を実行するインスタンスが持つプロセス名。
        [実行される順序を表す数字+クラスごとの処理名]で構成されます。
    cfg : dict
        本推論実行における設定情報です。
    """
    def __init__(self, cfg, pid, proc_type='_base_prep'):
        """
        Parameters
        ----------
        cfg : dict
            本実行処理における設定情報です。
        pid : int
            実行される順序を表す数値。
        proc_type : str
            クラスごとに定義されている処理名。
        """
        self.proc_name = str(pid) + proc_type

        if not self._is_valid_cfg(cfg):
            raise ValueError('Configuration validation error.')
        else:
            self.cfg = cfg

        self.process_dump_dir = None

        return True

    def do(self, data_idx, input_data):
        """
        推論処理を実行する際にOcrInferencerクラスから呼び出される推論実行関数。
        入力データのバリデーションや推論処理、推論結果の保存などが含まれます。
        本処理は基本的に継承先では変更されないことを想定しています。

        Parameters
        ----------
        data_idx : int
            入力データのインデックス。
            画像ファイル1つごとに入力データのリストが構成されます。
        input_data : dict
            推論処理を実行すつ対象の入力データ。

        Returns
        -------
        result : dict
            推論処理の結果を保持する辞書型データ。
            基本的にinput_dataと同じ構造です。
        """
        # input data valudation check
        if not self._is_valid_input(input_data):
            raise ValueError('Input data validation error.')

        # run main inference process
        result = self._run_process(input_data)
        if result is None:
            raise ValueError('Inference output error in {0}.'.format(self.proc_name))

        # dump inference result
        if self.cfg['dump']:
            self._dump_result(input_data, result, data_idx)

        return result

    def _run_process(self, input_data):
        """
        推論処理の本体部分。
        処理内容は継承先のクラスで実装されることを想定しています。

        Parameters
        ----------
        input_data : dict
            推論処理を実行する対象の入力データ。

        Returns
        -------
        result : dict
            推論処理の結果を保持する辞書型データ。
            基本的にinput_dataと同じ構造です。
        """
        print('### Base Inference Process ###')
        result = copy.deepcopy(input_data)
        return result

    def _is_valid_cfg(self, cfg):
        """
        推論処理全体の設定情報ではなく、クラス単位の設定情報に対するバリデーション。
        バリデーションの内容は継承先のクラスで実装されることを想定しています。

        Parameters
        ----------
        cfg : dict
            本推論実行における設定情報です。

        Returns
        -------
        [変数なし] : bool
            設定情報が正しければTrue, そうでなければFalseを返します。
        """
        if cfg is None:
            print('Given configuration data is None.')
            return False
        return True

    def _is_valid_input(self, input_data):
        """
        本クラスの推論処理における入力データのバリデーション。
        バリデーションの内容は継承先のクラスで実装されることを想定しています。

        Parameters
        ----------
        input_data : dict
            推論処理を実行する対象の入力データ。

        Returns
        -------
        [変数なし] : bool
             入力データが正しければTrue, そうでなければFalseを返します。
        """
        return True

    def _dump_result(self, input_data, result, data_idx):
        """
        本クラスの推論処理結果をファイルに保存します。
        dumpフラグが有効の場合にのみ実行されます。

        Parameters
        ----------
        input_data : dict
            推論処理に利用した入力データ。
        result : list
            推論処理の結果を保持するリスト型データ。
            各要素は基本的にinput_dataと同じ構造の辞書型データです。
        data_idx : int
            入力データのインデックス。
            画像ファイル1つごとに入力データのリストが構成されます。
        """

        self.process_dump_dir = os.path.join(os.path.join(input_data['output_dir'], 'dump'), self.proc_name)

        for i, single_result in enumerate(result):
            if 'img' in single_result.keys() and single_result['img'] is not None:
                dump_img_name = os.path.basename(input_data['img_path']).split('.')[0] + '_' + str(data_idx) + '_' + str(i) + '.jpg'
                self._dump_img_result(single_result, input_data['output_dir'], dump_img_name)
            if 'xml' in single_result.keys() and single_result['xml'] is not None:
                dump_xml_name = os.path.basename(input_data['img_path']).split('.')[0] + '_' + str(data_idx) + '_' + str(i) + '.xml'
                self._dump_xml_result(single_result, input_data['output_dir'], dump_xml_name)
            if 'txt' in single_result.keys() and single_result['txt'] is not None:
                dump_txt_name = os.path.basename(input_data['img_path']).split('.')[0] + '_' + str(data_idx) + '_' + str(i) + '.txt'
                self._dump_txt_result(single_result, input_data['output_dir'], dump_txt_name)
        return

    def _dump_img_result(self, single_result, output_dir, img_name):
        """
        本クラスの推論処理結果(画像)をファイルに保存します。
        dumpフラグが有効の場合にのみ実行されます。

        Parameters
        ----------
        single_result : dict
            推論処理の結果を保持する辞書型データ。
        output_dir : str
            推論結果が保存されるディレクトリのパス。
        img_name : str
            入力データの画像ファイル名。
            dumpされる画像ファイルのファイル名は入力のファイル名と同名(複数ある場合は連番を付与)となります。
        """
        pred_img_dir = os.path.join(self.process_dump_dir, 'pred_img')
        os.makedirs(pred_img_dir, exist_ok=True)
        image_file_path = os.path.join(pred_img_dir, img_name)
        dump_image = self._create_result_image(single_result)
        try:
            cv2.imwrite(image_file_path, dump_image)
        except OSError as err:
            print("Dump image save error: {0}".format(err))
            raise OSError

        return

    def _dump_xml_result(self, single_result, output_dir, img_name):
        """
        本クラスの推論処理結果(XML)をファイルに保存します。
        dumpフラグが有効の場合にのみ実行されます。

        Parameters
        ----------
        single_result : dict
            推論処理の結果を保持する辞書型データ。
        output_dir : str
            推論結果が保存されるディレクトリのパス。
        img_name : str
            入力データの画像ファイル名。
            dumpされるXMLファイルのファイル名は入力のファイル名とほぼ同名(拡張子の変更、サフィックスや連番の追加のみ)となります。
        """
        xml_dir = os.path.join(self.process_dump_dir, 'xml')
        os.makedirs(xml_dir, exist_ok=True)
        trum, _ = os.path.splitext(img_name)
        xml_path = os.path.join(xml_dir, trum + '.xml')
        try:
            single_result['xml'].write(xml_path, encoding='utf-8', xml_declaration=True)
        except OSError as err:
            print("Dump xml save error: {0}".format(err))
            raise OSError

        return

    def _dump_txt_result(self, single_result, output_dir, img_name):
        """
        本クラスの推論処理結果(テキスト)をファイルに保存します。
        dumpフラグが有効の場合にのみ実行されます。

        Parameters
        ----------
        single_result : dict
            推論処理の結果を保持する辞書型データ。
        output_dir : str
            推論結果が保存されるディレクトリのパス。
        img_name : str
            入力データの画像ファイル名。
            dumpされるテキストファイルのファイル名は入力のファイル名とほぼ同名(拡張子の変更、サフィックスや連番の追加のみ)となります。
        """
        txt_dir = os.path.join(self.process_dump_dir, 'txt')
        os.makedirs(txt_dir, exist_ok=True)

        trum, _ = os.path.splitext(img_name)
        txt_path = os.path.join(txt_dir, trum + '_main.txt')
        try:
            with open(txt_path, 'w') as f:
                f.write(single_result['txt'])
        except OSError as err:
            print("Dump text save error: {0}".format(err))
            raise OSError

        return

    def _create_result_image(self, single_result):
        """
        推論結果を入力の画像に重畳した画像データを生成します。

        Parameters
        ----------
        single_result : dict
            推論処理の結果を保持する辞書型データ。
        """
        dump_img = None
        if 'dump_img' in single_result.keys():
            dump_img = copy.deepcopy(single_result['dump_img'])
        else:
            dump_img = copy.deepcopy(single_result['img'])
        if 'xml' in single_result.keys() and single_result['xml'] is not None:
            # draw single inferenceresult on input image
            # this should be implemeted in each child class
            cv2.putText(dump_img, 'dump' + self.proc_name, (0, 50),
                        cv2.FONT_HERSHEY_PLAIN, 4, (255, 0, 0), 5, cv2.LINE_AA)
            pass
        else:
            cv2.putText(dump_img, 'dump' + self.proc_name, (0, 50),
                        cv2.FONT_HERSHEY_PLAIN, 4, (255, 255, 0), 5, cv2.LINE_AA)
        return dump_img