halictus commited on
Commit
b6e309c
·
verified ·
1 Parent(s): 51668c8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -5
app.py CHANGED
@@ -4,7 +4,11 @@ from PIL import Image
4
  from io import BytesIO
5
  from transformers import pipeline
6
 
7
- # Our label mapping:
 
 
 
 
8
  label_map = {
9
  "LABEL_0": "honeybee",
10
  "LABEL_1": "bumblebee",
@@ -27,12 +31,16 @@ def classify_image_from_url(image_url: str):
27
 
28
  # Run inference
29
  results = classifier(image)
30
-
31
- # Post-process to replace "LABEL_0" etc. with "honeybee", "bumblebee", "vespidae"
 
32
  for r in results:
 
33
  if r["label"] in label_map:
34
  r["label"] = label_map[r["label"]]
35
-
 
 
36
  return results
37
 
38
  except Exception as e:
@@ -43,7 +51,7 @@ demo = gr.Interface(
43
  inputs=gr.Textbox(lines=1, label="Image URL"),
44
  outputs="json",
45
  title="ResNet-50 Image Classifier",
46
- description="Enter a public image URL to get top predictions."
47
  )
48
 
49
  if __name__ == "__main__":
 
4
  from io import BytesIO
5
  from transformers import pipeline
6
 
7
+ # Adjust these if your model's order is actually different.
8
+ # For example, if your dataset folders are named (alphabetically):
9
+ # bumblebee, honeybee, vespidae,
10
+ # then 0 => bumblebee, 1 => honeybee, 2 => vespidae (the default PyTorch order).
11
+ # Verify your label indices by printing `test_dataset.classes` in your training script.
12
  label_map = {
13
  "LABEL_0": "honeybee",
14
  "LABEL_1": "bumblebee",
 
31
 
32
  # Run inference
33
  results = classifier(image)
34
+
35
+ # 1) Post-process labels
36
+ # 2) Format scores to remove scientific notation
37
  for r in results:
38
+ # Map from "LABEL_x" to your real class name
39
  if r["label"] in label_map:
40
  r["label"] = label_map[r["label"]]
41
+ # Format score with, e.g., 8 decimal places to avoid scientific notation
42
+ r["score"] = float(f"{r['score']:.8f}")
43
+
44
  return results
45
 
46
  except Exception as e:
 
51
  inputs=gr.Textbox(lines=1, label="Image URL"),
52
  outputs="json",
53
  title="ResNet-50 Image Classifier",
54
+ description="Enter a public image URL to get top predictions with custom labels."
55
  )
56
 
57
  if __name__ == "__main__":