import gradio as gr import torch from torch import nn import torchvision from torchvision.transforms import ToTensor data_test = torchvision.datasets.FashionMNIST(root=".\data", train = False, transform=ToTensor(), download=True) class MyVAE(nn.Module): def __init__(self): super().__init__() self.encoder = nn.Sequential( # (conv_in) nn.Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), # 28, 28 # (down_block_0) # (norm1) nn.GroupNorm(8, 32, eps=1e-06, affine=True), # (conv1) nn.Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), #28, 28 # (norm2): nn.GroupNorm(8, 32, eps=1e-06, affine=True), # (dropout): nn.Dropout(p=0.5, inplace=False), # (conv2): nn.Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), #28, 28 # (nonlinearity): nn.SiLU(), # (downsamplers)(conv): nn.Conv2d(32, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)), #14, 14 # (down_block_1) # (norm1) nn.GroupNorm(8, 32, eps=1e-06, affine=True), # (conv1) nn.Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), #28, 28 # (norm2): nn.GroupNorm(8, 64, eps=1e-06, affine=True), # (dropout): nn.Dropout(p=0.5, inplace=False), # (conv2): nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), #28, 28 # (nonlinearity): nn.SiLU(), # (conv_shortcut): #nn.Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), #28, 28 # (nonlinearity): nn.SiLU(), # (downsamplers)(conv): nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)), #7, 7 # (conv_norm_out): nn.GroupNorm(16, 64, eps=1e-06, affine=True), # (conv_act): nn.SiLU(), # (conv_out): nn.Conv2d(64, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), #nn.Conv2d(1, 4, kernel_size=3, stride=2, padding=3//2), # 14*14 #nn.ReLU(), #nn.Conv2d(4, 8, kernel_size=3, stride=2, padding=3//2), # 7*7 #nn.ReLU(), ) self.decoder = nn.Sequential( #(conv_in): nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), #(norm1): nn.GroupNorm(16, 64, eps=1e-06, affine=True), #(conv1): nn.Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), #(norm2): nn.GroupNorm(8, 32, eps=1e-06, affine=True), #(dropout): nn.Dropout(p=0.5, inplace=False), #(conv2): nn.Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), #(nonlinearity): nn.SiLU(), #(upsamplers): nn.Upsample(scale_factor=2, mode='nearest'), # 14,14 #(norm1): nn.GroupNorm(8, 32, eps=1e-06, affine=True), #(conv1): nn.Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), #(norm2): nn.GroupNorm(8, 16, eps=1e-06, affine=True), #(dropout): nn.Dropout(p=0.5, inplace=False), #(conv2): nn.Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), #(nonlinearity): nn.SiLU(), #(upsamplers): nn.Upsample(scale_factor=2, mode='nearest'), # 16, 28, 28 #(norm1): nn.GroupNorm(8, 16, eps=1e-06, affine=True), #(conv1): nn.Conv2d(16, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), nn.Sigmoid() ) def forward(self, xb, yb): x = self.encoder(xb) #print("current:",x.shape) x = self.decoder(x) #print("current decoder:",x.shape) #x = x.flatten(start_dim=1).mean(dim=1, keepdim=True) #print(x.shape, xb.shape) return x, F.mse_loss(x, xb) labels_map = { 0: "T-Shirt", 1: "Trouser", 2: "Pullover", 3: "Dress", 4: "Coat", 5: "Sandal", 6: "Shirt", 7: "Sneaker", 8: "Bag", 9: "Ankle Boot", } myVAE2 = torch.load("checkpoint10.pt", map_location=torch.device('cpu')).to("cpu") myVAE2.eval() def sample(): idx = torch.randint(0, len(data_test), (1,)) print(idx.item()) print(data_test[idx.item()][0].squeeze().shape) img = data_test[idx.item()][0].squeeze() img_original = torchvision.transforms.functional.to_pil_image(img) img_encoded = torchvision.transforms.functional.to_pil_image(myVAE2.encoder(img[None,None,...]).squeeze()) img_decoded = torchvision.transforms.functional.to_pil_image(myVAE2.decoder(myVAE2.encoder(img[None,None,...])).squeeze()) return(img_original,img_encoded,img_decoded, labels_map[data_test[idx.item()][1]]) with gr.Blocks() as demo: gr.HTML("""<h1 align="center">Variational Autoencoder</h1>""") gr.HTML("""<h1 align="center">trained with FashionMNIST</h1>""") session_data = gr.State([]) sampling_button = gr.Button("Sample random FashionMNIST image") with gr.Row(): with gr.Column(scale=2): #gr.Label("Original image") gr.HTML("""<h3 align="left">Original image</h1>""") sample_image = gr.Image(height=250,width=200) with gr.Column(scale=2): #gr.Label("Encoded image") gr.HTML("""<h3 align="left">Encoded image</h1>""") encoded_image = gr.Image(height=250,width=200) with gr.Column(scale=2): gr.HTML("""<h3 align="left">Decoded image</h1>""") #gr.Label("Decoded image") decoded_image = gr.Image(height=250,width=200) image_label = gr.Label(label = "Image label") sampling_button.click( sample, [], [sample_image, encoded_image, decoded_image, image_label], ) demo.queue().launch(share=False, inbrowser=True)