import gradio as gr import requests from PIL import Image from io import BytesIO from transformers import pipeline import torch from torchvision import transforms # Cache the model loading model_id = "Honey-Bee-Society/honeybee_bumblebee_vespidae_resnet50" classifier = pipeline("image-classification", model=model_id, device=0 if torch.cuda.is_available() else -1) # Define the same preprocessing steps as in the training script preprocess = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) def classify_image_from_url(image_url: str): """ Downloads an image from a public URL, preprocesses it, and runs it through the ResNet-50 fine-tuned image-classification pipeline, returning the top predictions. """ try: # Download the image response = requests.get(image_url) response.raise_for_status() image = Image.open(BytesIO(response.content)).convert("RGB") # Apply the same preprocessing as in the training script image_tensor = preprocess(image).unsqueeze(0) # Add batch dimension # Run inference results = classifier(image_tensor) # Format scores to remove scientific notation for r in results: r["score"] = float(f"{r['score']:.8f}") return results except requests.exceptions.RequestException as e: return {"error": f"Failed to download image: {str(e)}"} except Exception as e: return {"error": f"An error occurred during classification: {str(e)}"} demo = gr.Interface( fn=classify_image_from_url, inputs=gr.Textbox(lines=1, label="Image URL"), outputs="json", title="ResNet-50 Image Classifier", description="Enter public image URL to get top predictions." ) if __name__ == "__main__": demo.launch()