etweedy commited on
Commit
cbb6fd7
·
1 Parent(s): 13a3f81

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -14
app.py CHANGED
@@ -2,40 +2,50 @@ import torch
2
  from torch import nn
3
  import gradio as gr
4
 
 
5
  class CNN(nn.Module):
6
- def __init__(self):
7
- super(CNN,self).__init__()
8
-
9
- self.conv1 = nn.Sequential(
10
- nn.Conv2d(1,16,5,stride=1,padding=2),
11
- nn.ReLU(),
12
- nn.MaxPool2d(kernel_size=2),
13
- )
14
- self.conv2 = nn.Sequential(
15
- nn.Conv2d(16,32,5,1,2),
16
- nn.ReLU(),
17
- nn.MaxPool2d(2),
18
- )
 
 
19
  self.out = nn.Linear(32*7*7,10)
20
 
 
21
  def forward(self,x):
22
  x=self.conv1(x)
23
  x=self.conv2(x)
24
  x = x.view(-1,32*7*7)
25
  return self.out(x)
26
 
 
27
  model = CNN()
28
  model.load_state_dict(torch.load('mnist2.pkl',map_location=torch.device('cpu')))
29
  model.eval()
30
 
 
31
  def predict(img):
32
  x = torch.tensor(img, dtype=torch.float32).unsqueeze(0).unsqueeze(0) / 255.
33
  with torch.no_grad():
34
  pred = model(x)[0]
35
  return int(pred.argmax())
36
 
37
-
 
 
38
  gr.Interface(fn=predict,
39
  inputs="sketchpad",
40
  outputs="label",
 
 
41
  ).launch()
 
2
  from torch import nn
3
  import gradio as gr
4
 
5
+ # Define the custom CNN model class that was trained on the MNIST data
6
  class CNN(nn.Module):
7
+ """
8
+ A custom CNN class. The network has: (1) a convolution layer with 1 input channel and 16 output channels with ReLU activation and 2x2 max-pooling, (2) a second convolution layer with 16 input channels and 32 output channels with ReLU activation and 2x2 max-pooling, and (3) a linear output layer with 10 outputs.
9
+ """
10
+ def __init__(self):
11
+ super(CNN,self).__init__()
12
+ self.conv1 = nn.Sequential(
13
+ nn.Conv2d(1,16,5,stride=1,padding=2),
14
+ nn.ReLU(),
15
+ nn.MaxPool2d(kernel_size=2),
16
+ )
17
+ self.conv2 = nn.Sequential(
18
+ nn.Conv2d(16,32,5,1,2),
19
+ nn.ReLU(),
20
+ nn.MaxPool2d(2),
21
+ )
22
  self.out = nn.Linear(32*7*7,10)
23
 
24
+ # Forward propogation method
25
  def forward(self,x):
26
  x=self.conv1(x)
27
  x=self.conv2(x)
28
  x = x.view(-1,32*7*7)
29
  return self.out(x)
30
 
31
+ # Initialize an instance and load in the saved state_dict for the trained model
32
  model = CNN()
33
  model.load_state_dict(torch.load('mnist2.pkl',map_location=torch.device('cpu')))
34
  model.eval()
35
 
36
+ # Prediction function
37
  def predict(img):
38
  x = torch.tensor(img, dtype=torch.float32).unsqueeze(0).unsqueeze(0) / 255.
39
  with torch.no_grad():
40
  pred = model(x)[0]
41
  return int(pred.argmax())
42
 
43
+ # Define and launch gradio interfact with sketchopad input and classification label output
44
+ title = "Guess that digit"
45
+ description = "Draw your favorite base-10 digit (0-9) and click submit - I'll try to guess what you drew! I do a bit better if you're not too messy and your digit is fairly centered."
46
  gr.Interface(fn=predict,
47
  inputs="sketchpad",
48
  outputs="label",
49
+ title = title,
50
+ description = description,
51
  ).launch()