0-ma commited on
Commit
2559312
·
verified ·
1 Parent(s): 40fea1a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -20
app.py CHANGED
@@ -3,7 +3,6 @@ import numpy as np
3
  from PIL import Image
4
  from transformers import AutoImageProcessor, AutoModelForImageClassification
5
  import requests
6
- selected_model = "0-ma/vit-geometric-shapes-tiny"
7
 
8
  model_names = [
9
  "0-ma/swin-geometric-shapes-tiny",
@@ -27,23 +26,13 @@ examples = [
27
  ]
28
 
29
  labels = [example.split("_")[1].split(".")[0] for example in examples]
30
- # Load the default model
31
- #feature_extractor = AutoImageProcessor.from_pretrained(models["Tiny Model"])
32
- #model = AutoModelForImageClassification.from_pretrained(models["Tiny Model"])
33
 
 
 
34
 
35
- feature_extractors = { model_name : AutoImageProcessor.from_pretrained(model_name) for model_name in model_names}
36
- classification_models = { model_name : AutoModelForImageClassification.from_pretrained(model_name) for model_name in model_names}
37
-
38
-
39
- def predict(image):
40
- # Load the selected model
41
- # feature_extractor = AutoImageProcessor.from_pretrained(models[selected_model])
42
- # model = AutoModelForImageClassification.from_pretrained(models[selected_model])
43
-
44
  feature_extractor = feature_extractors[selected_model]
45
  model = classification_models[selected_model]
46
-
47
 
48
  inputs = feature_extractor(images=[image], return_tensors="pt")
49
  logits = model(**inputs)['logits'].cpu().detach().numpy()[0]
@@ -57,17 +46,20 @@ def predict(image):
57
  confidences[labels[i]] = float(logits_positive[i])
58
  return confidences
59
 
60
-
61
  title = "Geometric Shape Classifier"
62
  description = "Select a model to classify geometric shapes."
63
 
64
-
65
- # Adding a dropdown for model selection
66
- gr.Interface(
67
  fn=predict,
68
- inputs=gr.Image(type="pil"),
 
 
 
69
  outputs=gr.Label(),
70
  title=title,
71
  description=description,
72
  examples=examples
73
- ).launch()
 
 
 
3
  from PIL import Image
4
  from transformers import AutoImageProcessor, AutoModelForImageClassification
5
  import requests
 
6
 
7
  model_names = [
8
  "0-ma/swin-geometric-shapes-tiny",
 
26
  ]
27
 
28
  labels = [example.split("_")[1].split(".")[0] for example in examples]
 
 
 
29
 
30
+ feature_extractors = {model_name: AutoImageProcessor.from_pretrained(model_name) for model_name in model_names}
31
+ classification_models = {model_name: AutoModelForImageClassification.from_pretrained(model_name) for model_name in model_names}
32
 
33
+ def predict(image, selected_model):
 
 
 
 
 
 
 
 
34
  feature_extractor = feature_extractors[selected_model]
35
  model = classification_models[selected_model]
 
36
 
37
  inputs = feature_extractor(images=[image], return_tensors="pt")
38
  logits = model(**inputs)['logits'].cpu().detach().numpy()[0]
 
46
  confidences[labels[i]] = float(logits_positive[i])
47
  return confidences
48
 
 
49
  title = "Geometric Shape Classifier"
50
  description = "Select a model to classify geometric shapes."
51
 
52
+ # Create the Gradio interface
53
+ iface = gr.Interface(
 
54
  fn=predict,
55
+ inputs=[
56
+ gr.Image(type="pil"),
57
+ gr.Dropdown(choices=model_names, label="Select Model", value=model_names[0])
58
+ ],
59
  outputs=gr.Label(),
60
  title=title,
61
  description=description,
62
  examples=examples
63
+ )
64
+
65
+ # Launch the interf