import os
import torch
import numpy as np
import torchaudio
import matplotlib.pyplot as plt

CACHE = {
    "get_vits_phoneme_ids": {
        "PAD_LENGTH": 310,
        "_pad": "_",
        "_punctuation": ';:,.!?¡¿—…"«»“” ',
        "_letters": "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz",
        "_letters_ipa": "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ",
        "_special": "♪☎☒☝⚠",
    }
}

CACHE["get_vits_phoneme_ids"]["symbols"] = (
    [CACHE["get_vits_phoneme_ids"]["_pad"]]
    + list(CACHE["get_vits_phoneme_ids"]["_punctuation"])
    + list(CACHE["get_vits_phoneme_ids"]["_letters"])
    + list(CACHE["get_vits_phoneme_ids"]["_letters_ipa"])
    + list(CACHE["get_vits_phoneme_ids"]["_special"])
)
CACHE["get_vits_phoneme_ids"]["_symbol_to_id"] = {
    s: i for i, s in enumerate(CACHE["get_vits_phoneme_ids"]["symbols"])
}


def get_vits_phoneme_ids(config, dl_output, metadata):
    pad_token_id = 0
    pad_length = CACHE["get_vits_phoneme_ids"]["PAD_LENGTH"]
    _symbol_to_id = CACHE["get_vits_phoneme_ids"]["_symbol_to_id"]

    assert (
        "phonemes" in metadata.keys()
    ), "You must provide vits phonemes on using addon get_vits_phoneme_ids"
    clean_text = metadata["phonemes"]
    sequence = []

    for symbol in clean_text:
        symbol_id = _symbol_to_id[symbol]
        sequence += [symbol_id]

    inserted_zero_sequence = [0] * (len(sequence) * 2)
    inserted_zero_sequence[1::2] = sequence
    inserted_zero_sequence = inserted_zero_sequence + [0]

    def _pad_phonemes(phonemes_list):
        return phonemes_list + [pad_token_id] * (pad_length - len(phonemes_list))

    return {"phoneme_idx": torch.LongTensor(_pad_phonemes(inserted_zero_sequence))}


def get_vits_phoneme_ids_no_padding(config, dl_output, metadata):
    pad_token_id = 0
    pad_length = CACHE["get_vits_phoneme_ids"]["PAD_LENGTH"]
    _symbol_to_id = CACHE["get_vits_phoneme_ids"]["_symbol_to_id"]

    assert (
        "phonemes" in metadata.keys()
    ), "You must provide vits phonemes on using addon get_vits_phoneme_ids"
    clean_text = metadata["phonemes"] + "⚠"
    sequence = []

    for symbol in clean_text:
        if symbol not in _symbol_to_id.keys():
            print("%s is not in the vocabulary. %s" % (symbol, clean_text))
            symbol = "_"
        symbol_id = _symbol_to_id[symbol]
        sequence += [symbol_id]

    def _pad_phonemes(phonemes_list):
        return phonemes_list + [pad_token_id] * (pad_length - len(phonemes_list))

    sequence = sequence[:pad_length]

    return {"phoneme_idx": torch.LongTensor(_pad_phonemes(sequence))}


def calculate_relative_bandwidth(config, dl_output, metadata):
    assert "stft" in dl_output.keys()

    # The last dimension of the stft feature is the frequency dimension
    freq_dimensions = dl_output["stft"].size(-1)

    freq_energy_dist = torch.sum(dl_output["stft"], dim=0)
    freq_energy_dist = torch.cumsum(freq_energy_dist, dim=0)
    total_energy = freq_energy_dist[-1]

    percentile_5th = total_energy * 0.05
    percentile_95th = total_energy * 0.95

    lower_idx = torch.argmin(torch.abs(percentile_5th - freq_energy_dist))
    higher_idx = torch.argmin(torch.abs(percentile_95th - freq_energy_dist))

    lower_idx = int((lower_idx / freq_dimensions) * 1000)
    higher_idx = int((higher_idx / freq_dimensions) * 1000)

    return {"freq_energy_percentile": torch.LongTensor([lower_idx, higher_idx])}


