Spaces:
Sleeping
Sleeping
import gradio as gr | |
from fastai.vision.all import * | |
import pathlib | |
plt = platform.system() | |
if plt == 'Linux': pathlib.WindowsPath = pathlib.PosixPath | |
def get_x(r): | |
return r['name'] | |
def get_y(r): | |
return r['labels'].split(' ') | |
learner = load_learner('model.pkl') | |
labels = learner.dls.vocab | |
def bla(predicted): | |
ShirtLength = ('Crop_length', 'Regular_length', 'Long_length', 'ShirtLength_other') | |
ShirtNeck = ('Round_neck', 'Tailored_collar_neck', 'Turtle_neck', 'V_neck', 'ShirtNeck_other') | |
ShirtSleeveLength = ('Short_sleeve', 'Long_sleeve', 'Sleeveless', 'ShirtSleeveLength_other') | |
PatternPlacement = ('No_pattern', 'Pattern') | |
shirtlength_idx = [labels.o2i[s] for s in ShirtLength] | |
shirtneck_idx = [labels.o2i[s] for s in ShirtNeck] | |
shirtsleevelength_idx = [labels.o2i[s] for s in ShirtSleeveLength] | |
patternplacement_idx = [labels.o2i[s] for s in PatternPlacement] | |
shirtlength_pred = predicted[2][shirtlength_idx] | |
shirtneck_pred = predicted[2][shirtneck_idx] | |
shirtsleevelength_pred = predicted[2][shirtsleevelength_idx] | |
patternplacement_pred = predicted[2][patternplacement_idx] | |
val, ind = shirtlength_pred.sort(descending=True) | |
#l1 = {ShirtLength[i]: float(shirtlength_pred[i]) for i in ind} | |
l1 = {ShirtLength[ind[0]]: float(shirtlength_pred[ind[0]])} | |
val, ind = shirtneck_pred.sort(descending=True) | |
#l2 = {ShirtNeck[i]: float(shirtneck_pred[i]) for i in ind} | |
l2 = {ShirtNeck[ind[0]]: float(shirtneck_pred[ind[0]])} | |
val, ind = shirtsleevelength_pred.sort(descending=True) | |
#l3 = {ShirtSleeveLength[i]: float(shirtsleevelength_pred[i]) for i in ind} | |
l3 = {ShirtSleeveLength[ind[0]]: float(shirtsleevelength_pred[ind[0]])} | |
val, ind = patternplacement_pred.sort(descending=True) | |
#l4 = {PatternPlacement[i]: float(patternplacement_pred[i]) for i in ind} | |
l4 = {PatternPlacement[ind[0]]: float(patternplacement_pred[ind[0]])} | |
l1.update(l2) | |
l1.update(l3) | |
l1.update(l4) | |
return l1 | |
def predict(img): | |
img = PILImage.create(img) | |
# pred,pred_idx,probs = learner.predict(img) | |
# return {labels[i]: float(probs[i]) for i in range(len(labels))} | |
pred = learner.predict(img) | |
return bla(pred) | |
title = "Multi-Class Classifier" | |
description = "Fasion multi-class classifier" | |
article="<p style='text-align: center'><a href='https://tmabraham.github.io/blog/gradio_hf_spaces_tutorial' target='_blank'>Blog post</a></p>" | |
examples = ['demo1.jpg', 'demo2.jpg', 'demo3.jpg', 'demo4.jpg', 'demo5.jpg'] | |
interpretation='default' | |
enable_queue=True | |
gr.Interface(fn=predict, | |
inputs=gr.inputs.Image(shape=(300, 300)), | |
outputs=gr.outputs.Label(), | |
title=title, | |
description=description, | |
article=article, | |
examples=examples, | |
interpretation=interpretation, | |
enable_queue=enable_queue).launch() | |