File size: 1,923 Bytes
a31153b
 
 
 
 
 
b6e309c
 
 
 
 
53e6e66
51668c8
53e6e66
 
51668c8
 
 
4ad261a
a31153b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b6e309c
 
 
51668c8
b6e309c
51668c8
 
b6e309c
 
 
a31153b
 
 
 
 
 
 
 
 
 
b6e309c
a31153b
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import gradio as gr
import requests
from PIL import Image
from io import BytesIO
from transformers import pipeline

# Adjust these if your model's order is actually different.
# For example, if your dataset folders are named (alphabetically): 
#   bumblebee, honeybee, vespidae,
# then 0 => bumblebee, 1 => honeybee, 2 => vespidae (the default PyTorch order).
# Verify your label indices by printing `test_dataset.classes` in your training script.

label_map = {
    "LABEL_0": "bumblebee",
    "LABEL_1": "honeybee",
    "LABEL_2": "vespidae"
}

model_id = "Honey-Bee-Society/honeybee_bumblebee_vespidae_resnet50"
classifier = pipeline("image-classification", model=model_id)

def classify_image_from_url(image_url: str):
    """
    Downloads an image from a public URL and runs it through
    the ResNet-50 image-classification pipeline, returning the top predictions.
    """
    try:
        # Fetch the image
        response = requests.get(image_url)
        response.raise_for_status()
        image = Image.open(BytesIO(response.content)).convert("RGB")
        
        # Run inference
        results = classifier(image)

        # 1) Post-process labels
        # 2) Format scores to remove scientific notation
        for r in results:
            # Map from "LABEL_x" to your real class name
            if r["label"] in label_map:
                r["label"] = label_map[r["label"]]
            # Format score with, e.g., 8 decimal places to avoid scientific notation
            r["score"] = float(f"{r['score']:.8f}") 

        return results
    
    except Exception as e:
        return {"error": str(e)}

demo = gr.Interface(
    fn=classify_image_from_url,
    inputs=gr.Textbox(lines=1, label="Image URL"),
    outputs="json",
    title="ResNet-50 Image Classifier",
    description="Enter a public image URL to get top predictions with custom labels."
)

if __name__ == "__main__":
    demo.launch()