import gradio as gr import timm import torch from torchvision import transforms model = timm.create_model("mobileone_s2", pretrained = False) model.head.fc = torch.nn.Linear(model.head.fc.in_features,3) data_transforms = transforms.Compose(timm.data.create_transform(**timm.data.resolve_data_config(model.pretrained_cfg)).transforms) model.load_state_dict(torch.load("olive-classifier.pth", map_location=torch.device('cpu'), weights_only=True)) model.eval() categories = ("Aculus Olearius", "Healthy", "Peacock Spot") def classify_health(input_img): input_img = transforms.ToTensor()(input_img) with torch.no_grad(): image = data_transforms(input_img).unsqueeze(0) output = model(image) probs = torch.nn.functional.softmax(output, dim=1) idx = probs.argmax(dim=1) return dict(zip(categories, map(float, probs[0]))) labels = gr.Label() examples = [ "examples/healthy.jpg", "examples/aculus_2.jpg", "examples/peacock_3.jpg", ] demo = gr.Interface( classify_health, inputs=gr.Image(height=224, width=224), outputs=labels, examples=examples, ) demo.launch(inline=False)