louiecerv commited on
Commit
744b6e5
·
1 Parent(s): ba18a4d

sync to remote

Browse files
Files changed (2) hide show
  1. app.py +73 -0
  2. requirements.txt +5 -0
app.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ import torch.nn as nn
4
+ from torchvision import transforms
5
+ from PIL import Image
6
+ from huggingface_hub import hf_hub_download
7
+ import os
8
+
9
+ # Define the CNN model class
10
+ class CNNClassifier(nn.Module):
11
+ def __init__(self, n_classes):
12
+ super(CNNClassifier, self).__init__()
13
+ self.model = nn.Sequential(
14
+ nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
15
+ nn.BatchNorm2d(64),
16
+ nn.ReLU(),
17
+ nn.MaxPool2d(2, stride=2),
18
+
19
+ nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
20
+ nn.BatchNorm2d(128),
21
+ nn.ReLU(),
22
+ nn.Dropout(0.2),
23
+ nn.MaxPool2d(2, stride=2),
24
+
25
+ nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
26
+ nn.BatchNorm2d(256),
27
+ nn.ReLU(),
28
+ nn.MaxPool2d(2, stride=2),
29
+
30
+ nn.AdaptiveAvgPool2d(1),
31
+ nn.Flatten(),
32
+ nn.Linear(256, n_classes)
33
+ )
34
+
35
+ def forward(self, x):
36
+ return self.model(x)
37
+
38
+ hf_token = os.getenv("HF_TOKEN")
39
+
40
+ # Load the model from Hugging Face
41
+ model_path = hf_hub_download(repo_id="louiecerv/cats_dogs_recognition_torch_cnn",
42
+ filename="cats_dogs_classifier.pth", use_auth_token=hf_token)
43
+ n_classes = 2
44
+ model = CNNClassifier(n_classes)
45
+ model.load_state_dict(torch.load(model_path))
46
+ model.eval()
47
+
48
+ # Define the transformation pipeline
49
+ transform = transforms.Compose([
50
+ transforms.Resize((128, 128)),
51
+ transforms.ToTensor(),
52
+ ])
53
+
54
+ # Streamlit app
55
+ st.title("Cat vs Dog Classifier")
56
+ st.write("Upload an image and the model will classify it as a cat or a dog.")
57
+
58
+ uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
59
+
60
+ if uploaded_file is not None:
61
+ image = Image.open(uploaded_file)
62
+ st.image(image, caption="Uploaded Image", use_container_width=True)
63
+
64
+ # Preprocess the image
65
+ image = transform(image).unsqueeze(0)
66
+
67
+ # Make prediction
68
+ with torch.no_grad():
69
+ outputs = model(image)
70
+ _, predicted = torch.max(outputs, 1)
71
+ label = "Cat" if predicted.item() == 0 else "Dog"
72
+
73
+ st.write(f"The model predicts this image is a: **{label}**")
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ streamlit
2
+ torch
3
+ torchvision
4
+ Pillow
5
+ huggingface_hub