Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -4,23 +4,36 @@ from PIL import Image
|
|
4 |
from io import BytesIO
|
5 |
from transformers import pipeline
|
6 |
import torch
|
|
|
7 |
|
8 |
# Cache the model loading
|
9 |
model_id = "Honey-Bee-Society/honeybee_bumblebee_vespidae_resnet50"
|
10 |
classifier = pipeline("image-classification", model=model_id, device=0 if torch.cuda.is_available() else -1)
|
11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
def classify_image_from_url(image_url: str):
|
13 |
"""
|
14 |
-
Downloads an image from a public URL and runs it through
|
15 |
the ResNet-50 fine-tuned image-classification pipeline, returning the top predictions.
|
16 |
"""
|
17 |
try:
|
|
|
18 |
response = requests.get(image_url)
|
19 |
response.raise_for_status()
|
20 |
image = Image.open(BytesIO(response.content)).convert("RGB")
|
21 |
|
|
|
|
|
|
|
22 |
# Run inference
|
23 |
-
results = classifier(
|
24 |
|
25 |
# Format scores to remove scientific notation
|
26 |
for r in results:
|
|
|
4 |
from io import BytesIO
|
5 |
from transformers import pipeline
|
6 |
import torch
|
7 |
+
from torchvision import transforms
|
8 |
|
9 |
# Cache the model loading
|
10 |
model_id = "Honey-Bee-Society/honeybee_bumblebee_vespidae_resnet50"
|
11 |
classifier = pipeline("image-classification", model=model_id, device=0 if torch.cuda.is_available() else -1)
|
12 |
|
13 |
+
# Define the same preprocessing steps as in the training script
|
14 |
+
preprocess = transforms.Compose([
|
15 |
+
transforms.Resize(256),
|
16 |
+
transforms.CenterCrop(224),
|
17 |
+
transforms.ToTensor(),
|
18 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
19 |
+
])
|
20 |
+
|
21 |
def classify_image_from_url(image_url: str):
|
22 |
"""
|
23 |
+
Downloads an image from a public URL, preprocesses it, and runs it through
|
24 |
the ResNet-50 fine-tuned image-classification pipeline, returning the top predictions.
|
25 |
"""
|
26 |
try:
|
27 |
+
# Download the image
|
28 |
response = requests.get(image_url)
|
29 |
response.raise_for_status()
|
30 |
image = Image.open(BytesIO(response.content)).convert("RGB")
|
31 |
|
32 |
+
# Apply the same preprocessing as in the training script
|
33 |
+
image_tensor = preprocess(image).unsqueeze(0) # Add batch dimension
|
34 |
+
|
35 |
# Run inference
|
36 |
+
results = classifier(image_tensor)
|
37 |
|
38 |
# Format scores to remove scientific notation
|
39 |
for r in results:
|