Spaces:
Sleeping
Sleeping
# app.py | |
import gradio as gr | |
import torch | |
from torchvision import transforms | |
from torchvision.models import resnet18 | |
from PIL import Image | |
import numpy as np | |
# Class names (adjust if yours are different) | |
class_names = ['cardboard', 'glass', 'metal', 'paper', 'plastic', 'trash'] | |
# Load model | |
model = resnet18(weights=None) | |
model.fc = torch.nn.Linear(model.fc.in_features, 6) | |
model.load_state_dict(torch.load("garbage_classifier.pt", map_location="cpu")) | |
model.eval() | |
# Image preprocessing | |
transform = transforms.Compose([ | |
transforms.Resize((224, 224)), | |
transforms.ToTensor(), | |
]) | |
# Prediction function | |
def classify_image(img): | |
# Convert numpy.ndarray to PIL.Image | |
if isinstance(img, np.ndarray): | |
img = Image.fromarray(img) | |
img = transform(img).unsqueeze(0) | |
with torch.no_grad(): | |
logits = model(img) | |
probs = torch.nn.functional.softmax(logits[0], dim=0) | |
return {class_names[i]: float(probs[i]) for i in range(6)} | |
# Launch Gradio interface | |
gr.Interface(fn=classify_image, inputs="image", outputs="label", title="Trash Classifier").launch() | |