import gradio as gr import torch import torch.nn.functional as F from torchvision import models, transforms from PIL import Image import requests from io import BytesIO # 1) Create the same resnet50 architecture with correct num_classes model = models.resnet50() model.fc = torch.nn.Linear(model.fc.in_features, 3) # 2) Load your state dict (uploaded to your HF repo) checkpoint_url = "https://huggingface.co/Honey-Bee-Society/honeybee_bumblebee_vespidae_resnet50/resolve/main/resnet50_best.pth" state_dict = torch.hub.load_state_dict_from_url(checkpoint_url, map_location="cpu") model.load_state_dict(state_dict) model.eval() # 3) Define transforms that match your training code inference_transforms = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) class_names = ["bumblebee", "honeybee", "vespidae"] def classify_image_from_url(image_url: str): try: response = requests.get(image_url) response.raise_for_status() pil_image = Image.open(BytesIO(response.content)).convert("RGB") # Apply transforms input_tensor = inference_transforms(pil_image).unsqueeze(0) # Forward pass with torch.no_grad(): logits = model(input_tensor) probs = F.softmax(logits, dim=1).squeeze().numpy() # Sort by confidence descending sorted_indices = probs.argsort()[::-1] results = [(class_names[i], f"{probs[i]:.4f}") for i in sorted_indices] return pil_image, results except Exception as e: return None, [["Error", str(e)]] demo = gr.Interface( fn=classify_image_from_url, inputs=gr.Textbox(lines=1, label="Image URL"), outputs=[ gr.Image(label="Input Image"), gr.Dataframe(headers=["Label", "Confidence"]) ], title="ResNet-50 Image Classifier", description="Enter public image URL to get top predictions." ) if __name__ == "__main__": demo.launch()