halictus commited on
Commit
51668c8
·
verified ·
1 Parent(s): d575e3f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -9
app.py CHANGED
@@ -4,12 +4,16 @@ from PIL import Image
4
  from io import BytesIO
5
  from transformers import pipeline
6
 
7
- #new
8
- # 1. Load a pretrained ResNet-50 from the Hugging Face Hub
 
 
 
 
 
9
  model_id = "Honey-Bee-Society/honeybee_bumblebee_vespidae_resnet50"
10
  classifier = pipeline("image-classification", model=model_id)
11
 
12
- # 2. Define an inference function
13
  def classify_image_from_url(image_url: str):
14
  """
15
  Downloads an image from a public URL and runs it through
@@ -24,15 +28,16 @@ def classify_image_from_url(image_url: str):
24
  # Run inference
25
  results = classifier(image)
26
 
27
- # You can return raw results or format them as desired
 
 
 
 
28
  return results
29
 
30
  except Exception as e:
31
  return {"error": str(e)}
32
 
33
- # 3. Create a Gradio interface
34
- # - We accept a single Textbox input (the public image URL)
35
- # - We return the classification results in JSON format
36
  demo = gr.Interface(
37
  fn=classify_image_from_url,
38
  inputs=gr.Textbox(lines=1, label="Image URL"),
@@ -41,7 +46,5 @@ demo = gr.Interface(
41
  description="Enter a public image URL to get top predictions."
42
  )
43
 
44
- # 4. Launch the app
45
  if __name__ == "__main__":
46
  demo.launch()
47
-
 
4
  from io import BytesIO
5
  from transformers import pipeline
6
 
7
+ # Our label mapping:
8
+ label_map = {
9
+ "LABEL_0": "honeybee",
10
+ "LABEL_1": "bumblebee",
11
+ "LABEL_2": "vespidae"
12
+ }
13
+
14
  model_id = "Honey-Bee-Society/honeybee_bumblebee_vespidae_resnet50"
15
  classifier = pipeline("image-classification", model=model_id)
16
 
 
17
  def classify_image_from_url(image_url: str):
18
  """
19
  Downloads an image from a public URL and runs it through
 
28
  # Run inference
29
  results = classifier(image)
30
 
31
+ # Post-process to replace "LABEL_0" etc. with "honeybee", "bumblebee", "vespidae"
32
+ for r in results:
33
+ if r["label"] in label_map:
34
+ r["label"] = label_map[r["label"]]
35
+
36
  return results
37
 
38
  except Exception as e:
39
  return {"error": str(e)}
40
 
 
 
 
41
  demo = gr.Interface(
42
  fn=classify_image_from_url,
43
  inputs=gr.Textbox(lines=1, label="Image URL"),
 
46
  description="Enter a public image URL to get top predictions."
47
  )
48
 
 
49
  if __name__ == "__main__":
50
  demo.launch()