resnet-train / app.py
Sreekanth Tangirala
adding ui
f173f8e
from transformers import AutoModelForImageClassification, AutoFeatureExtractor
import torch
from PIL import Image
import gradio as gr
def load_model_from_hub(repo_id: str):
"""
Load model from Hugging Face Hub
Args:
repo_id: The repository ID (e.g., 'username/model-name')
Returns:
model: The loaded model
processor: The feature extractor/processor
"""
# Load model and processor from Hub
model = AutoModelForImageClassification.from_pretrained(repo_id)
processor = AutoFeatureExtractor.from_pretrained(repo_id)
return model, processor
def predict(image_path: str, model, processor):
"""
Make prediction using the loaded model
Args:
image_path: Path to input image
model: Loaded model
processor: Feature extractor/processor
Returns:
prediction: Model prediction
"""
# Load and preprocess image
image = Image.open(image_path)
inputs = processor(images=image, return_tensors="pt")
# Make prediction
with torch.no_grad():
outputs = model(**inputs)
predictions = outputs.logits.softmax(-1)
return predictions
def predict_image(image):
"""
Gradio interface function for prediction
Args:
image: Image uploaded through Gradio interface
Returns:
str: Prediction result with confidence score
"""
# Convert from numpy array to PIL Image
if not isinstance(image, Image.Image):
image = Image.fromarray(image)
# Process image and get prediction
inputs = processor(images=image, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
predictions = outputs.logits.softmax(-1)
# Get the top prediction
pred_scores = predictions[0].tolist()
top_pred_idx = max(range(len(pred_scores)), key=pred_scores.__getitem__)
confidence = pred_scores[top_pred_idx]
# Get class label
if hasattr(model.config, 'id2label'):
label = model.config.id2label[top_pred_idx]
else:
label = f"Class {top_pred_idx}"
return f"{label} (Confidence: {confidence:.2%})"
# Load model at startup
model, processor = load_model_from_hub("srtangirala/resnet50-exp")
# Create Gradio interface
iface = gr.Interface(
fn=predict_image,
inputs=gr.Image(),
outputs=gr.Text(),
title="Image Classification",
description="Upload an image to classify it!",
examples=[
# You can add example images here
# ["path/to/example1.jpg"],
# ["path/to/example2.jpg"]
]
)
if __name__ == "__main__":
iface.launch()