halictus's picture
Update app.py
f382800 verified
raw
history blame
1.93 kB
import gradio as gr
import requests
from PIL import Image
from io import BytesIO
from transformers import pipeline
import torch
from torchvision import transforms
# Cache the model loading
model_id = "Honey-Bee-Society/honeybee_bumblebee_vespidae_resnet50"
classifier = pipeline("image-classification", model=model_id, device=0 if torch.cuda.is_available() else -1)
# Define the same preprocessing steps as in the training script
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
def classify_image_from_url(image_url: str):
"""
Downloads an image from a public URL, preprocesses it, and runs it through
the ResNet-50 fine-tuned image-classification pipeline, returning the top predictions.
"""
try:
# Download the image
response = requests.get(image_url)
response.raise_for_status()
image = Image.open(BytesIO(response.content)).convert("RGB")
# Apply the same preprocessing as in the training script
image_tensor = preprocess(image).unsqueeze(0) # Add batch dimension
# Run inference
results = classifier(image_tensor)
# Format scores to remove scientific notation
for r in results:
r["score"] = float(f"{r['score']:.8f}")
return results
except requests.exceptions.RequestException as e:
return {"error": f"Failed to download image: {str(e)}"}
except Exception as e:
return {"error": f"An error occurred during classification: {str(e)}"}
demo = gr.Interface(
fn=classify_image_from_url,
inputs=gr.Textbox(lines=1, label="Image URL"),
outputs="json",
title="ResNet-50 Image Classifier",
description="Enter public image URL to get top predictions."
)
if __name__ == "__main__":
demo.launch()