wb-droid commited on
Commit
ee87318
·
1 Parent(s): ca04c47

initial commit

Browse files
Files changed (3) hide show
  1. checkpoint10.pt +3 -0
  2. requirements.txt +2 -0
  3. vae_app.py +177 -0
checkpoint10.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2313ee8482dc5c4fc6e323146e88df3e1a81791f75156c2dfeb627e588fdb4f4
3
+ size 664330
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ gradio
2
+ torchvision
vae_app.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from torch import nn
4
+ import torchvision
5
+ from torchvision.transforms import ToTensor
6
+
7
+ data_test = torchvision.datasets.FashionMNIST(root=".\data", train = False, transform=ToTensor(), download=True)
8
+
9
+ class MyVAE(nn.Module):
10
+ def __init__(self):
11
+ super().__init__()
12
+ self.encoder = nn.Sequential(
13
+ # (conv_in)
14
+ nn.Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), # 28, 28
15
+
16
+ # (down_block_0)
17
+ # (norm1)
18
+ nn.GroupNorm(8, 32, eps=1e-06, affine=True),
19
+ # (conv1)
20
+ nn.Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), #28, 28
21
+ # (norm2):
22
+ nn.GroupNorm(8, 32, eps=1e-06, affine=True),
23
+ # (dropout):
24
+ nn.Dropout(p=0.5, inplace=False),
25
+ # (conv2):
26
+ nn.Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), #28, 28
27
+ # (nonlinearity):
28
+ nn.SiLU(),
29
+ # (downsamplers)(conv):
30
+ nn.Conv2d(32, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)), #14, 14
31
+
32
+
33
+
34
+ # (down_block_1)
35
+ # (norm1)
36
+ nn.GroupNorm(8, 32, eps=1e-06, affine=True),
37
+ # (conv1)
38
+ nn.Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), #28, 28
39
+ # (norm2):
40
+ nn.GroupNorm(8, 64, eps=1e-06, affine=True),
41
+ # (dropout):
42
+ nn.Dropout(p=0.5, inplace=False),
43
+ # (conv2):
44
+ nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), #28, 28
45
+ # (nonlinearity):
46
+ nn.SiLU(),
47
+ # (conv_shortcut):
48
+ #nn.Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), #28, 28
49
+ # (nonlinearity):
50
+ nn.SiLU(),
51
+ # (downsamplers)(conv):
52
+ nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)), #7, 7
53
+
54
+ # (conv_norm_out):
55
+ nn.GroupNorm(16, 64, eps=1e-06, affine=True),
56
+ # (conv_act):
57
+ nn.SiLU(),
58
+ # (conv_out):
59
+ nn.Conv2d(64, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
60
+
61
+ #nn.Conv2d(1, 4, kernel_size=3, stride=2, padding=3//2), # 14*14
62
+ #nn.ReLU(),
63
+ #nn.Conv2d(4, 8, kernel_size=3, stride=2, padding=3//2), # 7*7
64
+ #nn.ReLU(),
65
+ )
66
+
67
+ self.decoder = nn.Sequential(
68
+ #(conv_in):
69
+ nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
70
+
71
+ #(norm1):
72
+ nn.GroupNorm(16, 64, eps=1e-06, affine=True),
73
+ #(conv1):
74
+ nn.Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
75
+ #(norm2):
76
+ nn.GroupNorm(8, 32, eps=1e-06, affine=True),
77
+ #(dropout):
78
+ nn.Dropout(p=0.5, inplace=False),
79
+ #(conv2):
80
+ nn.Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
81
+ #(nonlinearity):
82
+ nn.SiLU(),
83
+
84
+ #(upsamplers):
85
+ nn.Upsample(scale_factor=2, mode='nearest'), # 14,14
86
+
87
+ #(norm1):
88
+ nn.GroupNorm(8, 32, eps=1e-06, affine=True),
89
+ #(conv1):
90
+ nn.Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
91
+ #(norm2):
92
+ nn.GroupNorm(8, 16, eps=1e-06, affine=True),
93
+ #(dropout):
94
+ nn.Dropout(p=0.5, inplace=False),
95
+ #(conv2):
96
+ nn.Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
97
+ #(nonlinearity):
98
+ nn.SiLU(),
99
+
100
+ #(upsamplers):
101
+ nn.Upsample(scale_factor=2, mode='nearest'), # 16, 28, 28
102
+
103
+ #(norm1):
104
+ nn.GroupNorm(8, 16, eps=1e-06, affine=True),
105
+ #(conv1):
106
+ nn.Conv2d(16, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
107
+
108
+ nn.Sigmoid()
109
+ )
110
+
111
+ def forward(self, xb, yb):
112
+ x = self.encoder(xb)
113
+ #print("current:",x.shape)
114
+ x = self.decoder(x)
115
+ #print("current decoder:",x.shape)
116
+ #x = x.flatten(start_dim=1).mean(dim=1, keepdim=True)
117
+ #print(x.shape, xb.shape)
118
+ return x, F.mse_loss(x, xb)
119
+
120
+ labels_map = {
121
+ 0: "T-Shirt",
122
+ 1: "Trouser",
123
+ 2: "Pullover",
124
+ 3: "Dress",
125
+ 4: "Coat",
126
+ 5: "Sandal",
127
+ 6: "Shirt",
128
+ 7: "Sneaker",
129
+ 8: "Bag",
130
+ 9: "Ankle Boot",
131
+ }
132
+
133
+ myVAE2 = torch.load("checkpoint10.pt").to("cpu")
134
+ myVAE2.eval()
135
+
136
+ def sample():
137
+ idx = torch.randint(0, len(data_test), (1,))
138
+ print(idx.item())
139
+ print(data_test[idx.item()][0].squeeze().shape)
140
+ img = data_test[idx.item()][0].squeeze()
141
+ img_original = torchvision.transforms.functional.to_pil_image(img)
142
+ img_encoded = torchvision.transforms.functional.to_pil_image(myVAE2.encoder(img[None,None,...]).squeeze())
143
+ img_decoded = torchvision.transforms.functional.to_pil_image(myVAE2.decoder(myVAE2.encoder(img[None,None,...])).squeeze())
144
+
145
+ return(img_original,img_encoded,img_decoded, labels_map[data_test[idx.item()][1]])
146
+
147
+ with gr.Blocks() as demo:
148
+ gr.HTML("""<h1 align="center">Variational Autoencoder</h1>""")
149
+ gr.HTML("""<h1 align="center">trained with FashionMNIST</h1>""")
150
+ session_data = gr.State([])
151
+
152
+ sampling_button = gr.Button("Sample random FashionMNIST image")
153
+
154
+ with gr.Row():
155
+ with gr.Column(scale=2):
156
+ #gr.Label("Original image")
157
+ gr.HTML("""<h3 align="left">Original image</h1>""")
158
+ sample_image = gr.Image(height=250,width=200)
159
+ with gr.Column(scale=2):
160
+ #gr.Label("Encoded image")
161
+ gr.HTML("""<h3 align="left">Encoded image</h1>""")
162
+ encoded_image = gr.Image(height=250,width=200)
163
+ with gr.Column(scale=2):
164
+ gr.HTML("""<h3 align="left">Decoded image</h1>""")
165
+ #gr.Label("Decoded image")
166
+ decoded_image = gr.Image(height=250,width=200)
167
+
168
+ image_label = gr.Label(label = "Image label")
169
+
170
+
171
+ sampling_button.click(
172
+ sample,
173
+ [],
174
+ [sample_image, encoded_image, decoded_image, image_label],
175
+ )
176
+
177
+ demo.queue().launch(share=False, inbrowser=True)