Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
from torch import nn | |
import torchvision | |
from torchvision.transforms import ToTensor | |
data_test = torchvision.datasets.FashionMNIST(root=".\data", train = False, transform=ToTensor(), download=True) | |
class MyVAE(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.encoder = nn.Sequential( | |
# (conv_in) | |
nn.Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), # 28, 28 | |
# (down_block_0) | |
# (norm1) | |
nn.GroupNorm(8, 32, eps=1e-06, affine=True), | |
# (conv1) | |
nn.Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), #28, 28 | |
# (norm2): | |
nn.GroupNorm(8, 32, eps=1e-06, affine=True), | |
# (dropout): | |
nn.Dropout(p=0.5, inplace=False), | |
# (conv2): | |
nn.Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), #28, 28 | |
# (nonlinearity): | |
nn.SiLU(), | |
# (downsamplers)(conv): | |
nn.Conv2d(32, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)), #14, 14 | |
# (down_block_1) | |
# (norm1) | |
nn.GroupNorm(8, 32, eps=1e-06, affine=True), | |
# (conv1) | |
nn.Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), #28, 28 | |
# (norm2): | |
nn.GroupNorm(8, 64, eps=1e-06, affine=True), | |
# (dropout): | |
nn.Dropout(p=0.5, inplace=False), | |
# (conv2): | |
nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), #28, 28 | |
# (nonlinearity): | |
nn.SiLU(), | |
# (conv_shortcut): | |
#nn.Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), #28, 28 | |
# (nonlinearity): | |
nn.SiLU(), | |
# (downsamplers)(conv): | |
nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)), #7, 7 | |
# (conv_norm_out): | |
nn.GroupNorm(16, 64, eps=1e-06, affine=True), | |
# (conv_act): | |
nn.SiLU(), | |
# (conv_out): | |
nn.Conv2d(64, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), | |
#nn.Conv2d(1, 4, kernel_size=3, stride=2, padding=3//2), # 14*14 | |
#nn.ReLU(), | |
#nn.Conv2d(4, 8, kernel_size=3, stride=2, padding=3//2), # 7*7 | |
#nn.ReLU(), | |
) | |
self.decoder = nn.Sequential( | |
#(conv_in): | |
nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), | |
#(norm1): | |
nn.GroupNorm(16, 64, eps=1e-06, affine=True), | |
#(conv1): | |
nn.Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), | |
#(norm2): | |
nn.GroupNorm(8, 32, eps=1e-06, affine=True), | |
#(dropout): | |
nn.Dropout(p=0.5, inplace=False), | |
#(conv2): | |
nn.Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), | |
#(nonlinearity): | |
nn.SiLU(), | |
#(upsamplers): | |
nn.Upsample(scale_factor=2, mode='nearest'), # 14,14 | |
#(norm1): | |
nn.GroupNorm(8, 32, eps=1e-06, affine=True), | |
#(conv1): | |
nn.Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), | |
#(norm2): | |
nn.GroupNorm(8, 16, eps=1e-06, affine=True), | |
#(dropout): | |
nn.Dropout(p=0.5, inplace=False), | |
#(conv2): | |
nn.Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), | |
#(nonlinearity): | |
nn.SiLU(), | |
#(upsamplers): | |
nn.Upsample(scale_factor=2, mode='nearest'), # 16, 28, 28 | |
#(norm1): | |
nn.GroupNorm(8, 16, eps=1e-06, affine=True), | |
#(conv1): | |
nn.Conv2d(16, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), | |
nn.Sigmoid() | |
) | |
def forward(self, xb, yb): | |
x = self.encoder(xb) | |
#print("current:",x.shape) | |
x = self.decoder(x) | |
#print("current decoder:",x.shape) | |
#x = x.flatten(start_dim=1).mean(dim=1, keepdim=True) | |
#print(x.shape, xb.shape) | |
return x, F.mse_loss(x, xb) | |
labels_map = { | |
0: "T-Shirt", | |
1: "Trouser", | |
2: "Pullover", | |
3: "Dress", | |
4: "Coat", | |
5: "Sandal", | |
6: "Shirt", | |
7: "Sneaker", | |
8: "Bag", | |
9: "Ankle Boot", | |
} | |
myVAE2 = torch.load("checkpoint10.pt", map_location=torch.device('cpu')).to("cpu") | |
myVAE2.eval() | |
def sample(): | |
idx = torch.randint(0, len(data_test), (1,)) | |
print(idx.item()) | |
print(data_test[idx.item()][0].squeeze().shape) | |
img = data_test[idx.item()][0].squeeze() | |
img_original = torchvision.transforms.functional.to_pil_image(img) | |
img_encoded = torchvision.transforms.functional.to_pil_image(myVAE2.encoder(img[None,None,...]).squeeze()) | |
img_decoded = torchvision.transforms.functional.to_pil_image(myVAE2.decoder(myVAE2.encoder(img[None,None,...])).squeeze()) | |
return(img_original,img_encoded,img_decoded, labels_map[data_test[idx.item()][1]]) | |
with gr.Blocks() as demo: | |
gr.HTML("""<h1 align="center">Variational Autoencoder</h1>""") | |
gr.HTML("""<h1 align="center">trained with FashionMNIST</h1>""") | |
session_data = gr.State([]) | |
sampling_button = gr.Button("Sample random FashionMNIST image") | |
with gr.Row(): | |
with gr.Column(scale=2): | |
#gr.Label("Original image") | |
gr.HTML("""<h3 align="left">Original image</h1>""") | |
sample_image = gr.Image(height=250,width=200) | |
with gr.Column(scale=2): | |
#gr.Label("Encoded image") | |
gr.HTML("""<h3 align="left">Encoded image</h1>""") | |
encoded_image = gr.Image(height=250,width=200) | |
with gr.Column(scale=2): | |
gr.HTML("""<h3 align="left">Decoded image</h1>""") | |
#gr.Label("Decoded image") | |
decoded_image = gr.Image(height=250,width=200) | |
image_label = gr.Label(label = "Image label") | |
sampling_button.click( | |
sample, | |
[], | |
[sample_image, encoded_image, decoded_image, image_label], | |
) | |
demo.queue().launch(share=False, inbrowser=True) | |