ryanwang058 commited on
Commit
8d91a27
·
1 Parent(s): a29ef42

Fix subdir loading

Browse files
Files changed (1) hide show
  1. app.py +24 -19
app.py CHANGED
@@ -18,44 +18,49 @@ classifiers = {
18
  for name, model_type, model_path in zip(model_paths.keys(), model_types, model_paths.values())
19
  }
20
 
21
- # List all test images
22
- def get_test_images():
23
- return [f for f in os.listdir(TEST_IMAGE_DIR) if f.lower().endswith(('.jpg', '.png'))]
 
 
 
 
 
 
 
24
 
25
  def predict(image, model_name):
26
  classifier = classifiers[model_name]
27
  predicted_class = classifier.predict(image)
28
  return predicted_class
29
 
30
- def classify_uploaded_image(image, model_name):
31
- return predict(image, model_name)
32
-
33
- def classify_preloaded_image(image_name, model_name):
34
- image_path = os.path.join(TEST_IMAGE_DIR, image_name)
35
  image = Image.open(image_path).convert("RGB")
36
  return predict(image, model_name)
37
 
38
  model_choices = list(model_paths.keys())
39
- test_images = get_test_images()
40
 
41
  # Define Gradio app
42
  with gr.Blocks() as demo:
43
  gr.Markdown("# Plant Disease Classifier")
44
 
45
- with gr.Tab("Upload an Image"):
46
- with gr.Row():
47
- image_input = gr.Image(type="pil", label="Upload an Image")
48
- model_input_upload = gr.Dropdown(choices=model_choices, label="Select Model", value="resnet")
49
- classify_button_upload = gr.Button("Classify")
50
- output_text_upload = gr.Textbox(label="Predicted Class")
51
- classify_button_upload.click(classify_uploaded_image, inputs=[image_input, model_input_upload], outputs=output_text_upload)
52
-
53
  with gr.Tab("Select a Preloaded Image"):
54
  with gr.Row():
55
- image_dropdown = gr.Dropdown(choices=test_images, label="Select a Test Image")
 
56
  model_input_preloaded = gr.Dropdown(choices=model_choices, label="Select Model", value="resnet")
 
57
  classify_button_preloaded = gr.Button("Classify")
58
  output_text_preloaded = gr.Textbox(label="Predicted Class")
59
- classify_button_preloaded.click(classify_preloaded_image, inputs=[image_dropdown, model_input_preloaded], outputs=output_text_preloaded)
 
 
 
 
 
 
 
 
60
 
61
  demo.launch()
 
18
  for name, model_type, model_path in zip(model_paths.keys(), model_types, model_paths.values())
19
  }
20
 
21
+ def get_subdirectories(directory):
22
+ """Get a list of subdirectories in the directory."""
23
+ return [d for d in os.listdir(directory) if os.path.isdir(os.path.join(directory, d))]
24
+
25
+ def get_images_in_subdirectory(subdirectory):
26
+ """Get a list of images in the selected subdirectory."""
27
+ subdir_path = os.path.join(TEST_IMAGE_DIR, subdirectory)
28
+ if os.path.exists(subdir_path):
29
+ return [f for f in os.listdir(subdir_path) if f.lower().endswith(('.jpg', '.png'))]
30
+ return []
31
 
32
  def predict(image, model_name):
33
  classifier = classifiers[model_name]
34
  predicted_class = classifier.predict(image)
35
  return predicted_class
36
 
37
+ def classify_preloaded_image(subdirectory, image_name, model_name):
38
+ image_path = os.path.join(TEST_IMAGE_DIR, subdirectory, image_name)
 
 
 
39
  image = Image.open(image_path).convert("RGB")
40
  return predict(image, model_name)
41
 
42
  model_choices = list(model_paths.keys())
 
43
 
44
  # Define Gradio app
45
  with gr.Blocks() as demo:
46
  gr.Markdown("# Plant Disease Classifier")
47
 
 
 
 
 
 
 
 
 
48
  with gr.Tab("Select a Preloaded Image"):
49
  with gr.Row():
50
+ subdir_dropdown = gr.Dropdown(choices=get_subdirectories(TEST_IMAGE_DIR), label="Select a Subdirectory")
51
+ image_dropdown = gr.Dropdown(choices=[], label="Select an Image")
52
  model_input_preloaded = gr.Dropdown(choices=model_choices, label="Select Model", value="resnet")
53
+
54
  classify_button_preloaded = gr.Button("Classify")
55
  output_text_preloaded = gr.Textbox(label="Predicted Class")
56
+
57
+ # Update image dropdown based on selected subdirectory
58
+ def update_images(subdirectory):
59
+ return gr.update(choices=get_images_in_subdirectory(subdirectory))
60
+
61
+ subdir_dropdown.change(update_images, inputs=subdir_dropdown, outputs=image_dropdown)
62
+ classify_button_preloaded.click(
63
+ classify_preloaded_image, inputs=[subdir_dropdown, image_dropdown, model_input_preloaded], outputs=output_text_preloaded
64
+ )
65
 
66
  demo.launch()