ryanwang058
Allow user to use preloaded images for testing
a988558
raw
history blame
2.32 kB
import gradio as gr
from PIL import Image
import os
from plant_disease_classifier import PlantDiseaseClassifier
# Directory containing test images
TEST_IMAGE_DIR = "test"
# Define model paths and types
model_types = ["resnet", "vit", "levit"]
model_paths = {
"resnet": "resnet50_ft.pth",
"vit": "vit32b_ft.pth",
"levit": "levit128s_ft.pth",
}
classifiers = {
name: PlantDiseaseClassifier(model_type, model_path)
for name, model_type, model_path in zip(model_paths.keys(), model_types, model_paths.values())
}
# List all test images
def get_test_images():
return [f for f in os.listdir(TEST_IMAGE_DIR) if f.lower().endswith(('.jpg', '.png'))]
def predict(image, model_name):
classifier = classifiers[model_name]
predicted_class = classifier.predict(image)
return predicted_class
def classify_uploaded_image(image, model_name):
return predict(image, model_name)
def classify_preloaded_image(image_name, model_name):
image_path = os.path.join(TEST_IMAGE_DIR, image_name)
image = Image.open(image_path).convert("RGB")
return predict(image, model_name)
model_choices = list(model_paths.keys())
test_images = get_test_images()
# Define Gradio app
with gr.Blocks() as demo:
gr.Markdown("# Plant Disease Classifier")
with gr.Tab("Upload an Image"):
with gr.Row():
image_input = gr.Image(type="pil", label="Upload an Image")
model_input_upload = gr.Dropdown(choices=model_choices, label="Select Model", value="resnet")
classify_button_upload = gr.Button("Classify")
output_text_upload = gr.Textbox(label="Predicted Class")
classify_button_upload.click(classify_uploaded_image, inputs=[image_input, model_input_upload], outputs=output_text_upload)
with gr.Tab("Select a Preloaded Image"):
with gr.Row():
image_dropdown = gr.Dropdown(choices=test_images, label="Select a Test Image")
model_input_preloaded = gr.Dropdown(choices=model_choices, label="Select Model", value="resnet")
classify_button_preloaded = gr.Button("Classify")
output_text_preloaded = gr.Textbox(label="Predicted Class")
classify_button_preloaded.click(classify_preloaded_image, inputs=[image_dropdown, model_input_preloaded], outputs=output_text_preloaded)
demo.launch()