Spaces:
Sleeping
Sleeping
"""The main parts are: | |
1. Imports and class names setup | |
2. Model and transforms preparation | |
3. Write a predict function for gradio to use | |
4. Write the Gradio app and the launch command | |
""" | |
import os | |
from typing import Tuple, Dict, List | |
import PIL | |
import torch | |
from torch import nn | |
import torchvision | |
import gradio as gr | |
from timeit import default_timer as timer | |
from model import create_effnetb2_model | |
class_names = ['pizza', 'steak', 'sushi'] #hardcoded as a list | |
model, transforms = create_effnetb2_model(num_classes = len(class_names)) | |
# Load saved weights into the model, and load the model onto the CPU | |
model.load_state_dict(torch.load(f = "finetuned_effnetb2_20percent.pth", map_location = torch.device('cpu'))) | |
# Write function to run inference on gradio | |
def predict(img: PIL.Image, | |
model: nn.Module = model, | |
transforms: torchvision.transforms = transforms, | |
class_names: List[str] = class_names) -> Tuple[Dict, float]: | |
"""Function to predict image class on gradio | |
Args: | |
img (np.array): Image as a numpy array | |
model (nn.Module, optional): Model. Defaults to vit. | |
class_names (List[str], optional): List of class anmes. Defaults to class_names. | |
Returns: | |
Tuple[Dict, float]: Tuplefor further processing on gradio | |
""" | |
start_time = timer() | |
img = transforms(img).unsqueeze(0) #add batch dimension | |
model.eval() | |
with torch.inference_mode(): | |
pred_probs = torch.softmax(model(img), dim = 1) | |
pred_labels_and_probs = {class_names[i]: float(pred_probs[0][i]) for i in range(len(class_names))} | |
end_time = timer() | |
pred_time = round(end_time - start_time, 4) | |
return pred_labels_and_probs, pred_time | |
# Create example_list | |
example_list = [["examples/" + example] for example in os.listdir("examples")] | |
# Create Gradio App | |
title = 'FoodVision Mini ππ₯©π£' | |
description = "Using a [Vision Transformer](https://arxiv.org/abs/2010.11929) for Image Classification" | |
article = "Created by [Titus Lim](https://github.com/tituslhy)" | |
demo = gr.Interface(fn = predict, | |
inputs = gr.Image(type = "pil"), | |
outputs = [gr.Label(num_top_classes = 3, label = "Predictions"), | |
gr.Number(label = "Prediction time (s)")], | |
examples = example_list, | |
title = title, | |
description = description, | |
article = article) | |
# Launch demo | |
demo.launch(debug = False, #prints errors locally | |
) | |