0-ma commited on
Commit
dde0d91
·
verified ·
1 Parent(s): 657b87f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -12
app.py CHANGED
@@ -30,6 +30,8 @@ feature_extractors = {model_name: AutoImageProcessor.from_pretrained(model_name)
30
  classification_models = {model_name: AutoModelForImageClassification.from_pretrained(model_name) for model_name in model_names}
31
 
32
  def predict(image, selected_model):
 
 
33
  feature_extractor = feature_extractors[selected_model]
34
  model = classification_models[selected_model]
35
 
@@ -53,19 +55,26 @@ with gr.Blocks() as demo:
53
  gr.Markdown(description)
54
 
55
  with gr.Row():
56
- model_dropdown = gr.Dropdown(choices=model_names, label="Selected Model", value=model_names[0])
57
- gr.Examples(
58
- examples=example_images,
59
- inputs=image_input,
60
- outputs=output,
61
- fn=lambda img: predict(img, model_dropdown.value),
62
- cache_examples=True,
63
- )
64
- image_input = gr.Image(type="pil")
65
 
66
- output = gr.Label()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
- image_input.change(fn=lambda img: predict(img, model_dropdown.value), inputs=[image_input], outputs=output)
69
- model_dropdown.change(fn=lambda img, model: predict(img, model), inputs=[image_input, model_dropdown], outputs=output)
70
 
71
  demo.launch()
 
30
  classification_models = {model_name: AutoModelForImageClassification.from_pretrained(model_name) for model_name in model_names}
31
 
32
  def predict(image, selected_model):
33
+ if image is None:
34
+ return None
35
  feature_extractor = feature_extractors[selected_model]
36
  model = classification_models[selected_model]
37
 
 
55
  gr.Markdown(description)
56
 
57
  with gr.Row():
58
+ model_dropdown = gr.Dropdown(choices=model_names, label="Select Model", value=model_names[0])
59
+ image_input = gr.Image(type="pil")
 
 
 
 
 
 
 
60
 
61
+ # Move the Examples section here, before the output
62
+ gr.Examples(
63
+ examples=example_images,
64
+ inputs=image_input,
65
+ label="Click on an example image to test",
66
+ )
67
+
68
+ # Output section
69
+ output = gr.Label(label="Classification Result")
70
+
71
+ # Event handlers
72
+ def classify(img, model):
73
+ if img is not None:
74
+ return predict(img, model)
75
+ return None
76
 
77
+ image_input.change(fn=classify, inputs=[image_input, model_dropdown], outputs=output)
78
+ model_dropdown.change(fn=classify, inputs=[image_input, model_dropdown], outputs=output)
79
 
80
  demo.launch()