import os
import json
import unittest
from pathlib import Path
from zipfile import ZipFile
from typing import List, Dict, Any, Union
from tempfile import TemporaryDirectory


def validate_zip(submission_track: str, submission_zip: Union[Path, str]):
    """
    Validates the submission format and contents
    Args:
        submission_track: the track of the submission
        submission_zip: path to the submission zip file
    Raises:
        ValueError: if the submission zip is invalid

    """
    with TemporaryDirectory() as temp_dir:
        with ZipFile(submission_zip, 'r') as submission_zip_file:
            submission_zip_file.extractall(temp_dir)
        submission_dir = Path(temp_dir)
        if submission_track in ['NOTSOFAR-SC', 'NOTSOFAR-MC']:
            validate_notsofar_submission(submission_dir=submission_dir)
        elif submission_track in ['DASR-Constrained-LM', 'DASR-Unconstrained-LM']:
            validate_dasr_submission(submission_dir=submission_dir)
        else:
            raise ValueError(f'Invalid submission track: {submission_track}')


def validate_notsofar_submission(submission_dir: Path):
    """
    Validates NOTSOFAR submission format and contents
    Args:
        submission_dir: path to the submission directory
    Raises:
        ValueError: if the submission zip is invalid
    """
    submission_file_names = ['tcp_wer_hyp.json']
    optional_file_names = ['tc_orc_wer_ref.json']
    fields = ['session_id', 'words', 'speaker', 'start_time', 'end_time']

    for file_name in submission_file_names + optional_file_names:
        file_path = submission_dir / file_name
        if not file_path.exists():
            if file_name in submission_file_names:
                raise ValueError(f'Missing {file_name}')
            else:
                continue

        validate_json_file_structure(file_path, fields)


def validate_dasr_submission(submission_dir: Path):
    """
    Validates DASR submission format and contents
    Args:
        submission_dir: path to the submission directory
    Raises:
        ValueError: if the submission zip is invalid

    """
    submission_file_names = ['chime6.json', 'dipco.json', 'mixer6.json', 'notsofar1.json']
    fields = ['session_id', 'words', 'speaker', 'start_time', 'end_time']

    if not (submission_dir / 'dev').exists():
        raise ValueError('Missing `dev` directory, expecting a directory named `dev` with the submission files in it.')

    for file_name in submission_file_names:
        file_path = submission_dir / 'dev' / file_name
        if not file_path.exists():
            raise ValueError(f'Missing {file_name}')

        validate_json_file_structure(file_path, fields)


def validate_json_file_structure(file_path: Path, fields: List[str]):
    """
    Validates the structure of a json file
    Args:
        file_path: path to the json file
        fields: list of fields that are required in each entry
    Raises:
        ValueError: if the json file is invalid

    """
    with open(file_path, 'r') as json_file:
        json_data: List[Dict[str, Any]] = json.load(json_file)
        if not isinstance(json_data, list):
            raise ValueError(f'Invalid `{file_path.name}` format, expecting a list of entries')
        for data in json_data:
            if not all(field in data for field in fields):
                raise ValueError(f'Invalid `{file_path.name}` format, fields: {fields} are required in each entry')


####################################################################################################
# Tests
####################################################################################################

