File size: 1,781 Bytes
1d31025
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import gradio as gr
import os
import torch

from pathlib import Path

from model import create_effnetb3_model
from timeit import default_timer as timer
from typing import Tuple, Dict

class_names = ['Banh beo', 'Banh bot loc', 'Banh can', 'Banh canh', 'Banh chung','Banh cuon', 'Banh duc', 'Banh gio','Banh khot',
 'Banh mi','Banh pia', 'Banh tet', 'Banh trang nuong', 'Banh xeo', 'Bun bo Hue', 'Bun dau mam tom','Bun mam', 'Bun rieu', 'Bun thit nuong',
 'Ca kho to', 'Canh chua', 'Cao lau', 'Chao long', 'Com tam', 'Goi cuon', 'Hu tieu', 'Mi quang', 'Nem chua', 'Pho', 'Xoi xeo']

effnetb3, effnetb3_transforms = create_effnetb3_model(num_classes=30)

effnetb3.load_state_dict(
    torch.load(
        f= "./models/pretrained_effnetb3_vietnamese_food.pth",
        map_location=torch.device("cpu")
    )
)

def predict(img) -> Tuple[Dict, float]:
  start_time = timer()
  img = effnetb3_transforms(img).unsqueeze(0)
  
  effnetb3.eval()
  with torch.inference_mode():
    pred_probs = torch.softmax(effnetb3(img), dim = 1)
  
  pred_labels_and_probs = {class_names[i]: float(pred_probs[0][i]) for i in range(len(class_names))}

  pred_time = round(timer() - start_time, 4)

  return pred_labels_and_probs, pred_time

title = "Vietnamese food vision"
description = "An EfficientNetB3 feature extractor computer vision model"

example_list = [["examples/" + example] for example in os.listdir("examples")]

demo = gr.Interface(fn=predict,
                    inputs=gr.Image(type="pil"),
                    outputs=[gr.Label(num_top_classes=3, label="Prediction"),
                             gr.Number(label="Prediction time (s)")],
                    examples=example_list,
                    title=title,
                    description=description)

demo.launch(share=True)