
Updated notation lines 12-16 via labels variable declaration : # Define the Pokémon labels # Although, labels are set with 3 Pokémon -> One can leverage the remaining 715 Pokémon # Provides the functionality -> But, the hardcoded list of three Pokémon, is not correct. # On the next variant [mutation] of the Pokémon classifier jettison obverse the actual model's configuration # yielding mapping from class indices to labels. Then, foment the model's predictions on the Pokémon.
d183073
verified
import gradio as gr | |
from transformers import AutoModelForImageClassification, AutoFeatureExtractor | |
from PIL import Image | |
import numpy as np | |
import torch | |
# Load the model and feature extractor | |
model_name = "imjeffhi/pokemon_classifier" | |
model = AutoModelForImageClassification.from_pretrained(model_name) | |
feature_extractor = AutoFeatureExtractor.from_pretrained(model_name) | |
# Define the Pokémon labels | |
# Although, labels are set with 3 Pokémon -> One can leverage the remaining 715 Pokémon | |
# Provides the functionality -> But, the hardcoded list of three Pokémon, is not correct. | |
# On the next variant [mutation] of the Pokémon classifier jettison obverse the actual model's configuration | |
# yielding mapping from class indices to labels. Then, foment the model's predictions on the Pokémon. | |
labels = ['Jolteon', 'Kakuna', 'Mr. Mime'] | |
# Function to preprocess the image | |
def preprocess_image(img_pil): | |
inputs = feature_extractor(images=img_pil, return_tensors="pt") | |
return inputs | |
# Function to predict the class of the image | |
def predict_classification(img_pil): | |
inputs = preprocess_image(img_pil) | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
logits = outputs.logits | |
predicted_class_idx = logits.argmax(-1).item() | |
# Check if the predicted class index is within the valid range of the labels list | |
if predicted_class_idx < len(labels): | |
predicted_class = labels[predicted_class_idx] | |
else: | |
predicted_class = "Unknown" # Default to "Unknown" if the index is out of range | |
confidence = torch.nn.functional.softmax(logits, dim=1).numpy()[0][predicted_class_idx] | |
return predicted_class, confidence | |
# Function to handle the prediction in the Gradio interface | |
def gradio_predict(img_pil): | |
predicted_class, confidence = predict_classification(img_pil) | |
return f"Predicted class: {predicted_class}, Confidence: {confidence:.4f}" | |
# Create Gradio interface | |
input_image = gr.Image(label="Upload an image of a Pokemon") | |
output_text = gr.Textbox(label="Predicted Class and Confidence") | |
iface = gr.Interface( | |
fn=gradio_predict, | |
inputs=input_image, | |
outputs=output_text, | |
title="Pokemon Classifier", | |
description="Upload an image of a Pokemon and the classifier will tell you which one it is and the confidence level of the prediction.", | |
allow_flagging="never" | |
) | |
iface.launch() |