Spaces:
Sleeping
Sleeping
File size: 2,642 Bytes
9c8dfb8 de2aabe f173f8e de2aabe 9c8dfb8 de2aabe 9c8dfb8 de2aabe 9c8dfb8 de2aabe 9c8dfb8 de2aabe 9c8dfb8 f173f8e de2aabe 9c8dfb8 f173f8e de2aabe f173f8e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
from transformers import AutoModelForImageClassification, AutoFeatureExtractor
import torch
from PIL import Image
import gradio as gr
def load_model_from_hub(repo_id: str):
"""
Load model from Hugging Face Hub
Args:
repo_id: The repository ID (e.g., 'username/model-name')
Returns:
model: The loaded model
processor: The feature extractor/processor
"""
# Load model and processor from Hub
model = AutoModelForImageClassification.from_pretrained(repo_id)
processor = AutoFeatureExtractor.from_pretrained(repo_id)
return model, processor
def predict(image_path: str, model, processor):
"""
Make prediction using the loaded model
Args:
image_path: Path to input image
model: Loaded model
processor: Feature extractor/processor
Returns:
prediction: Model prediction
"""
# Load and preprocess image
image = Image.open(image_path)
inputs = processor(images=image, return_tensors="pt")
# Make prediction
with torch.no_grad():
outputs = model(**inputs)
predictions = outputs.logits.softmax(-1)
return predictions
def predict_image(image):
"""
Gradio interface function for prediction
Args:
image: Image uploaded through Gradio interface
Returns:
str: Prediction result with confidence score
"""
# Convert from numpy array to PIL Image
if not isinstance(image, Image.Image):
image = Image.fromarray(image)
# Process image and get prediction
inputs = processor(images=image, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
predictions = outputs.logits.softmax(-1)
# Get the top prediction
pred_scores = predictions[0].tolist()
top_pred_idx = max(range(len(pred_scores)), key=pred_scores.__getitem__)
confidence = pred_scores[top_pred_idx]
# Get class label
if hasattr(model.config, 'id2label'):
label = model.config.id2label[top_pred_idx]
else:
label = f"Class {top_pred_idx}"
return f"{label} (Confidence: {confidence:.2%})"
# Load model at startup
model, processor = load_model_from_hub("srtangirala/resnet50-exp")
# Create Gradio interface
iface = gr.Interface(
fn=predict_image,
inputs=gr.Image(),
outputs=gr.Text(),
title="Image Classification",
description="Upload an image to classify it!",
examples=[
# You can add example images here
# ["path/to/example1.jpg"],
# ["path/to/example2.jpg"]
]
)
if __name__ == "__main__":
iface.launch() |