wb-droid's picture
initial commit
75403bb
import gradio as gr
import torch
from torch import nn
import torchvision
from torchvision.transforms import ToTensor
data_test = torchvision.datasets.FashionMNIST(root=".\data", train = False, transform=ToTensor(), download=True)
class MyVAE(nn.Module):
def __init__(self):
super().__init__()
self.encoder = nn.Sequential(
# (conv_in)
nn.Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), # 28, 28
# (down_block_0)
# (norm1)
nn.GroupNorm(8, 32, eps=1e-06, affine=True),
# (conv1)
nn.Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), #28, 28
# (norm2):
nn.GroupNorm(8, 32, eps=1e-06, affine=True),
# (dropout):
nn.Dropout(p=0.5, inplace=False),
# (conv2):
nn.Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), #28, 28
# (nonlinearity):
nn.SiLU(),
# (downsamplers)(conv):
nn.Conv2d(32, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)), #14, 14
# (down_block_1)
# (norm1)
nn.GroupNorm(8, 32, eps=1e-06, affine=True),
# (conv1)
nn.Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), #28, 28
# (norm2):
nn.GroupNorm(8, 64, eps=1e-06, affine=True),
# (dropout):
nn.Dropout(p=0.5, inplace=False),
# (conv2):
nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), #28, 28
# (nonlinearity):
nn.SiLU(),
# (conv_shortcut):
#nn.Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), #28, 28
# (nonlinearity):
nn.SiLU(),
# (downsamplers)(conv):
nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)), #7, 7
# (conv_norm_out):
nn.GroupNorm(16, 64, eps=1e-06, affine=True),
# (conv_act):
nn.SiLU(),
# (conv_out):
nn.Conv2d(64, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
#nn.Conv2d(1, 4, kernel_size=3, stride=2, padding=3//2), # 14*14
#nn.ReLU(),
#nn.Conv2d(4, 8, kernel_size=3, stride=2, padding=3//2), # 7*7
#nn.ReLU(),
)
self.decoder = nn.Sequential(
#(conv_in):
nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
#(norm1):
nn.GroupNorm(16, 64, eps=1e-06, affine=True),
#(conv1):
nn.Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
#(norm2):
nn.GroupNorm(8, 32, eps=1e-06, affine=True),
#(dropout):
nn.Dropout(p=0.5, inplace=False),
#(conv2):
nn.Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
#(nonlinearity):
nn.SiLU(),
#(upsamplers):
nn.Upsample(scale_factor=2, mode='nearest'), # 14,14
#(norm1):
nn.GroupNorm(8, 32, eps=1e-06, affine=True),
#(conv1):
nn.Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
#(norm2):
nn.GroupNorm(8, 16, eps=1e-06, affine=True),
#(dropout):
nn.Dropout(p=0.5, inplace=False),
#(conv2):
nn.Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
#(nonlinearity):
nn.SiLU(),
#(upsamplers):
nn.Upsample(scale_factor=2, mode='nearest'), # 16, 28, 28
#(norm1):
nn.GroupNorm(8, 16, eps=1e-06, affine=True),
#(conv1):
nn.Conv2d(16, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
nn.Sigmoid()
)
def forward(self, xb, yb):
x = self.encoder(xb)
#print("current:",x.shape)
x = self.decoder(x)
#print("current decoder:",x.shape)
#x = x.flatten(start_dim=1).mean(dim=1, keepdim=True)
#print(x.shape, xb.shape)
return x, F.mse_loss(x, xb)
labels_map = {
0: "T-Shirt",
1: "Trouser",
2: "Pullover",
3: "Dress",
4: "Coat",
5: "Sandal",
6: "Shirt",
7: "Sneaker",
8: "Bag",
9: "Ankle Boot",
}
myVAE2 = torch.load("checkpoint10.pt", map_location=torch.device('cpu')).to("cpu")
myVAE2.eval()
def sample():
idx = torch.randint(0, len(data_test), (1,))
print(idx.item())
print(data_test[idx.item()][0].squeeze().shape)
img = data_test[idx.item()][0].squeeze()
img_original = torchvision.transforms.functional.to_pil_image(img)
img_encoded = torchvision.transforms.functional.to_pil_image(myVAE2.encoder(img[None,None,...]).squeeze())
img_decoded = torchvision.transforms.functional.to_pil_image(myVAE2.decoder(myVAE2.encoder(img[None,None,...])).squeeze())
return(img_original,img_encoded,img_decoded, labels_map[data_test[idx.item()][1]])
with gr.Blocks() as demo:
gr.HTML("""<h1 align="center">Variational Autoencoder</h1>""")
gr.HTML("""<h1 align="center">trained with FashionMNIST</h1>""")
session_data = gr.State([])
sampling_button = gr.Button("Sample random FashionMNIST image")
with gr.Row():
with gr.Column(scale=2):
#gr.Label("Original image")
gr.HTML("""<h3 align="left">Original image</h1>""")
sample_image = gr.Image(height=250,width=200)
with gr.Column(scale=2):
#gr.Label("Encoded image")
gr.HTML("""<h3 align="left">Encoded image</h1>""")
encoded_image = gr.Image(height=250,width=200)
with gr.Column(scale=2):
gr.HTML("""<h3 align="left">Decoded image</h1>""")
#gr.Label("Decoded image")
decoded_image = gr.Image(height=250,width=200)
image_label = gr.Label(label = "Image label")
sampling_button.click(
sample,
[],
[sample_image, encoded_image, decoded_image, image_label],
)
demo.queue().launch(share=False, inbrowser=True)