|
import torch |
|
from torchvision import transforms |
|
from PIL import Image |
|
import gradio as gr |
|
import timm |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
checkpoint = torch.load("./model.pth", map_location=torch.device(device)) |
|
|
|
model = timm.create_model("efficientnet_b0", pretrained=False, num_classes=12) |
|
model.load_state_dict(checkpoint['model_state_dict']) |
|
model = model.to(device) |
|
model.eval() |
|
|
|
class_labels = [ |
|
'battery', |
|
'biological', |
|
'brown-glass', |
|
'cardboard', |
|
'clothes', |
|
'green-glass', |
|
'metal', |
|
'paper', |
|
'plastic', |
|
'shoes', |
|
'trash', |
|
'white-glass' |
|
] |
|
|
|
transform = transforms.Compose([ |
|
transforms.Resize((224, 224)), |
|
transforms.ToTensor(), |
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) |
|
]) |
|
|
|
def predict(image): |
|
image = transform(image).unsqueeze(0).to(device) |
|
with torch.inference_mode(): |
|
output = model(image) |
|
_, predicted = torch.max(output, 1) |
|
label = class_labels[predicted.item()] |
|
return label |
|
|
|
interface = gr.Interface( |
|
fn=predict, |
|
inputs=gr.Image(type="pil"), |
|
outputs="text", |
|
title="TSYP Garbage Classification Model", |
|
description="Upload an image of garbage to classify it into one of 12 categories(make sure it's the only thing in the photo , except background)" |
|
) |
|
|
|
interface.launch() |
|
|