halictus commited on
Commit
d448e99
·
verified ·
1 Parent(s): 2466430

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -5
app.py CHANGED
@@ -3,14 +3,16 @@ import requests
3
  from PIL import Image
4
  from io import BytesIO
5
  from transformers import pipeline
 
6
 
 
7
  model_id = "Honey-Bee-Society/honeybee_bumblebee_vespidae_resnet50"
8
- classifier = pipeline("image-classification", model=model_id)
9
 
10
  def classify_image_from_url(image_url: str):
11
  """
12
  Downloads an image from a public URL and runs it through
13
- the ResNet-50 fine tuned image-classification pipeline, returning the top predictions.
14
  """
15
  try:
16
  response = requests.get(image_url)
@@ -22,12 +24,14 @@ def classify_image_from_url(image_url: str):
22
 
23
  # Format scores to remove scientific notation
24
  for r in results:
25
- r["score"] = float(f"{r['score']:.2f}")
26
 
27
  return results
28
 
 
 
29
  except Exception as e:
30
- return {"error": str(e)}
31
 
32
  demo = gr.Interface(
33
  fn=classify_image_from_url,
@@ -38,4 +42,4 @@ demo = gr.Interface(
38
  )
39
 
40
  if __name__ == "__main__":
41
- demo.launch()
 
3
  from PIL import Image
4
  from io import BytesIO
5
  from transformers import pipeline
6
+ import torch
7
 
8
+ # Cache the model loading
9
  model_id = "Honey-Bee-Society/honeybee_bumblebee_vespidae_resnet50"
10
+ classifier = pipeline("image-classification", model=model_id, device=0 if torch.cuda.is_available() else -1)
11
 
12
  def classify_image_from_url(image_url: str):
13
  """
14
  Downloads an image from a public URL and runs it through
15
+ the ResNet-50 fine-tuned image-classification pipeline, returning the top predictions.
16
  """
17
  try:
18
  response = requests.get(image_url)
 
24
 
25
  # Format scores to remove scientific notation
26
  for r in results:
27
+ r["score"] = float(f"{r['score']:.8f}")
28
 
29
  return results
30
 
31
+ except requests.exceptions.RequestException as e:
32
+ return {"error": f"Failed to download image: {str(e)}"}
33
  except Exception as e:
34
+ return {"error": f"An error occurred during classification: {str(e)}"}
35
 
36
  demo = gr.Interface(
37
  fn=classify_image_from_url,
 
42
  )
43
 
44
  if __name__ == "__main__":
45
+ demo.launch()