# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import tempfile

import numpy as np
import pytest

from mmocr.datasets.pipelines.ocr_seg_targets import OCRSegTargets


def _create_dummy_dict_file(dict_file):
    chars = list('0123456789')
    with open(dict_file, 'w') as fw:
        for char in chars:
            fw.write(char + '\n')


def test_ocr_segm_targets():
    tmp_dir = tempfile.TemporaryDirectory()
    # create dummy dict file
    dict_file = osp.join(tmp_dir.name, 'fake_chars.txt')
    _create_dummy_dict_file(dict_file)
    # dummy label convertor
    label_convertor = dict(
        type='SegConvertor',
        dict_file=dict_file,
        with_unknown=True,
        lower=True)
    # test init
    with pytest.raises(AssertionError):
        OCRSegTargets(None, 0.5, 0.5)
    with pytest.raises(AssertionError):
        OCRSegTargets(label_convertor, '1by2', 0.5)
    with pytest.raises(AssertionError):
        OCRSegTargets(label_convertor, 0.5, 2)

    ocr_seg_tgt = OCRSegTargets(label_convertor, 0.5, 0.5)
    # test generate kernels
    img_size = (8, 8)
    pad_size = (8, 10)
    char_boxes = [[2, 2, 6, 6]]
    char_idxs = [2]

    with pytest.raises(AssertionError):
        ocr_seg_tgt.generate_kernels(8, pad_size, char_boxes, char_idxs, 0.5,
                                     True)
    with pytest.raises(AssertionError):
        ocr_seg_tgt.generate_kernels(img_size, pad_size, [2, 2, 6, 6],
                                     char_idxs, 0.5, True)
    with pytest.raises(AssertionError):
        ocr_seg_tgt.generate_kernels(img_size, pad_size, char_boxes, 2, 0.5,
                                     True)

    attn_tgt = ocr_seg_tgt.generate_kernels(
        img_size, pad_size, char_boxes, char_idxs, 0.5, binary=True)
    expect_attn_tgt = [[0, 0, 0, 0, 0, 0, 0, 0, 255, 255],
                       [0, 0, 0, 0, 0, 0, 0, 0, 255, 255],
                       [0, 0, 0, 0, 0, 0, 0, 0, 255, 255],
                       [0, 0, 0, 1, 1, 1, 0, 0, 255, 255],
                       [0, 0, 0, 1, 1, 1, 0, 0, 255, 255],
                       [0, 0, 0, 1, 1, 1, 0, 0, 255, 255],
                       [0, 0, 0, 0, 0, 0, 0, 0, 255, 255],
                       [0, 0, 0, 0, 0, 0, 0, 0, 255, 255]]
    assert np.allclose(attn_tgt, np.array(expect_attn_tgt, dtype=np.int32))

    segm_tgt = ocr_seg_tgt.generate_kernels(
        img_size, pad_size, char_boxes, char_idxs, 0.5, binary=False)
    expect_segm_tgt = [[0, 0, 0, 0, 0, 0, 0, 0, 255, 255],
                       [0, 0, 0, 0, 0, 0, 0, 0, 255, 255],
                       [0, 0, 0, 0, 0, 0, 0, 0, 255, 255],
                       [0, 0, 0, 2, 2, 2, 0, 0, 255, 255],
                       [0, 0, 0, 2, 2, 2, 0, 0, 255, 255],
                       [0, 0, 0, 2, 2, 2, 0, 0, 255, 255],
                       [0, 0, 0, 0, 0, 0, 0, 0, 255, 255],
                       [0, 0, 0, 0, 0, 0, 0, 0, 255, 255]]
    assert np.allclose(segm_tgt, np.array(expect_segm_tgt, dtype=np.int32))

    # test __call__
    results = {}
    results['img_shape'] = (4, 4, 3)
    results['resize_shape'] = (8, 8, 3)
    results['pad_shape'] = (8, 10)
    results['ann_info'] = {}
    results['ann_info']['char_rects'] = [[1, 1, 3, 3]]
    results['ann_info']['chars'] = ['1']

    results = ocr_seg_tgt(results)
    assert results['mask_fields'] == ['gt_kernels']
    assert np.allclose(results['gt_kernels'].masks[0],
                       np.array(expect_attn_tgt, dtype=np.int32))
    assert np.allclose(results['gt_kernels'].masks[1],
                       np.array(expect_segm_tgt, dtype=np.int32))

    tmp_dir.cleanup()