import gradio as gr
import torch
from torch import nn
import torchvision
from diffusers import UNet2DModel, UNet2DConditionModel, DDPMScheduler, DDPMPipeline, DDIMScheduler
from fastprogress.fastprogress import progress_bar

labels_map = {
    0: "T-Shirt",
    1: "Trouser",
    2: "Pullover",
    3: "Dress",
    4: "Coat",
    5: "Sandal",
    6: "Shirt",
    7: "Sneaker",
    8: "Bag",
    9: "Ankle Boot",
}

l2i = {l:i for i,l in labels_map.items()}

def label2idx(l):
    return l2i[l]
    

unet = torch.load("unconditional01.pt", map_location=torch.device('cpu')).to("cpu")
Emb = torch.load("unconditional_emb_01.pt", map_location=torch.device('cpu')).to("cpu")
unet.eval()

sched = DDIMScheduler(beta_end=0.01)
sched.set_timesteps(20)

@torch.no_grad
def diff_sample(model, sz, sched, hidden, **kwargs):
    x_t = torch.randn(sz)
    preds = []
    for t in progress_bar(sched.timesteps):
        with torch.no_grad(): noise = model(x_t, t, hidden).sample
        x_t = sched.step(noise, t, x_t, **kwargs).prev_sample
        preds.append(x_t.float().cpu())
    return preds


@torch.no_grad()  
def generate(classChoice):
    sz = (1,1,32,32)
    print(classChoice)
    hidden = Emb(torch.tensor([label2idx(classChoice)]*1)[:,None]).detach().to("cpu")
    preds = diff_sample(unet, sz, sched, hidden, eta=1.)

    return((preds[-1][0] + 0.5).squeeze().clamp(-1,1).detach().numpy())
    
with gr.Blocks() as demo:
    gr.HTML("""<h1 align="center">Conditional Diffusion with DDIM</h1>""")
    gr.HTML("""<h1 align="center">trained with FashionMNIST</h1>""")
    session_data = gr.State([])

    classChoice = gr.Radio(list(labels_map.values()), value="T-Shirt", label="Select the type of image to generate", info="")
    sampling_button = gr.Button("Conditional image generation")
    final_image = gr.Image(height=250,width=200) 

  

    sampling_button.click(
        generate,
        [classChoice],
        [final_image],
    )

demo.queue().launch(share=False, inbrowser=True)