class TestValidateZip(unittest.TestCase):
    DATA_SAMPLES = 10

    @classmethod
    def setUpClass(cls):
        cls.valid_data = [{'session_id': 'session_id', 'words': 'words', 'speaker': 'speaker',
                           'start_time': 0.0, 'end_time': 1.0} for _ in range(cls.DATA_SAMPLES)]
        cls.invalid_data = [{'session_id': 'session_id', 'words': 'words',
                             'start_time': 0.0} for _ in range(cls.DATA_SAMPLES)]

    def setUp(self):
        self.temp_dir = TemporaryDirectory()
        self.submission_zip = Path(self.temp_dir.name) / 'submission.zip'

    def create_test_data(self, submission_track: str, data: List[Dict[str, Any]], json_file_names: List[str],
                         parent_zip_dir: str = None):
        submission_dir = Path(self.temp_dir.name) / submission_track
        os.makedirs(submission_dir, exist_ok=True)
        with ZipFile(self.submission_zip, 'w') as submission_zip_file:
            for json_file_name in json_file_names:
                if parent_zip_dir:
                    json_file_name = str(Path(parent_zip_dir) / json_file_name)
                submission_zip_file.writestr(json_file_name, json.dumps(data))
        return submission_track, self.submission_zip

    def tearDown(self):
        self.temp_dir.cleanup()

    def test_NOTSOFAR_SC_valid_data_tcp(self):
        self.assertEqual(validate_zip(*self.create_test_data(
            'NOTSOFAR-SC', self.valid_data, ['tcp_wer_hyp.json'])), None)

    def test_NOTSOFAR_SC_valid_data_tcp_and_tcorc(self):
        self.assertEqual(validate_zip(*self.create_test_data(
            'NOTSOFAR-SC', self.valid_data, ['tcp_wer_hyp.json', 'tc_orc_wer_ref.json'])), None)

    def test_NOTSOFAR_SC_missing_tcp_file(self):
        with self.assertRaises(ValueError):
            validate_zip(*self.create_test_data(
                'NOTSOFAR-SC', self.valid_data, ['tc_orc_wer_ref.json']))

    def test_NOTSOFAR_SC_invalid_data(self):
        with self.assertRaises(ValueError):
            validate_zip(*self.create_test_data(
                'NOTSOFAR-SC', self.invalid_data, ['tcp_wer_hyp.json']))

    def test_NOTSOFAR_MC_valid_data_tcp(self):
        self.assertEqual(validate_zip(*self.create_test_data(
            'NOTSOFAR-MC', self.valid_data, ['tcp_wer_hyp.json'])), None)

    def test_NOTSOFAR_MC_valid_data_tcp_and_tcorc(self):
        self.assertEqual(validate_zip(*self.create_test_data(
            'NOTSOFAR-MC', self.valid_data, ['tcp_wer_hyp.json', 'tc_orc_wer_ref.json'])), None)

    def test_NOTSOFAR_MC_missing_tcp_file(self):
        with self.assertRaises(ValueError):
            validate_zip(*self.create_test_data(
                'NOTSOFAR-MC', self.valid_data, ['tc_orc_wer_ref.json']))

    def test_NOTSOFAR_MC_invalid_data(self):
        with self.assertRaises(ValueError):
            validate_zip(*self.create_test_data(
                'NOTSOFAR-MC', self.invalid_data, ['tcp_wer_hyp.json']))

    def test_DASR_Constrained_LM_valid_data(self):
        self.assertEqual(validate_zip(*self.create_test_data('DASR-Constrained-LM', self.valid_data,
                                                             ['chime6.json', 'dipco.json', 'mixer6.json',
                                                              'notsofar1.json'], 'dev')), None)

    def test_DASR_Constrained_LM_invalid_data(self):
        with self.assertRaises(ValueError):
            validate_zip(*self.create_test_data('DASR-Constrained-LM', self.invalid_data,
                                                ['chime6.json', 'dipco.json', 'mixer6.json', 'notsofar1.json'], 'dev'))

    def test_DASR_Constrained_LM_missing_dev_dir(self):
        with self.assertRaises(ValueError):
            validate_zip(*self.create_test_data('DASR-Constrained-LM', self.valid_data,
                                                ['chime6.json', 'dipco.json', 'mixer6.json', 'notsofar1.json']))

    def test_DASR_Constrained_LM_missing_json_file(self):
        with self.assertRaises(ValueError):
            validate_zip(*self.create_test_data('DASR-Constrained-LM', self.valid_data,
                                                ['chime6.json', 'dipco.json', 'mixer6.json'], 'dev'))

    def test_DASR_Unconstrained_LM_valid_data(self):
        self.assertEqual(validate_zip(*self.create_test_data('DASR-Unconstrained-LM', self.valid_data,
                                                             ['chime6.json', 'dipco.json', 'mixer6.json',
                                                              'notsofar1.json'], 'dev')), None)

    def test_DASR_Unconstrained_LM_invalid_data(self):
        with self.assertRaises(ValueError):
            validate_zip(*self.create_test_data('DASR-Unconstrained-LM', self.invalid_data,
                                                ['chime6.json', 'dipco.json', 'mixer6.json', 'notsofar1.json'], 'dev'))

    def test_DASR_Unconstrained_LM_missing_dev_dir(self):
        with self.assertRaises(ValueError):
            validate_zip(*self.create_test_data('DASR-Unconstrained-LM', self.valid_data,
                                                ['chime6.json', 'dipco.json', 'mixer6.json', 'notsofar1.json']))

    def test_DASR_Unconstrained_LM_missing_json_file(self):
        with self.assertRaises(ValueError):
            validate_zip(*self.create_test_data('DASR-Unconstrained-LM', self.valid_data,
                                                ['chime6.json', 'dipco.json', 'mixer6.json'], 'dev'))


if __name__ == '__main__':
    unittest.main()