import gradio as gr
import torch 
import torchvision
import os 

from model import model_efficientb3
from timeit import default_timer as Timer
from typing import Tuple,Dict

class_name=["pizza","steak","sushi"]

effnetb3,effentb3_tranforms=model_efficientb3(out_feature=3)

effnetb3.load_state_dict(
    torch.load(
        f="09_pretrained_effnetb3_feature_extractor_pizza_steak_sushi_20_percent.pth",
        map_location=torch.device("cpu")
    )
)

def predict(img) -> Tuple[Dict,float]:

    start_time=Timer()

    img=effentb3_tranforms(img).unsqueeze(0)

    effnetb3.eval()
    with torch.inference_mode():
        pred_probs=torch.softmax(effnetb3(img),dim=1)

    pred_labels_and_probs={class_name[i]: float(pred_probs[0][i]) for i in range(len(class_name))}

    pred_time=round(Timer()-start_time,5)

    return pred_labels_and_probs,pred_time


title="FoodVision Mini 🍕🥩🍣"
description= "An EfficientNetB2 feature extractor computer vision model to classify images of food as pizza, steak or sushi."
article="tryin to learn pytorch"

example_list = [["examples/" + example] for example in os.listdir("examples")]

demo=gr.Interface(fn=predict,
                  inputs=gr.Image(type="pil"),
                  outputs=[gr.Label(num_top_classes=3,label="Prediction"),
                           gr.Number(label="Prediction time (s)")],
                  examples=example_list,
                  title=title,
                  description=description,
                  article=article)

demo.launch()