Spaces:
Sleeping
Sleeping
File size: 1,929 Bytes
a31153b d448e99 f382800 a31153b d448e99 e3f40a0 d448e99 a31153b f382800 a31153b f382800 d448e99 a31153b f382800 a31153b 9c858bd f382800 a31153b f382800 b6e309c 9c858bd 51668c8 d448e99 b6e309c a31153b d448e99 a31153b d448e99 a31153b 2466430 a31153b d448e99 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 |
import gradio as gr
import requests
from PIL import Image
from io import BytesIO
from transformers import pipeline
import torch
from torchvision import transforms
# Cache the model loading
model_id = "Honey-Bee-Society/honeybee_bumblebee_vespidae_resnet50"
classifier = pipeline("image-classification", model=model_id, device=0 if torch.cuda.is_available() else -1)
# Define the same preprocessing steps as in the training script
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
def classify_image_from_url(image_url: str):
"""
Downloads an image from a public URL, preprocesses it, and runs it through
the ResNet-50 fine-tuned image-classification pipeline, returning the top predictions.
"""
try:
# Download the image
response = requests.get(image_url)
response.raise_for_status()
image = Image.open(BytesIO(response.content)).convert("RGB")
# Apply the same preprocessing as in the training script
image_tensor = preprocess(image).unsqueeze(0) # Add batch dimension
# Run inference
results = classifier(image_tensor)
# Format scores to remove scientific notation
for r in results:
r["score"] = float(f"{r['score']:.8f}")
return results
except requests.exceptions.RequestException as e:
return {"error": f"Failed to download image: {str(e)}"}
except Exception as e:
return {"error": f"An error occurred during classification: {str(e)}"}
demo = gr.Interface(
fn=classify_image_from_url,
inputs=gr.Textbox(lines=1, label="Image URL"),
outputs="json",
title="ResNet-50 Image Classifier",
description="Enter public image URL to get top predictions."
)
if __name__ == "__main__":
demo.launch() |