Spaces:
Runtime error
Runtime error
import gradio as gr | |
import model_builder as mb | |
from torchvision import transforms | |
import torch | |
device = torch.device("cpu") | |
normalize = transforms.Normalize( | |
mean=[0.485, 0.456, 0.406], | |
std=[0.229, 0.224, 0.225]) | |
manual_transform = transforms.Compose([ | |
transforms.ToPILImage(), | |
transforms.Resize(size=(224, 224)), | |
transforms.ToTensor(), | |
normalize | |
]) | |
class_names = ['Fresh Banana', | |
'Fresh Lemon', | |
'Fresh Lulo', | |
'Fresh Mango', | |
'Fresh Orange', | |
'Fresh Strawberry', | |
'Fresh Tamarillo', | |
'Fresh Tomato', | |
'Spoiled Banana', | |
'Spoiled Lemon', | |
'Spoiled Lulo', | |
'Spoiled Mango', | |
'Spoiled Orange', | |
'Spoiled Strawberry', | |
'Spoiled Tamarillo', | |
'Spoiled Tomato'] | |
model_0 = mb.create_model_baseline_effnetb0(out_feats=len(class_names), device=device) | |
model_0.load_state_dict(torch.load(f="models/effnetb0_fruitsvegs0_5_epochs.pt", map_location="cpu")) | |
def pred(img): | |
model_0.eval() | |
transformed = manual_transform(img).to(device) | |
with torch.inference_mode(): | |
logits = model_0(transformed.unsqueeze(dim=0)) | |
pred = torch.softmax(logits, dim=-1) | |
return f"prediction: {class_names[pred.argmax(dim=-1).item()]} | confidence: {pred.max():.3f}" | |
demo = gr.Interface(pred, gr.Image(), outputs="text") | |
demo.launch() |