balakrish181's picture
mapped to cpu
761f68b
import gradio as gr
from model import model_classification
import torch,os
path = 'efficient_cat_dog.pth'
class_names = ['cat','dog']
model,transforms = model_classification()
model.load_state_dict(torch.load(path,map_location=torch.device('cpu')))
def predict(img):
img = transforms(img).unsqueeze(0)
model.eval()
with torch.inference_mode():
logits = model(img)
pred_probs = torch.softmax(logits,dim=1)
pred_label_and_probs = {class_names[i]: float(pred_probs[0][i]) for i in range(len(class_names))}
return pred_label_and_probs
title = 'Cat and Dog classification'
description = 'An EfficientNetB0 feature extractor computert vision model to classify the cats and dogs'
example_list = [["examples/" + example] for example in os.listdir("examples")]
demo = gr.Interface(fn=predict,
inputs=gr.Image(type='pil'),
outputs=gr.Label(num_top_classes=2,label='Predictions'),
title=title,
examples=example_list,
description=description,
)
demo.launch(share=True)