resberry commited on
Commit
5e28cf6
·
verified ·
1 Parent(s): 3970115

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -0
app.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torchvision.transforms as transforms
4
+ from PIL import Image
5
+
6
+ # Define the CNN model
7
+ class CNN(torch.nn.Module):
8
+ def __init__(self):
9
+ super(CNN, self).__init__()
10
+ self.conv1 = torch.nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
11
+ self.conv2 = torch.nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
12
+ self.pool = torch.nn.MaxPool2d(2, 2)
13
+ self.fc1 = torch.nn.Linear(64 * 14 * 14, 128)
14
+ self.fc2 = torch.nn.Linear(128, 10)
15
+ self.relu = torch.nn.ReLU()
16
+ self.dropout = torch.nn.Dropout(0.25)
17
+
18
+ def forward(self, x):
19
+ x = self.relu(self.conv1(x))
20
+ x = self.pool(self.relu(self.conv2(x)))
21
+ x = x.view(x.size(0), -1) # Flatten dynamically based on batch size
22
+ x = self.relu(self.fc1(x))
23
+ x = self.dropout(x)
24
+ x = self.fc2(x)
25
+ return x
26
+
27
+ # Load the trained model
28
+ model = CNN()
29
+ model.load_state_dict(torch.load("pytorch_model.bin", map_location=torch.device('cpu')))
30
+ model.eval()
31
+
32
+ # Define the prediction function
33
+ def predict(image):
34
+ transform = transforms.Compose([
35
+ transforms.Grayscale(), # Ensure the input image is grayscale
36
+ transforms.Resize((28, 28)), # Resize the image to 28x28 pixels
37
+ transforms.ToTensor(),
38
+ transforms.Normalize((0.5,), (0.5,)) # Normalize the image
39
+ ])
40
+ image_tensor = transform(image).unsqueeze(0) # Add batch dimension
41
+ with torch.no_grad():
42
+ output = model(image_tensor)
43
+ predicted_class = output.argmax(dim=1).item() # Get the predicted class
44
+ return f"Predicted digit: {predicted_class}"
45
+
46
+ # Create the Gradio interface
47
+ interface = gr.Interface(
48
+ fn=predict,
49
+ inputs=gr.Image(type="pil"), # Updated input component
50
+ outputs="text",
51
+ title="Handwritten Digit Classifier",
52
+ description="Upload an image of a handwritten digit, and the model will predict the digit."
53
+ )
54
+
55
+ # Launch the Gradio app
56
+ if __name__ == "__main__":
57
+ interface.launch()