def calculate_mel_spec_relative_bandwidth_as_extra_channel(config, dl_output, metadata):
    assert "stft" in dl_output.keys()
    linear_mel_spec = torch.exp(torch.clip(dl_output["log_mel_spec"], max=10))

    # The last dimension of the stft feature is the frequency dimension
    freq_dimensions = linear_mel_spec.size(-1)
    freq_energy_dist = torch.sum(linear_mel_spec, dim=0)
    freq_energy_dist = torch.cumsum(freq_energy_dist, dim=0)
    total_energy = freq_energy_dist[-1]

    percentile_5th = total_energy * 0.05
    percentile_95th = total_energy * 0.95

    lower_idx = torch.argmin(torch.abs(percentile_5th - freq_energy_dist))
    higher_idx = torch.argmin(torch.abs(percentile_95th - freq_energy_dist))

    latent_t_size = config["model"]["params"]["latent_t_size"]
    latent_f_size = config["model"]["params"]["latent_f_size"]

    lower_idx = int(latent_f_size * float((lower_idx / freq_dimensions)))
    higher_idx = int(latent_f_size * float((higher_idx / freq_dimensions)))

    bandwidth_condition = torch.zeros((latent_t_size, latent_f_size))
    bandwidth_condition[:, lower_idx:higher_idx] += 1.0

    return {
        "mel_spec_bandwidth_cond_extra_channel": bandwidth_condition,
        "freq_energy_percentile": torch.LongTensor([lower_idx, higher_idx]),
    }


def waveform_rs_48k(config, dl_output, metadata):
    waveform = dl_output["waveform"]  # [1, samples]
    sampling_rate = dl_output["sampling_rate"]

    if sampling_rate != 48000:
        waveform_48k = torchaudio.functional.resample(
            waveform, orig_freq=sampling_rate, new_freq=48000
        )
    else:
        waveform_48k = waveform

    return {"waveform_48k": waveform_48k}


def extract_vits_phoneme_and_flant5_text(config, dl_output, metadata):
    assert (
        "phoneme" not in metadata.keys()
    ), "The metadata of speech you use seems belong to fastspeech. Please check dataset_root.json"

    if "phonemes" in metadata.keys():
        new_item = get_vits_phoneme_ids_no_padding(config, dl_output, metadata)
        new_item["text"] = ""  # We assume TTS data does not have text description
    else:
        fake_metadata = {"phonemes": ""}  # Add empty phoneme sequence
        new_item = get_vits_phoneme_ids_no_padding(config, dl_output, fake_metadata)

    return new_item


def extract_fs2_phoneme_and_flant5_text(config, dl_output, metadata):
    if "phoneme" in metadata.keys():
        new_item = extract_fs2_phoneme_g2p_en_feature(config, dl_output, metadata)
        new_item["text"] = ""
    else:
        fake_metadata = {"phoneme": []}
        new_item = extract_fs2_phoneme_g2p_en_feature(config, dl_output, fake_metadata)
    return new_item


def extract_fs2_phoneme_g2p_en_feature(config, dl_output, metadata):
    PAD_LENGTH = 135

    phonemes_lookup_dict = {
        "K": 0,
        "IH2": 1,
        "NG": 2,
        "OW2": 3,
        "AH2": 4,
        "F": 5,
        "AE0": 6,
        "IY0": 7,
        "SH": 8,
        "G": 9,
        "W": 10,
        "UW1": 11,
        "AO2": 12,
        "AW2": 13,
        "UW0": 14,
        "EY2": 15,
        "UW2": 16,
        "AE2": 17,
        "IH0": 18,
        "P": 19,
        "D": 20,
        "ER1": 21,
        "AA1": 22,
        "EH0": 23,
        "UH1": 24,
        "N": 25,
        "V": 26,
        "AY1": 27,
        "EY1": 28,
        "UH2": 29,
        "EH1": 30,
        "L": 31,
        "AA2": 32,
        "R": 33,
        "OY1": 34,
        "Y": 35,
        "ER2": 36,
        "S": 37,
        "AE1": 38,
        "AH1": 39,
        "JH": 40,
        "ER0": 41,
        "EH2": 42,
        "IY2": 43,
        "OY2": 44,
        "AW1": 45,
        "IH1": 46,
        "IY1": 47,
        "OW0": 48,
        "AO0": 49,
        "AY0": 50,
        "EY0": 51,
        "AY2": 52,
        "UH0": 53,
        "M": 54,
        "TH": 55,
        "T": 56,
        "OY0": 57,
        "AW0": 58,
        "DH": 59,
        "Z": 60,
        "spn": 61,
        "AH0": 62,
        "sp": 63,
        "AO1": 64,
        "OW1": 65,
        "ZH": 66,
        "B": 67,
        "AA0": 68,
        "CH": 69,
        "HH": 70,
    }
    pad_token_id = len(phonemes_lookup_dict.keys())

    assert (
        "phoneme" in metadata.keys()
    ), "The dataloader add-on extract_phoneme_g2p_en_feature will output phoneme id, which is not specified in your dataset"

    phonemes = [
        phonemes_lookup_dict[x]
        for x in metadata["phoneme"]
        if (x in phonemes_lookup_dict.keys())
    ]

    if (len(phonemes) / PAD_LENGTH) > 5:
        print(
            "Warning: Phonemes length is too long and is truncated too much! %s"
            % metadata
        )

    phonemes = phonemes[:PAD_LENGTH]

    def _pad_phonemes(phonemes_list):
        return phonemes_list + [pad_token_id] * (PAD_LENGTH - len(phonemes_list))

    return {"phoneme_idx": torch.LongTensor(_pad_phonemes(phonemes))}


