File size: 1,929 Bytes
a31153b
 
 
 
 
d448e99
f382800
a31153b
d448e99
e3f40a0
d448e99
a31153b
f382800
 
 
 
 
 
 
 
a31153b
 
f382800
d448e99
a31153b
 
f382800
a31153b
 
 
9c858bd
f382800
 
 
a31153b
f382800
b6e309c
9c858bd
51668c8
d448e99
b6e309c
a31153b
 
d448e99
 
a31153b
d448e99
a31153b
 
 
 
 
 
2466430
a31153b
 
 
d448e99
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import gradio as gr
import requests
from PIL import Image
from io import BytesIO
from transformers import pipeline
import torch
from torchvision import transforms

# Cache the model loading
model_id = "Honey-Bee-Society/honeybee_bumblebee_vespidae_resnet50"
classifier = pipeline("image-classification", model=model_id, device=0 if torch.cuda.is_available() else -1)

# Define the same preprocessing steps as in the training script
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

def classify_image_from_url(image_url: str):
    """
    Downloads an image from a public URL, preprocesses it, and runs it through
    the ResNet-50 fine-tuned image-classification pipeline, returning the top predictions.
    """
    try:
        # Download the image
        response = requests.get(image_url)
        response.raise_for_status()
        image = Image.open(BytesIO(response.content)).convert("RGB")

        # Apply the same preprocessing as in the training script
        image_tensor = preprocess(image).unsqueeze(0)  # Add batch dimension

        # Run inference
        results = classifier(image_tensor)

        # Format scores to remove scientific notation
        for r in results:
            r["score"] = float(f"{r['score']:.8f}")

        return results
    
    except requests.exceptions.RequestException as e:
        return {"error": f"Failed to download image: {str(e)}"}
    except Exception as e:
        return {"error": f"An error occurred during classification: {str(e)}"}

demo = gr.Interface(
    fn=classify_image_from_url,
    inputs=gr.Textbox(lines=1, label="Image URL"),
    outputs="json",
    title="ResNet-50 Image Classifier",
    description="Enter public image URL to get top predictions."
)

if __name__ == "__main__":
    demo.launch()