import os
import time
import numpy as np
import torch
from tqdm import tqdm
import torch.nn as nn
from collections import OrderedDict
import json

from models.tta.autoencoder.autoencoder import AutoencoderKL
from models.tta.ldm.inference_utils.vocoder import Generator
from models.tta.ldm.audioldm import AudioLDM
from transformers import T5EncoderModel, AutoTokenizer
from diffusers import PNDMScheduler

import matplotlib.pyplot as plt
from scipy.io.wavfile import write

from utils.util import load_config
import gradio as gr

class AttrDict(dict):
    def __init__(self, *args, **kwargs):
        super(AttrDict, self).__init__(*args, **kwargs)
        self.__dict__ = self


def build_autoencoderkl(cfg, device):
    autoencoderkl = AutoencoderKL(cfg.model.autoencoderkl)
    autoencoder_path = cfg.model.autoencoder_path
    checkpoint = torch.load(autoencoder_path, map_location="cpu")
    autoencoderkl.load_state_dict(checkpoint["model"])
    autoencoderkl = autoencoderkl.to(device=device)
    autoencoderkl.requires_grad_(requires_grad=False)
    autoencoderkl.eval()
    return autoencoderkl

def build_textencoder(device):
    # tokenizer = AutoTokenizer.from_pretrained("t5-base", model_max_length=512)
    # text_encoder = T5EncoderModel.from_pretrained("t5-base")
    tokenizer = AutoTokenizer.from_pretrained("ckpts/tta/tokenizer")
    text_encoder = T5EncoderModel.from_pretrained("ckpts/tta/text_encoder")
    text_encoder = text_encoder.to(device=device)
    text_encoder.requires_grad_(requires_grad=False)
    text_encoder.eval()
    return tokenizer, text_encoder

def build_vocoder(device):
    config_file = os.path.join("ckpts/tta/hifigan_checkpoints/config.json")
    with open(config_file) as f:
        data = f.read()
    json_config = json.loads(data)
    h = AttrDict(json_config)
    vocoder = Generator(h).to(device)
    checkpoint_dict = torch.load(
        "ckpts/tta/hifigan_checkpoints/g_01250000", map_location=device
    )
    vocoder.load_state_dict(checkpoint_dict["generator"])
    return vocoder

def build_model(cfg):
    model = AudioLDM(cfg.model.audioldm)
    return model

def get_text_embedding(text, tokenizer, text_encoder, device):

    prompt = [text]

    text_input = tokenizer(
        prompt,
        max_length=tokenizer.model_max_length,
        truncation=True,
        padding="do_not_pad",
        return_tensors="pt",
    )
    text_embeddings = text_encoder(
        text_input.input_ids.to(device)
    )[0]

    max_length = text_input.input_ids.shape[-1]
    uncond_input = tokenizer(
        [""] * 1, padding="max_length", max_length=max_length, return_tensors="pt"
    )
    uncond_embeddings = text_encoder(
        uncond_input.input_ids.to(device)
    )[0]
    text_embeddings = torch.cat([uncond_embeddings, text_embeddings])

    return text_embeddings
    
def tta_inference(
        text,
        guidance_scale=4,
        diffusion_steps=100,
):

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    os.environ["WORK_DIR"] = "./"
    cfg = load_config("egs/tta/audioldm/exp_config.json")

    autoencoderkl = build_autoencoderkl(cfg, device)
    tokenizer, text_encoder = build_textencoder(device)
    vocoder = build_vocoder(device)
    model = build_model(cfg)

    checkpoint_path = "ckpts/tta/audioldm_debug_latent_size_4_5_39/checkpoints/step-0570000_loss-0.2521.pt"
    checkpoint = torch.load(checkpoint_path, map_location="cpu")
    model.load_state_dict(checkpoint["model"])
    model = model.to(device)

    text_embeddings = get_text_embedding(text, tokenizer, text_encoder, device)

    num_steps = diffusion_steps

    noise_scheduler = PNDMScheduler(
        num_train_timesteps=1000,
        beta_start=0.00085,
        beta_end=0.012,
        beta_schedule="scaled_linear",
        skip_prk_steps=True,
        set_alpha_to_one=False,
        steps_offset=1,
        prediction_type="epsilon",
    )

    noise_scheduler.set_timesteps(num_steps)


    latents = torch.randn(
        (
            1,
            cfg.model.autoencoderkl.z_channels,
            80 // (2 ** (len(cfg.model.autoencoderkl.ch_mult) - 1)),
            624 // (2 ** (len(cfg.model.autoencoderkl.ch_mult) - 1)),
        )
    ).to(device)

    model.eval()
    for t in tqdm(noise_scheduler.timesteps):
        t = t.to(device)

        # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
        latent_model_input = torch.cat([latents] * 2)

        latent_model_input = noise_scheduler.scale_model_input(
            latent_model_input, timestep=t
        )
        # print(latent_model_input.shape)

        # predict the noise residual
        with torch.no_grad():
            noise_pred = model(
                latent_model_input, torch.cat([t.unsqueeze(0)] * 2), text_embeddings
            )

        # perform guidance
        noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
        print(guidance_scale)
        noise_pred = noise_pred_uncond + guidance_scale * (
            noise_pred_text - noise_pred_uncond
        )

        # compute the previous noisy sample x_t -> x_t-1
        latents = noise_scheduler.step(noise_pred, t, latents).prev_sample
        # print(latents.shape)

    latents_out = latents

    with torch.no_grad():
        mel_out = autoencoderkl.decode(latents_out)

    melspec = mel_out[0, 0].cpu().detach().numpy()

    vocoder.eval()
    vocoder.remove_weight_norm()

    with torch.no_grad():
        melspec = np.expand_dims(melspec, 0)
        melspec = torch.FloatTensor(melspec).to(device)

        y = vocoder(melspec)
        audio = y.squeeze()
        audio = audio * 32768.0
        audio = audio.cpu().numpy().astype("int16")

    os.makedirs("result", exist_ok=True)
    write(os.path.join("result", text + ".wav"), 16000, audio)

    return os.path.join("result", text + ".wav")

demo_inputs = [
    gr.Textbox(
        value="birds singing and a man whistling",
        label="Text prompt you want to generate",
        type="text",
    ),
    gr.Slider(
        1,
        10,
        value=4,
        step=1,
        label="Classifier free guidance",
    ),
    gr.Slider(
        50,
        1000,
        value=100,
        step=1,
        label="Diffusion Inference Steps",
        info="As the step number increases, the synthesis quality will be better while the inference speed will be lower",
    ),
]

demo_outputs = gr.Audio(label="")

demo = gr.Interface(
    fn=tta_inference,
    inputs=demo_inputs,
    outputs=demo_outputs,
    title="Amphion Text to Audio"
)

if __name__ == "__main__":
    demo.launch()