def extract_phoneme_g2p_en_feature(config, dl_output, metadata):
    PAD_LENGTH = 250

    phonemes_lookup_dict = {
        " ": 0,
        "AA": 1,
        "AE": 2,
        "AH": 3,
        "AO": 4,
        "AW": 5,
        "AY": 6,
        "B": 7,
        "CH": 8,
        "D": 9,
        "DH": 10,
        "EH": 11,
        "ER": 12,
        "EY": 13,
        "F": 14,
        "G": 15,
        "HH": 16,
        "IH": 17,
        "IY": 18,
        "JH": 19,
        "K": 20,
        "L": 21,
        "M": 22,
        "N": 23,
        "NG": 24,
        "OW": 25,
        "OY": 26,
        "P": 27,
        "R": 28,
        "S": 29,
        "SH": 30,
        "T": 31,
        "TH": 32,
        "UH": 33,
        "UW": 34,
        "V": 35,
        "W": 36,
        "Y": 37,
        "Z": 38,
        "ZH": 39,
    }
    pad_token_id = len(phonemes_lookup_dict.keys())

    assert (
        "phoneme" in metadata.keys()
    ), "The dataloader add-on extract_phoneme_g2p_en_feature will output phoneme id, which is not specified in your dataset"
    phonemes = [
        phonemes_lookup_dict[x]
        for x in metadata["phoneme"]
        if (x in phonemes_lookup_dict.keys())
    ]

    if (len(phonemes) / PAD_LENGTH) > 5:
        print(
            "Warning: Phonemes length is too long and is truncated too much! %s"
            % metadata
        )

    phonemes = phonemes[:PAD_LENGTH]

    def _pad_phonemes(phonemes_list):
        return phonemes_list + [pad_token_id] * (PAD_LENGTH - len(phonemes_list))

    return {"phoneme_idx": torch.LongTensor(_pad_phonemes(phonemes))}


def extract_kaldi_fbank_feature(config, dl_output, metadata):
    norm_mean = -4.2677393
    norm_std = 4.5689974

    waveform = dl_output["waveform"]  # [1, samples]
    sampling_rate = dl_output["sampling_rate"]
    log_mel_spec_hifigan = dl_output["log_mel_spec"]

    if sampling_rate != 16000:
        waveform_16k = torchaudio.functional.resample(
            waveform, orig_freq=sampling_rate, new_freq=16000
        )
    else:
        waveform_16k = waveform

    waveform_16k = waveform_16k - waveform_16k.mean()
    fbank = torchaudio.compliance.kaldi.fbank(
        waveform_16k,
        htk_compat=True,
        sample_frequency=16000,
        use_energy=False,
        window_type="hanning",
        num_mel_bins=128,
        dither=0.0,
        frame_shift=10,
    )

    TARGET_LEN = log_mel_spec_hifigan.size(0)

    # cut and pad
    n_frames = fbank.shape[0]
    p = TARGET_LEN - n_frames
    if p > 0:
        m = torch.nn.ZeroPad2d((0, 0, 0, p))
        fbank = m(fbank)
    elif p < 0:
        fbank = fbank[:TARGET_LEN, :]

    fbank = (fbank - norm_mean) / (norm_std * 2)

    return {"ta_kaldi_fbank": fbank}  # [1024, 128]


