Spaces:
Sleeping
Sleeping
from transformers import AutoModelForImageClassification, AutoFeatureExtractor | |
import torch | |
from PIL import Image | |
import gradio as gr | |
def load_model_from_hub(repo_id: str): | |
""" | |
Load model from Hugging Face Hub | |
Args: | |
repo_id: The repository ID (e.g., 'username/model-name') | |
Returns: | |
model: The loaded model | |
processor: The feature extractor/processor | |
""" | |
# Load model and processor from Hub | |
model = AutoModelForImageClassification.from_pretrained(repo_id) | |
processor = AutoFeatureExtractor.from_pretrained(repo_id) | |
return model, processor | |
def predict(image_path: str, model, processor): | |
""" | |
Make prediction using the loaded model | |
Args: | |
image_path: Path to input image | |
model: Loaded model | |
processor: Feature extractor/processor | |
Returns: | |
prediction: Model prediction | |
""" | |
# Load and preprocess image | |
image = Image.open(image_path) | |
inputs = processor(images=image, return_tensors="pt") | |
# Make prediction | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
predictions = outputs.logits.softmax(-1) | |
return predictions | |
def predict_image(image): | |
""" | |
Gradio interface function for prediction | |
Args: | |
image: Image uploaded through Gradio interface | |
Returns: | |
str: Prediction result with confidence score | |
""" | |
# Convert from numpy array to PIL Image | |
if not isinstance(image, Image.Image): | |
image = Image.fromarray(image) | |
# Process image and get prediction | |
inputs = processor(images=image, return_tensors="pt") | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
predictions = outputs.logits.softmax(-1) | |
# Get the top prediction | |
pred_scores = predictions[0].tolist() | |
top_pred_idx = max(range(len(pred_scores)), key=pred_scores.__getitem__) | |
confidence = pred_scores[top_pred_idx] | |
# Get class label | |
if hasattr(model.config, 'id2label'): | |
label = model.config.id2label[top_pred_idx] | |
else: | |
label = f"Class {top_pred_idx}" | |
return f"{label} (Confidence: {confidence:.2%})" | |
# Load model at startup | |
model, processor = load_model_from_hub("srtangirala/resnet50-exp") | |
# Create Gradio interface | |
iface = gr.Interface( | |
fn=predict_image, | |
inputs=gr.Image(), | |
outputs=gr.Text(), | |
title="Image Classification", | |
description="Upload an image to classify it!", | |
examples=[ | |
# You can add example images here | |
# ["path/to/example1.jpg"], | |
# ["path/to/example2.jpg"] | |
] | |
) | |
if __name__ == "__main__": | |
iface.launch() |