from transformers import AutoFeatureExtractor, AutoModelForImageClassification from PIL import Image import gradio as gr # Load the pretrained model and feature extractor feature_extractor = AutoFeatureExtractor.from_pretrained("google/vit-base-patch16-224") model = AutoModelForImageClassification.from_pretrained("google/vit-base-patch16-224") # Define the function to classify images def classify_image(image): image = Image.fromarray(image).convert("RGB") # Convert input image to RGB inputs = feature_extractor(images=image, return_tensors="pt") # Preprocess image outputs = model(**inputs) # Get model predictions predicted_class_idx = outputs.logits.argmax(-1).item() # Get predicted class index return model.config.id2label[predicted_class_idx] # Return class label # Create a Gradio app interface app = gr.Interface( fn=classify_image, # Function to run inputs=gr.Image(type="numpy"), # Input: Image outputs="text", # Output: Predicted class label title="Image Classification App" # App title ) # Launch the app app.launch()