# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
from mmcv.utils import print_log
from mmdet.datasets.builder import DATASETS
from mmdet.datasets.pipelines import Compose
from torch.utils.data import Dataset

from mmocr.datasets.builder import build_loader


@DATASETS.register_module()
class BaseDataset(Dataset):
    """Custom dataset for text detection, text recognition, and their
    downstream tasks.

    1. The text detection annotation format is as follows:
       The `annotations` field is optional for testing
       (this is one line of anno_file, with line-json-str
        converted to dict for visualizing only).

        {
            "file_name": "sample.jpg",
            "height": 1080,
            "width": 960,
            "annotations":
                [
                    {
                        "iscrowd": 0,
                        "category_id": 1,
                        "bbox": [357.0, 667.0, 804.0, 100.0],
                        "segmentation": [[361, 667, 710, 670,
                                          72, 767, 357, 763]]
                    }
                ]
        }

    2. The two text recognition annotation formats are as follows:
       The `x1,y1,x2,y2,x3,y3,x4,y4` field is used for online crop
       augmentation during training.

        format1: sample.jpg hello
        format2: sample.jpg 20 20 100 20 100 40 20 40 hello

    Args:
        ann_file (str): Annotation file path.
        pipeline (list[dict]): Processing pipeline.
        loader (dict): Dictionary to construct loader
            to load annotation infos.
        img_prefix (str, optional): Image prefix to generate full
            image path.
        test_mode (bool, optional): If set True, try...except will
            be turned off in __getitem__.
    """

    def __init__(self,
                 ann_file,
                 loader,
                 pipeline,
                 img_prefix='',
                 test_mode=False):
        super().__init__()
        self.test_mode = test_mode
        self.img_prefix = img_prefix
        self.ann_file = ann_file
        # load annotations
        loader.update(ann_file=ann_file)
        self.data_infos = build_loader(loader)
        # processing pipeline
        self.pipeline = Compose(pipeline)
        # set group flag and class, no meaning
        # for text detect and recognize
        self._set_group_flag()
        self.CLASSES = 0

    def __len__(self):
        return len(self.data_infos)

    def _set_group_flag(self):
        """Set flag."""
        self.flag = np.zeros(len(self), dtype=np.uint8)

    def pre_pipeline(self, results):
        """Prepare results dict for pipeline."""
        results['img_prefix'] = self.img_prefix

    def prepare_train_img(self, index):
        """Get training data and annotations from pipeline.

        Args:
            index (int): Index of data.

        Returns:
            dict: Training data and annotation after pipeline with new keys
                introduced by pipeline.
        """
        img_info = self.data_infos[index]
        results = dict(img_info=img_info)
        self.pre_pipeline(results)
        return self.pipeline(results)

    def prepare_test_img(self, img_info):
        """Get testing data from pipeline.

        Args:
            idx (int): Index of data.

        Returns:
            dict: Testing data after pipeline with new keys introduced by
                pipeline.
        """
        return self.prepare_train_img(img_info)

    def _log_error_index(self, index):
        """Logging data info of bad index."""
        try:
            data_info = self.data_infos[index]
            img_prefix = self.img_prefix
            print_log(f'Warning: skip broken file {data_info} '
                      f'with img_prefix {img_prefix}')
        except Exception as e:
            print_log(f'load index {index} with error {e}')

    def _get_next_index(self, index):
        """Get next index from dataset."""
        self._log_error_index(index)
        index = (index + 1) % len(self)
        return index

    def __getitem__(self, index):
        """Get training/test data from pipeline.

        Args:
            index (int): Index of data.

        Returns:
            dict: Training/test data.
        """
        if self.test_mode:
            return self.prepare_test_img(index)

        while True:
            try:
                data = self.prepare_train_img(index)
                if data is None:
                    raise Exception('prepared train data empty')
                break
            except Exception as e:
                print_log(f'prepare index {index} with error {e}')
                index = self._get_next_index(index)
        return data

    def format_results(self, results, **kwargs):
        """Placeholder to format result to dataset-specific output."""
        pass

    def evaluate(self, results, metric=None, logger=None, **kwargs):
        """Evaluate the dataset.

        Args:
            results (list): Testing results of the dataset.
            metric (str | list[str]): Metrics to be evaluated.
            logger (logging.Logger | str | None): Logger used for printing
                related information during evaluation. Default: None.
        Returns:
            dict[str: float]
        """
        raise NotImplementedError