import gradio as gr from PIL import Image import os from torch.utils.data import DataLoader from plant_disease_classifier import PlantDiseaseClassifier # Define model paths and types model_types = ["resnet", "vit", "levit"] model_paths = { "resnet": "resnet50_ft.pth", "vit": "vit32b_ft.pth", "levit": "levit128s_ft.pth", } classifiers = { name: PlantDiseaseClassifier(model_type, model_path) for name, model_type, model_path in zip(model_paths.keys(), model_types, model_paths.values()) } def predict(image, model_name): classifier = classifiers[model_name] predicted_class = classifier.predict_image(image) return predicted_class # Gradio Interface def classify_image(image, model_name): return predict(image, model_name) model_choices = list(model_paths.keys()) # Define Gradio app with gr.Blocks() as demo: gr.Markdown("# Plant Disease Classifier") with gr.Row(): image_input = gr.Image(type="pil", label="Upload an Image") model_input = gr.Dropdown(choices=model_choices, label="Select Model", value="ResNet") classify_button = gr.Button("Classify") output_text = gr.Textbox(label="Predicted Class") classify_button.click(classify_image, inputs=[image_input, model_input], outputs=output_text) demo.launch()