halictus commited on
Commit
f382800
·
verified ·
1 Parent(s): d448e99

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -2
app.py CHANGED
@@ -4,23 +4,36 @@ from PIL import Image
4
  from io import BytesIO
5
  from transformers import pipeline
6
  import torch
 
7
 
8
  # Cache the model loading
9
  model_id = "Honey-Bee-Society/honeybee_bumblebee_vespidae_resnet50"
10
  classifier = pipeline("image-classification", model=model_id, device=0 if torch.cuda.is_available() else -1)
11
 
 
 
 
 
 
 
 
 
12
  def classify_image_from_url(image_url: str):
13
  """
14
- Downloads an image from a public URL and runs it through
15
  the ResNet-50 fine-tuned image-classification pipeline, returning the top predictions.
16
  """
17
  try:
 
18
  response = requests.get(image_url)
19
  response.raise_for_status()
20
  image = Image.open(BytesIO(response.content)).convert("RGB")
21
 
 
 
 
22
  # Run inference
23
- results = classifier(image)
24
 
25
  # Format scores to remove scientific notation
26
  for r in results:
 
4
  from io import BytesIO
5
  from transformers import pipeline
6
  import torch
7
+ from torchvision import transforms
8
 
9
  # Cache the model loading
10
  model_id = "Honey-Bee-Society/honeybee_bumblebee_vespidae_resnet50"
11
  classifier = pipeline("image-classification", model=model_id, device=0 if torch.cuda.is_available() else -1)
12
 
13
+ # Define the same preprocessing steps as in the training script
14
+ preprocess = transforms.Compose([
15
+ transforms.Resize(256),
16
+ transforms.CenterCrop(224),
17
+ transforms.ToTensor(),
18
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
19
+ ])
20
+
21
  def classify_image_from_url(image_url: str):
22
  """
23
+ Downloads an image from a public URL, preprocesses it, and runs it through
24
  the ResNet-50 fine-tuned image-classification pipeline, returning the top predictions.
25
  """
26
  try:
27
+ # Download the image
28
  response = requests.get(image_url)
29
  response.raise_for_status()
30
  image = Image.open(BytesIO(response.content)).convert("RGB")
31
 
32
+ # Apply the same preprocessing as in the training script
33
+ image_tensor = preprocess(image).unsqueeze(0) # Add batch dimension
34
+
35
  # Run inference
36
+ results = classifier(image_tensor)
37
 
38
  # Format scores to remove scientific notation
39
  for r in results: