ghassen-fatnassi
wip
0dfb5a6
import torch
from torchvision import transforms
from PIL import Image
import gradio as gr
import timm
device = "cuda" if torch.cuda.is_available() else "cpu"
checkpoint = torch.load("./model.pth", map_location=torch.device(device))
model = timm.create_model("efficientnet_b0", pretrained=False, num_classes=12)
model.load_state_dict(checkpoint['model_state_dict'])
model = model.to(device)
model.eval()
class_labels = [
'battery',
'biological',
'brown-glass',
'cardboard',
'clothes',
'green-glass',
'metal',
'paper',
'plastic',
'shoes',
'trash',
'white-glass'
]
transform = transforms.Compose([
transforms.Resize((224, 224)), # EfficientNet-B0 input size
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
def predict(image):
image = transform(image).unsqueeze(0).to(device) # Add batch dimension and move to device
with torch.inference_mode():
output = model(image)
_, predicted = torch.max(output, 1)
label = class_labels[predicted.item()]
return label
interface = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil"),
outputs="text",
title="TSYP Garbage Classification Model",
description="Upload an image of garbage to classify it into one of 12 categories(make sure it's the only thing in the photo , except background)"
)
interface.launch()