def extract_kaldi_fbank_feature_32k(config, dl_output, metadata):
    norm_mean = -4.2677393
    norm_std = 4.5689974

    waveform = dl_output["waveform"]  # [1, samples]
    sampling_rate = dl_output["sampling_rate"]
    log_mel_spec_hifigan = dl_output["log_mel_spec"]

    if sampling_rate != 32000:
        waveform_32k = torchaudio.functional.resample(
            waveform, orig_freq=sampling_rate, new_freq=32000
        )
    else:
        waveform_32k = waveform

    waveform_32k = waveform_32k - waveform_32k.mean()
    fbank = torchaudio.compliance.kaldi.fbank(
        waveform_32k,
        htk_compat=True,
        sample_frequency=32000,
        use_energy=False,
        window_type="hanning",
        num_mel_bins=128,
        dither=0.0,
        frame_shift=10,
    )

    TARGET_LEN = log_mel_spec_hifigan.size(0)

    # cut and pad
    n_frames = fbank.shape[0]
    p = TARGET_LEN - n_frames
    if p > 0:
        m = torch.nn.ZeroPad2d((0, 0, 0, p))
        fbank = m(fbank)
    elif p < 0:
        fbank = fbank[:TARGET_LEN, :]

    fbank = (fbank - norm_mean) / (norm_std * 2)

    return {"ta_kaldi_fbank": fbank}  # [1024, 128]


# Use the beat and downbeat information as music conditions
def extract_drum_beat(config, dl_output, metadata):
    def visualization(conditional_signal, mel_spectrogram, filename):
        import soundfile as sf

        sf.write(
            os.path.basename(dl_output["fname"]),
            np.array(dl_output["waveform"])[0],
            dl_output["sampling_rate"],
        )
        plt.figure(figsize=(10, 10))

        plt.subplot(211)
        plt.imshow(np.array(conditional_signal).T, aspect="auto")
        plt.title("Conditional Signal")

        plt.subplot(212)
        plt.imshow(np.array(mel_spectrogram).T, aspect="auto")
        plt.title("Mel Spectrogram")

        plt.savefig(filename)
        plt.close()

    assert "sample_rate" in metadata and "beat" in metadata and "downbeat" in metadata

    sampling_rate = metadata["sample_rate"]
    duration = dl_output["duration"]
    # The dataloader segment length before performing torch resampling
    original_segment_length_before_resample = int(sampling_rate * duration)

    random_start_sample = int(dl_output["random_start_sample_in_original_audio_file"])

    # The sample idx for beat and downbeat, relatively to the segmented audio
    beat = [
        x - random_start_sample
        for x in metadata["beat"]
        if (
            x - random_start_sample >= 0
            and x - random_start_sample <= original_segment_length_before_resample
        )
    ]
    downbeat = [
        x - random_start_sample
        for x in metadata["downbeat"]
        if (
            x - random_start_sample >= 0
            and x - random_start_sample <= original_segment_length_before_resample
        )
    ]

    latent_shape = (
        config["model"]["params"]["latent_t_size"],
        config["model"]["params"]["latent_f_size"],
    )
    conditional_signal = torch.zeros(latent_shape)

    # beat: -0.5
    # downbeat: +1.0
    # 0: none; -0.5: beat; 1.0: downbeat; 0.5: downbeat+beat
    for each in beat:
        beat_index = int(
            (each / original_segment_length_before_resample) * latent_shape[0]
        )
        beat_index = min(beat_index, conditional_signal.size(0) - 1)

        conditional_signal[beat_index, :] -= 0.5

    for each in downbeat:
        beat_index = int(
            (each / original_segment_length_before_resample) * latent_shape[0]
        )
        beat_index = min(beat_index, conditional_signal.size(0) - 1)

        conditional_signal[beat_index, :] += 1.0

    # visualization(conditional_signal, dl_output["log_mel_spec"], filename = os.path.basename(dl_output["fname"])+".png")

    return {"cond_beat_downbeat": conditional_signal}