foodvision_mini / app.py
tituslhy's picture
bug fix
f535307
"""The main parts are:
1. Imports and class names setup
2. Model and transforms preparation
3. Write a predict function for gradio to use
4. Write the Gradio app and the launch command
"""
import os
from typing import Tuple, Dict, List
import PIL
import torch
from torch import nn
import torchvision
import gradio as gr
from timeit import default_timer as timer
from model import create_effnetb2_model
class_names = ['pizza', 'steak', 'sushi'] #hardcoded as a list
model, transforms = create_effnetb2_model(num_classes = len(class_names))
# Load saved weights into the model, and load the model onto the CPU
model.load_state_dict(torch.load(f = "finetuned_effnetb2_20percent.pth", map_location = torch.device('cpu')))
# Write function to run inference on gradio
def predict(img: PIL.Image,
model: nn.Module = model,
transforms: torchvision.transforms = transforms,
class_names: List[str] = class_names) -> Tuple[Dict, float]:
"""Function to predict image class on gradio
Args:
img (np.array): Image as a numpy array
model (nn.Module, optional): Model. Defaults to vit.
class_names (List[str], optional): List of class anmes. Defaults to class_names.
Returns:
Tuple[Dict, float]: Tuplefor further processing on gradio
"""
start_time = timer()
img = transforms(img).unsqueeze(0) #add batch dimension
model.eval()
with torch.inference_mode():
pred_probs = torch.softmax(model(img), dim = 1)
pred_labels_and_probs = {class_names[i]: float(pred_probs[0][i]) for i in range(len(class_names))}
end_time = timer()
pred_time = round(end_time - start_time, 4)
return pred_labels_and_probs, pred_time
# Create example_list
example_list = [["examples/" + example] for example in os.listdir("examples")]
# Create Gradio App
title = 'FoodVision Mini πŸ•πŸ₯©πŸ£'
description = "Using a [Vision Transformer](https://arxiv.org/abs/2010.11929) for Image Classification"
article = "Created by [Titus Lim](https://github.com/tituslhy)"
demo = gr.Interface(fn = predict,
inputs = gr.Image(type = "pil"),
outputs = [gr.Label(num_top_classes = 3, label = "Predictions"),
gr.Number(label = "Prediction time (s)")],
examples = example_list,
title = title,
description = description,
article = article)
# Launch demo
demo.launch(debug = False, #prints errors locally
)