import gradio as gr import torch from torch import nn import torchvision from diffusers import UNet2DModel, UNet2DConditionModel, DDPMScheduler, DDPMPipeline, DDIMScheduler from fastprogress.fastprogress import progress_bar labels_map = { 0: "T-Shirt", 1: "Trouser", 2: "Pullover", 3: "Dress", 4: "Coat", 5: "Sandal", 6: "Shirt", 7: "Sneaker", 8: "Bag", 9: "Ankle Boot", } l2i = {l:i for i,l in labels_map.items()} def label2idx(l): return l2i[l] unet = torch.load("unconditional01.pt", map_location=torch.device('cpu')).to("cpu") Emb = torch.load("unconditional_emb_01.pt", map_location=torch.device('cpu')).to("cpu") unet.eval() sched = DDIMScheduler(beta_end=0.01) sched.set_timesteps(20) @torch.no_grad def diff_sample(model, sz, sched, hidden, **kwargs): x_t = torch.randn(sz) preds = [] for t in progress_bar(sched.timesteps): with torch.no_grad(): noise = model(x_t, t, hidden).sample x_t = sched.step(noise, t, x_t, **kwargs).prev_sample preds.append(x_t.float().cpu()) return preds @torch.no_grad() def generate(classChoice): sz = (1,1,32,32) print(classChoice) hidden = Emb(torch.tensor([label2idx(classChoice)]*1)[:,None]).detach().to("cpu") preds = diff_sample(unet, sz, sched, hidden, eta=1.) return((preds[-1][0] + 0.5).squeeze().clamp(-1,1).detach().numpy()) with gr.Blocks() as demo: gr.HTML("""

Conditional Diffusion with DDIM

""") gr.HTML("""

trained with FashionMNIST

""") session_data = gr.State([]) classChoice = gr.Radio(list(labels_map.values()), value="T-Shirt", label="Select the type of image to generate", info="") sampling_button = gr.Button("Conditional image generation") final_image = gr.Image(height=250,width=200) sampling_button.click( generate, [classChoice], [final_image], ) demo.queue().launch(share=False, inbrowser=True)