Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from torch import nn | |
import torchvision | |
from diffusers import UNet2DModel, UNet2DConditionModel, DDPMScheduler, DDPMPipeline, DDIMScheduler | |
from fastprogress.fastprogress import progress_bar | |
labels_map = { | |
0: "T-Shirt", | |
1: "Trouser", | |
2: "Pullover", | |
3: "Dress", | |
4: "Coat", | |
5: "Sandal", | |
6: "Shirt", | |
7: "Sneaker", | |
8: "Bag", | |
9: "Ankle Boot", | |
} | |
l2i = {l:i for i,l in labels_map.items()} | |
def label2idx(l): | |
return l2i[l] | |
unet = torch.load("unconditional01.pt", map_location=torch.device('cpu')).to("cpu") | |
Emb = torch.load("unconditional_emb_01.pt", map_location=torch.device('cpu')).to("cpu") | |
unet.eval() | |
sched = DDIMScheduler(beta_end=0.01) | |
sched.set_timesteps(20) | |
def diff_sample(model, sz, sched, hidden, **kwargs): | |
x_t = torch.randn(sz) | |
preds = [] | |
for t in progress_bar(sched.timesteps): | |
with torch.no_grad(): noise = model(x_t, t, hidden).sample | |
x_t = sched.step(noise, t, x_t, **kwargs).prev_sample | |
preds.append(x_t.float().cpu()) | |
return preds | |
def generate(classChoice): | |
sz = (1,1,32,32) | |
print(classChoice) | |
hidden = Emb(torch.tensor([label2idx(classChoice)]*1)[:,None]).detach().to("cpu") | |
preds = diff_sample(unet, sz, sched, hidden, eta=1.) | |
return((preds[-1][0] + 0.5).squeeze().clamp(-1,1).detach().numpy()) | |
with gr.Blocks() as demo: | |
gr.HTML("""<h1 align="center">Conditional Diffusion with DDIM</h1>""") | |
gr.HTML("""<h1 align="center">trained with FashionMNIST</h1>""") | |
session_data = gr.State([]) | |
classChoice = gr.Radio(list(labels_map.values()), value="T-Shirt", label="Select the type of image to generate", info="") | |
sampling_button = gr.Button("Conditional image generation") | |
final_image = gr.Image(height=250,width=200) | |
sampling_button.click( | |
generate, | |
[classChoice], | |
[final_image], | |
) | |
demo.queue().launch(share=False, inbrowser=True) | |