mmek's picture
Initial Commit
b2e7f62
raw
history blame
723 Bytes
import gradio as gr
import timm
model = timm.create_model("MobileOne_s2", pretrained = False)
model.head.fc = nn.Linear(model.head.fc.in_features,3)
model.load_state_dict(torch.load("olive_classifier.pth", weights_only=True))
model.eval()
categories = ("Aculus Olearius", "Healthy", "Peacock Spot")
def classify_health(input_img):
probs = model(input_img)
idx = probs.argmax()
return dict(zip(categories, map(float, probs)))
labels = gr.Label()
examples = [
"examples/healthy.jpg",
"examples/aculus_2.jpg",
"examples/peacock_3.jpeg",
]
demo = gr.Interface(
classify_health,
inputs=gr.Image(height=224, width=224),
outputs=labels,
examples=examples,
)
demo.launch(inline=False)