import gradio as gr import timm model = timm.create_model("MobileOne_s2", pretrained = False) model.head.fc = nn.Linear(model.head.fc.in_features,3) model.load_state_dict(torch.load("olive_classifier.pth", weights_only=True)) model.eval() categories = ("Aculus Olearius", "Healthy", "Peacock Spot") def classify_health(input_img): probs = model(input_img) idx = probs.argmax() return dict(zip(categories, map(float, probs))) labels = gr.Label() examples = [ "examples/healthy.jpg", "examples/aculus_2.jpg", "examples/peacock_3.jpeg", ] demo = gr.Interface( classify_health, inputs=gr.Image(height=224, width=224), outputs=labels, examples=examples, ) demo.launch(inline=False)