import streamlit as st import torch import torch.nn as nn from torchvision import transforms from PIL import Image from huggingface_hub import hf_hub_download import os # Define the CNN model class class CNNClassifier(nn.Module): def __init__(self, n_classes): super(CNNClassifier, self).__init__() self.model = nn.Sequential( nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(64), nn.ReLU(), nn.MaxPool2d(2, stride=2), nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(128), nn.ReLU(), nn.Dropout(0.2), nn.MaxPool2d(2, stride=2), nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(256), nn.ReLU(), nn.MaxPool2d(2, stride=2), nn.AdaptiveAvgPool2d(1), nn.Flatten(), nn.Linear(256, n_classes) ) def forward(self, x): return self.model(x) hf_token = os.getenv("HF_TOKEN") # Load the model from Hugging Face model_path = hf_hub_download(repo_id="louiecerv/cats_dogs_recognition_torch_cnn", filename="cats_dogs_classifier.pth", use_auth_token=hf_token) n_classes = 2 model = CNNClassifier(n_classes) model.load_state_dict(torch.load(model_path)) model.eval() # Define the transformation pipeline transform = transforms.Compose([ transforms.Resize((128, 128)), transforms.ToTensor(), ]) # Streamlit app st.title("Cat vs Dog Classifier") st.write("Upload an image and the model will classify it as a cat or a dog.") uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) if uploaded_file is not None: image = Image.open(uploaded_file) st.image(image, caption="Uploaded Image", use_container_width=True) # Preprocess the image image = transform(image).unsqueeze(0) # Make prediction with torch.no_grad(): outputs = model(image) _, predicted = torch.max(outputs, 1) label = "Cat" if predicted.item() == 0 else "Dog" st.write(f"The model predicts this image is a: **{label}**")