0-ma commited on
Commit
185a1c6
·
verified ·
1 Parent(s): 901ff0d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -11
app.py CHANGED
@@ -1,37 +1,81 @@
1
  import gradio as gr
2
  import numpy as np
3
  from PIL import Image
4
- from transformers import AutoImageProcessor, AutoModelForImageClassification
5
  import requests
6
- labels = [
 
7
  "None",
8
  "Circle",
9
  "Triangle",
10
  "Square",
11
  "Pentagon",
12
  "Hexagon"
13
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
- feature_extractor = AutoImageProcessor.from_pretrained('0-ma/vit-geometric-shapes-tiny')
16
- model = AutoModelForImageClassification.from_pretrained('0-ma/vit-geometric-shapes-tiny')
17
 
18
 
19
- def predict(image):
 
 
 
 
 
 
 
 
20
  inputs = feature_extractor(images=[image], return_tensors="pt")
21
  logits = model(**inputs)['logits'].cpu().detach().numpy()[0]
22
  logits_positive = logits
23
  logits_positive[logits < 0] = 0
24
  logits_positive = logits_positive/np.sum(logits_positive)
 
25
  confidences = {}
26
  for i in range(len(labels)):
27
- if logits[i]>0:
28
  confidences[labels[i]] = float(logits_positive[i])
29
  return confidences
30
 
31
-
32
  title = "Geometric Shape Classifier"
33
- description = "The geometric shape classifier: 0-ma/vit-geometric-shapes-tiny."
34
- examples = ['example/1_None.jpg','example/2_Circle.jpg','example/3_Triangle.jpg','example/4_Square.jpg','example/5_Pentagone.jpg','example/6_Hexagone.jpg']
35
 
 
 
 
 
 
 
 
 
36
 
37
- gr.Interface(fn=predict,inputs=gr.Image(type="pil"),outputs=gr.Label(),title=title,description=description,examples=examples).launch()
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
  import numpy as np
3
  from PIL import Image
4
+ from transformers import AutoImageProcessor, AutoModelForImageClassification
5
  import requests
6
+
7
+ labels = [
8
  "None",
9
  "Circle",
10
  "Triangle",
11
  "Square",
12
  "Pentagon",
13
  "Hexagon"
14
+ ]
15
+
16
+ # Available models for the dropdown
17
+ models = {
18
+ "0-ma/swin-geometric-shapes-tiny": "0-ma/swin-geometric-shapes-tiny",
19
+ "0-ma/mobilenet-v2-geometric-shapes": "0-ma/mobilenet-v2-geometric-shapes",
20
+ "0-ma/focalnet-geometric-shapes-tiny": "0-ma/focalnet-geometric-shapes-tiny" ,
21
+ "0-ma/efficientnet-b2-geometric-shapes":"0-ma/efficientnet-b2-geometric-shapes",
22
+ "0-ma/beit-geometric-shapes-base":"0-ma/beit-geometric-shapes-base",
23
+ "0-ma/mit-b0-geometric-shapes":"0-ma/mit-b0-geometric-shapes",
24
+ "0-ma/vit-geometric-shapes-base":"0-ma/vit-geometric-shapes-base",
25
+ "0-ma/resnet-geometric-shapes":"0-ma/resnet-geometric-shapes",
26
+ "0-ma/vit-geometric-shapes-tiny":"0-ma/vit-geometric-shapes-tiny",
27
+
28
+ }
29
+
30
+ # Load the default model
31
+ #feature_extractor = AutoImageProcessor.from_pretrained(models["Tiny Model"])
32
+ #model = AutoModelForImageClassification.from_pretrained(models["Tiny Model"])
33
+
34
 
35
+ feature_extractors = { model_name : AutoImageProcessor.from_pretrained(models[model_name]) for model_name in models}
36
+ classification_models = { model_name : AutoModelForImageClassification.from_pretrained(models[model_name]) for model_name in models}
37
 
38
 
39
+ def predict(image, selected_model):
40
+ # Load the selected model
41
+ # feature_extractor = AutoImageProcessor.from_pretrained(models[selected_model])
42
+ # model = AutoModelForImageClassification.from_pretrained(models[selected_model])
43
+
44
+ feature_extractor = feature_extractors[selected_model]
45
+ model = classification_models[selected_model]
46
+
47
+
48
  inputs = feature_extractor(images=[image], return_tensors="pt")
49
  logits = model(**inputs)['logits'].cpu().detach().numpy()[0]
50
  logits_positive = logits
51
  logits_positive[logits < 0] = 0
52
  logits_positive = logits_positive/np.sum(logits_positive)
53
+
54
  confidences = {}
55
  for i in range(len(labels)):
56
+ if logits[i] > 0:
57
  confidences[labels[i]] = float(logits_positive[i])
58
  return confidences
59
 
60
+
61
  title = "Geometric Shape Classifier"
62
+ description = "Select a model to classify geometric shapes."
 
63
 
64
+ examples = [
65
+ 'example/1_None.jpg',
66
+ 'example/2_Circle.jpg',
67
+ 'example/3_Triangle.jpg',
68
+ 'example/4_Square.jpg',
69
+ 'example/5_Pentagone.jpg',
70
+ 'example/6_Hexagone.jpg'
71
+ ]
72
 
73
+ # Adding a dropdown for model selection
74
+ gr.Interface(
75
+ fn=predict,
76
+ inputs=[gr.Image(type="pil"), gr.Dropdown(list(models.keys()), label="Select Model")],
77
+ outputs=gr.Label(),
78
+ title=title,
79
+ description=description,
80
+ examples=examples
81
+ ).launch()