from typing import Dict, List, NoReturn, Optional

import h5py
import librosa
import numpy as np
import torch
from pytorch_lightning.core.datamodule import LightningDataModule

from bytesep.data.samplers import DistributedSamplerWrapper
from bytesep.utils import int16_to_float32


class DataModule(LightningDataModule):
    def __init__(
        self,
        train_sampler: object,
        train_dataset: object,
        num_workers: int,
        distributed: bool,
    ):
        r"""Data module.

        Args:
            train_sampler: Sampler object
            train_dataset: Dataset object
            num_workers: int
            distributed: bool
        """
        super().__init__()
        self._train_sampler = train_sampler
        self.train_dataset = train_dataset
        self.num_workers = num_workers
        self.distributed = distributed

    def setup(self, stage: Optional[str] = None) -> NoReturn:
        r"""called on every device."""

        # SegmentSampler is used for selecting segments for training.
        # On multiple devices, each SegmentSampler samples a part of mini-batch
        # data.
        if self.distributed:
            self.train_sampler = DistributedSamplerWrapper(self._train_sampler)

        else:
            self.train_sampler = self._train_sampler

    def train_dataloader(self) -> torch.utils.data.DataLoader:
        r"""Get train loader."""
        train_loader = torch.utils.data.DataLoader(
            dataset=self.train_dataset,
            batch_sampler=self.train_sampler,
            collate_fn=collate_fn,
            num_workers=self.num_workers,
            pin_memory=True,
        )

        return train_loader


class Dataset:
    def __init__(self, augmentor: object, segment_samples: int):
        r"""Used for getting data according to a meta.

        Args:
            augmentor: Augmentor class
            segment_samples: int
        """
        self.augmentor = augmentor
        self.segment_samples = segment_samples

    def __getitem__(self, meta: Dict) -> Dict:
        r"""Return data according to a meta. E.g., an input meta looks like: {
            'vocals': [['song_A.h5', 6332760, 6465060], ['song_B.h5', 198450, 330750]],
            'accompaniment': [['song_C.h5', 24232920, 24365250], ['song_D.h5', 1569960, 1702260]]}.
        }

        Then, vocals segments of song_A and song_B will be mixed (mix-audio augmentation).
        Accompaniment segments of song_C and song_B will be mixed (mix-audio augmentation).
        Finally, mixture is created by summing vocals and accompaniment.

        Args:
            meta: dict, e.g., {
                'vocals': [['song_A.h5', 6332760, 6465060], ['song_B.h5', 198450, 330750]],
                'accompaniment': [['song_C.h5', 24232920, 24365250], ['song_D.h5', 1569960, 1702260]]}
            }

        Returns:
            data_dict: dict, e.g., {
                'vocals': (channels, segments_num),
                'accompaniment': (channels, segments_num),
                'mixture': (channels, segments_num),
            }
        """
        source_types = meta.keys()
        data_dict = {}

        for source_type in source_types:
            # E.g., ['vocals', 'bass', ...]

            waveforms = []  # Audio segments to be mix-audio augmented.

            for m in meta[source_type]:
                # E.g., {
                #     'hdf5_path': '.../song_A.h5',
                #     'key_in_hdf5': 'vocals',
                #     'begin_sample': '13406400',
                #     'end_sample': 13538700,
                # }

                hdf5_path = m['hdf5_path']
                key_in_hdf5 = m['key_in_hdf5']
                bgn_sample = m['begin_sample']
                end_sample = m['end_sample']

                with h5py.File(hdf5_path, 'r') as hf:

                    if source_type == 'audioset':
                        index_in_hdf5 = m['index_in_hdf5']
                        waveform = int16_to_float32(
                            hf['waveform'][index_in_hdf5][bgn_sample:end_sample]
                        )
                        waveform = waveform[None, :]
                    else:
                        waveform = int16_to_float32(
                            hf[key_in_hdf5][:, bgn_sample:end_sample]
                        )

                    if self.augmentor:
                        waveform = self.augmentor(waveform, source_type)

                    waveform = librosa.util.fix_length(
                        waveform, size=self.segment_samples, axis=1
                    )
                    # (channels_num, segments_num)

                waveforms.append(waveform)
            # E.g., waveforms: [(channels_num, audio_samples), (channels_num, audio_samples)]

            # mix-audio augmentation
            data_dict[source_type] = np.sum(waveforms, axis=0)
            # data_dict[source_type]: (channels_num, audio_samples)

        # data_dict looks like: {
        #     'voclas': (channels_num, audio_samples),
        #     'accompaniment': (channels_num, audio_samples)
        # }

        # Mix segments from different sources.
        mixture = np.sum(
            [data_dict[source_type] for source_type in source_types], axis=0
        )
        data_dict['mixture'] = mixture
        # shape: (channels_num, audio_samples)

        return data_dict


def collate_fn(list_data_dict: List[Dict]) -> Dict:
    r"""Collate mini-batch data to inputs and targets for training.

    Args:
        list_data_dict: e.g., [
            {'vocals': (channels_num, segment_samples),
             'accompaniment': (channels_num, segment_samples),
             'mixture': (channels_num, segment_samples)
            },
            {'vocals': (channels_num, segment_samples),
             'accompaniment': (channels_num, segment_samples),
             'mixture': (channels_num, segment_samples)
            },
            ...]

    Returns:
        data_dict: e.g. {
            'vocals': (batch_size, channels_num, segment_samples),
            'accompaniment': (batch_size, channels_num, segment_samples),
            'mixture': (batch_size, channels_num, segment_samples)
            }
    """
    data_dict = {}

    for key in list_data_dict[0].keys():
        data_dict[key] = torch.Tensor(
            np.array([data_dict[key] for data_dict in list_data_dict])
        )

    return data_dict