ResNetVision-1K / app.py
Adityak204's picture
Return raw predictions
b555f64
import gradio as gr
import torch
from torchvision import transforms
from PIL import Image
import json
from pathlib import Path
import os
from huggingface_hub import hf_hub_download
import numpy as np
class ModelPredictor:
def __init__(
self,
model_repo: str,
model_filename: str,
device: str = None,
):
self.device = (
device if device else ("cuda" if torch.cuda.is_available() else "cpu")
)
# Load the model
checkpoint_path = hf_hub_download(repo_id=model_repo, filename=model_filename)
self.model = self.load_model(checkpoint_path)
self.model.to(self.device)
self.model.eval()
# Setup transforms
self.transform = transforms.Compose(
[
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
),
]
)
# Load ImageNet class labels
self.class_labels = self.load_imagenet_labels()
def load_model(self, checkpoint_path: str):
"""Load the trained model from checkpoint"""
from pl_train import ImageNetModule
checkpoint = torch.load(checkpoint_path, map_location=self.device)
model = ImageNetModule(
learning_rate=0.156,
batch_size=1,
num_workers=0, # Set to 0 for Gradio
max_epochs=40,
train_path="",
val_path="",
checkpoint_dir="",
)
model.load_state_dict(checkpoint["state_dict"])
return model
def load_imagenet_labels(self):
"""Load ImageNet class labels"""
# For HuggingFace Spaces, we'll look for the labels file in the same directory
labels_path = Path("data/imagenet-simple-labels.json")
if labels_path.exists():
with open(labels_path) as f:
data = json.load(f)
return {str(i + 1): name for i, name in enumerate(data)}
return {str(i): f"class_{i}" for i in range(1000)} # Fallback
def predict(self, image):
"""
Make prediction for a single image
Args:
image: numpy array from Gradio
Returns:
Dictionary of class labels and probabilities
"""
try:
# Convert numpy array to PIL Image
if isinstance(image, np.ndarray):
# If image is from Gradio, it will be a numpy array
image = Image.fromarray(image.astype("uint8"))
elif isinstance(image, str):
# If image is a file path
image = Image.open(image)
# Ensure image is in RGB mode
if image.mode != "RGB":
image = image.convert("RGB")
# Apply transforms and predict
image_tensor = self.transform(image).unsqueeze(0)
image_tensor = image_tensor.to(self.device)
with torch.no_grad():
outputs = self.model(image_tensor)
probabilities = torch.nn.functional.softmax(outputs, dim=1)
# Get top 5 predictions
top_probs, top_indices = torch.topk(probabilities, 5)
# Create results dictionary
results = {}
for prob, idx in zip(top_probs[0], top_indices[0]):
class_name = self.class_labels[str(idx.item())]
results[class_name] = float(prob)
return results
except Exception as e:
print(f"Error in prediction: {str(e)}")
return {"error": 1.0}
# Initialize the predictor
try:
predictor = ModelPredictor(
model_repo="Adityak204/ResNetVision-1K", # Replace with your repo
model_filename="resnet50-epoch36-acc60.3506.ckpt", # Replace with your model filename
)
except Exception as e:
print(f"Error initializing predictor: {str(e)}")
def predict_image(image):
"""
Gradio interface function
Args:
image: numpy array from Gradio's image input
Returns:
Dictionary of predictions formatted for display
"""
if image is None:
return {"Error: No image provided": 1.0}
try:
predictions = predictor.predict(image)
# Format results for display
return predictions
except Exception as e:
print(f"Error in predict_image: {str(e)}")
return {"Error: Failed to process image": 1.0}
# Create Gradio interface
iface = gr.Interface(
fn=predict_image,
inputs=gr.Image(type="pil"),
outputs=gr.Label(num_top_classes=5),
title="ImageNet-1K Classification",
description="Upload an image to classify it into one of 1000 ImageNet categories",
# examples=(
# [
# ["ResNetVision-1K/data/ILSVRC2012_val_00000048.JPEG"],
# ["ResNetVision-1K/data/ILSVRC2012_val_00000090.JPEG"],
# ["ResNetVision-1K/data/ILSVRC2012_val_00000.JPEG"],
# ]
# if all(
# Path(f).exists()
# for f in [
# ["ResNetVision-1K/data/ILSVRC2012_val_00000048.JPEG"],
# ["ResNetVision-1K/data/ILSVRC2012_val_00000090.JPEG"],
# ["ResNetVision-1K/data/ILSVRC2012_val_00000.JPEG"],
# ]
# )
# else None
# ),
analytics_enabled=False,
)
# Launch the app
if __name__ == "__main__":
iface.launch()