import numpy as np
import torch
from typing import Optional, List, Tuple, NamedTuple, Union
from models import PipelineWrapper
import torchaudio
from audioldm.utils import get_duration

MAX_DURATION = None


class PromptEmbeddings(NamedTuple):
    embedding_hidden_states: torch.Tensor
    embedding_class_lables: torch.Tensor
    boolean_prompt_mask: torch.Tensor


def load_audio(audio_path: Union[str, np.array], fn_STFT, left: int = 0, right: int = 0,
               device: Optional[torch.device] = None,
               return_wav: bool = False, stft: bool = False, model_sr: Optional[int] = None) -> torch.Tensor:
    if stft:  # AudioLDM/tango loading to spectrogram
        if type(audio_path) is str:
            import audioldm
            import audioldm.audio

            duration = get_duration(audio_path)
            if MAX_DURATION is not None:
                duration = min(duration, MAX_DURATION)

            mel, _, wav = audioldm.audio.wav_to_fbank(audio_path, target_length=int(duration * 102.4), fn_STFT=fn_STFT)
            mel = mel.unsqueeze(0)
        else:
            mel = audio_path

        c, h, w = mel.shape
        left = min(left, w-1)
        right = min(right, w - left - 1)
        mel = mel[:, :, left:w-right]
        mel = mel.unsqueeze(0).to(device)

        if return_wav:
            return mel, 16000, duration, wav

        return mel, model_sr, duration
    else:
        waveform, sr = torchaudio.load(audio_path)
        if sr != model_sr:
            waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=model_sr)
        # waveform = waveform.numpy()[0, ...]

        def normalize_wav(waveform):
            waveform = waveform - torch.mean(waveform)
            waveform = waveform / (torch.max(torch.abs(waveform)) + 1e-8)
            return waveform * 0.5

        waveform = normalize_wav(waveform)
        # waveform = waveform[None, ...]
        # waveform = pad_wav(waveform, segment_length)

        # waveform = waveform[0, ...]
        waveform = torch.FloatTensor(waveform)
        if MAX_DURATION is not None:
            duration = min(waveform.shape[-1] / model_sr, MAX_DURATION)
            waveform = waveform[:, :int(duration * model_sr)]

        # cut waveform
        duration = waveform.shape[-1] / model_sr
        return waveform, model_sr, duration


def get_height_of_spectrogram(length: int, ldm_stable: PipelineWrapper) -> int:
    vocoder_upsample_factor = np.prod(ldm_stable.model.vocoder.config.upsample_rates) / \
        ldm_stable.model.vocoder.config.sampling_rate

    if length is None:
        length = ldm_stable.model.unet.config.sample_size * ldm_stable.model.vae_scale_factor * \
            vocoder_upsample_factor

    height = int(length / vocoder_upsample_factor)

    # original_waveform_length = int(length * ldm_stable.model.vocoder.config.sampling_rate)
    if height % ldm_stable.model.vae_scale_factor != 0:
        height = int(np.ceil(height / ldm_stable.model.vae_scale_factor)) * ldm_stable.model.vae_scale_factor
        print(
            f"Audio length in seconds {length} is increased to {height * vocoder_upsample_factor} "
            f"so that it can be handled by the model. It will be cut to {length} after the "
            f"denoising process."
        )

    return height


def get_text_embeddings(target_prompt: List[str], target_neg_prompt: List[str], ldm_stable: PipelineWrapper
                        ) -> Tuple[torch.Tensor, PromptEmbeddings, PromptEmbeddings]:
    text_embeddings_hidden_states, text_embeddings_class_labels, text_embeddings_boolean_prompt_mask = \
        ldm_stable.encode_text(target_prompt)
    uncond_embedding_hidden_states, uncond_embedding_class_lables, uncond_boolean_prompt_mask = \
        ldm_stable.encode_text(target_neg_prompt)

    text_emb = PromptEmbeddings(embedding_hidden_states=text_embeddings_hidden_states,
                                boolean_prompt_mask=text_embeddings_boolean_prompt_mask,
                                embedding_class_lables=text_embeddings_class_labels)
    uncond_emb = PromptEmbeddings(embedding_hidden_states=uncond_embedding_hidden_states,
                                  boolean_prompt_mask=uncond_boolean_prompt_mask,
                                  embedding_class_lables=uncond_embedding_class_lables)

    return text_embeddings_class_labels, text_emb, uncond_emb