File size: 2,542 Bytes
83f96f9
 
 
 
 
 
 
c34ea29
83f96f9
 
c34ea29
83f96f9
 
c34ea29
83f96f9
 
 
 
 
 
a4b0982
83f96f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
"""The main parts are:
1. Imports and class names setup
2. Model and transforms preparation
3. Write a predict function for gradio to use
4. Write the Gradio app and the launch command
"""
import os
from typing import Tuple, Dict, List
import PIL
import torch
from torch import nn
import torchvision
import gradio as gr
from timeit import default_timer as timer
from model import create_effnetb2_model

class_names = ['pizza', 'steak', 'sushi'] #hardcoded as a list
model, transforms = create_effnetb2_model(num_classes = len(class_names))

# Load saved weights into the model, and load the model onto the CPU
model.load_state_dict(torch.load(f = "finetuned_effnetb2_20percent.pth", map_location = torch.device('cpu')))

# Write function to run inference on gradio
def predict(img: PIL.Image, 
            model: nn.Module = model,
            transforms: torchvision.transforms = transforms,
            class_names: List[str] = class_names) -> Tuple[Dict, float]:
    """Function to predict image class on gradio

    Args:
        img (np.array): Image as a numpy array
        model (nn.Module, optional): Model. Defaults to vit.
        class_names (List[str], optional): List of class anmes. Defaults to class_names.

    Returns:
        Tuple[Dict, float]: Tuplefor further processing on gradio
    """
    start_time = timer()
    img = transforms(img).unsqueeze(0) #add batch dimension
    model.eval()
    with torch.inference_mode():
        pred_probs = torch.softmax(model(img), dim = 1)
    pred_labels_and_probs = {class_names[i]: float(pred_probs[0][i]) for i in range(len(class_names))}
    end_time = timer()
    pred_time = round(end_time - start_time, 4)
    return pred_labels_and_probs, pred_time

# Create example_list
example_list = [["examples/" + example] for example in os.listdir("examples")]

# Create Gradio App
title = 'FoodVision Mini 🍕🥩🍣'
description = "Using a [Vision Transformer](https://arxiv.org/abs/2010.11929) for Image Classification"
article = "Created by [Titus Lim](https://github.com/tituslhy)"

demo = gr.Interface(fn = predict,
                    inputs = gr.Image(type = "pil"),
                    outputs = [gr.Label(num_top_classes = 3, label = "Predictions"),
                               gr.Number(label = "Prediction time (s)")],
                    examples = example_list,
                    title = title,
                    description = description,
                    article = article)

# Launch demo
demo.launch(debug = False, #prints errors locally
            )