Last commit not found
from PIL import Image | |
import gradio as gr | |
import torch | |
import torchvision.transforms as transforms | |
from model import * | |
title = "Garment Classifier" | |
description = "Trained on the Fashion MNIST dataset (28x28 pixels). The model expects images containing only one garment article as in the examples." | |
inputs = gr.components.Image() | |
outputs = gr.components.Label() | |
examples = "examples" | |
model = torch.load("model/fashion.mnist.base.pt", map_location=torch.device("cpu")) | |
# Images need to be transformed to the `Fashion MNIST` dataset format | |
# see https://arxiv.org/abs/1708.07747 | |
transform = transforms.Compose( | |
[ | |
transforms.Resize((28, 28)), | |
transforms.Grayscale(), | |
transforms.ToTensor(), | |
transforms.Normalize((0.5,), (0.5,)), # Normalization | |
transforms.Lambda(lambda x: 1.0 - x), # Invert colors | |
transforms.Lambda(lambda x: x[0]), | |
transforms.Lambda(lambda x: x.unsqueeze(0)), | |
] | |
) | |
def predict(img): | |
img = transform(Image.fromarray(img)) | |
predictions = model.predictions(img) | |
return predictions | |
with gr.Blocks() as demo: | |
with gr.Tab("Garment Prediction"): | |
gr.Interface( | |
fn=predict, | |
inputs=inputs, | |
outputs=outputs, | |
examples=examples, | |
description=description, | |
).queue(default_concurrency_limit=5) | |
demo.launch() | |