TroglodyteDerivations's picture
Updated notation lines 12-16 via labels variable declaration : # Define the Pokémon labels # Although, labels are set with 3 Pokémon -> One can leverage the remaining 715 Pokémon # Provides the functionality -> But, the hardcoded list of three Pokémon, is not correct. # On the next variant [mutation] of the Pokémon classifier jettison obverse the actual model's configuration # yielding mapping from class indices to labels. Then, foment the model's predictions on the Pokémon.
d183073 verified
import gradio as gr
from transformers import AutoModelForImageClassification, AutoFeatureExtractor
from PIL import Image
import numpy as np
import torch
# Load the model and feature extractor
model_name = "imjeffhi/pokemon_classifier"
model = AutoModelForImageClassification.from_pretrained(model_name)
feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
# Define the Pokémon labels
# Although, labels are set with 3 Pokémon -> One can leverage the remaining 715 Pokémon
# Provides the functionality -> But, the hardcoded list of three Pokémon, is not correct.
# On the next variant [mutation] of the Pokémon classifier jettison obverse the actual model's configuration
# yielding mapping from class indices to labels. Then, foment the model's predictions on the Pokémon.
labels = ['Jolteon', 'Kakuna', 'Mr. Mime']
# Function to preprocess the image
def preprocess_image(img_pil):
inputs = feature_extractor(images=img_pil, return_tensors="pt")
return inputs
# Function to predict the class of the image
def predict_classification(img_pil):
inputs = preprocess_image(img_pil)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
predicted_class_idx = logits.argmax(-1).item()
# Check if the predicted class index is within the valid range of the labels list
if predicted_class_idx < len(labels):
predicted_class = labels[predicted_class_idx]
else:
predicted_class = "Unknown" # Default to "Unknown" if the index is out of range
confidence = torch.nn.functional.softmax(logits, dim=1).numpy()[0][predicted_class_idx]
return predicted_class, confidence
# Function to handle the prediction in the Gradio interface
def gradio_predict(img_pil):
predicted_class, confidence = predict_classification(img_pil)
return f"Predicted class: {predicted_class}, Confidence: {confidence:.4f}"
# Create Gradio interface
input_image = gr.Image(label="Upload an image of a Pokemon")
output_text = gr.Textbox(label="Predicted Class and Confidence")
iface = gr.Interface(
fn=gradio_predict,
inputs=input_image,
outputs=output_text,
title="Pokemon Classifier",
description="Upload an image of a Pokemon and the classifier will tell you which one it is and the confidence level of the prediction.",
allow_flagging="never"
)
iface.launch()