ResNetVision-1K / app.py
Adityak204's picture
Return raw predictions
b555f64
raw
history blame
5.56 kB
import gradio as gr
import torch
from torchvision import transforms
from PIL import Image
import json
from pathlib import Path
import os
from huggingface_hub import hf_hub_download
import numpy as np
class ModelPredictor:
def __init__(
self,
model_repo: str,
model_filename: str,
device: str = None,
):
self.device = (
device if device else ("cuda" if torch.cuda.is_available() else "cpu")
)
# Load the model
checkpoint_path = hf_hub_download(repo_id=model_repo, filename=model_filename)
self.model = self.load_model(checkpoint_path)
self.model.to(self.device)
self.model.eval()
# Setup transforms
self.transform = transforms.Compose(
[
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
),
]
)
# Load ImageNet class labels
self.class_labels = self.load_imagenet_labels()
def load_model(self, checkpoint_path: str):
"""Load the trained model from checkpoint"""
from pl_train import ImageNetModule
checkpoint = torch.load(checkpoint_path, map_location=self.device)
model = ImageNetModule(
learning_rate=0.156,
batch_size=1,
num_workers=0, # Set to 0 for Gradio
max_epochs=40,
train_path="",
val_path="",
checkpoint_dir="",
)
model.load_state_dict(checkpoint["state_dict"])
return model
def load_imagenet_labels(self):
"""Load ImageNet class labels"""
# For HuggingFace Spaces, we'll look for the labels file in the same directory
labels_path = Path("data/imagenet-simple-labels.json")
if labels_path.exists():
with open(labels_path) as f:
data = json.load(f)
return {str(i + 1): name for i, name in enumerate(data)}
return {str(i): f"class_{i}" for i in range(1000)} # Fallback
def predict(self, image):
"""
Make prediction for a single image
Args:
image: numpy array from Gradio
Returns:
Dictionary of class labels and probabilities
"""
try:
# Convert numpy array to PIL Image
if isinstance(image, np.ndarray):
# If image is from Gradio, it will be a numpy array
image = Image.fromarray(image.astype("uint8"))
elif isinstance(image, str):
# If image is a file path
image = Image.open(image)
# Ensure image is in RGB mode
if image.mode != "RGB":
image = image.convert("RGB")
# Apply transforms and predict
image_tensor = self.transform(image).unsqueeze(0)
image_tensor = image_tensor.to(self.device)
with torch.no_grad():
outputs = self.model(image_tensor)
probabilities = torch.nn.functional.softmax(outputs, dim=1)
# Get top 5 predictions
top_probs, top_indices = torch.topk(probabilities, 5)
# Create results dictionary
results = {}
for prob, idx in zip(top_probs[0], top_indices[0]):
class_name = self.class_labels[str(idx.item())]
results[class_name] = float(prob)
return results
except Exception as e:
print(f"Error in prediction: {str(e)}")
return {"error": 1.0}
# Initialize the predictor
try:
predictor = ModelPredictor(
model_repo="Adityak204/ResNetVision-1K", # Replace with your repo
model_filename="resnet50-epoch36-acc60.3506.ckpt", # Replace with your model filename
)
except Exception as e:
print(f"Error initializing predictor: {str(e)}")
def predict_image(image):
"""
Gradio interface function
Args:
image: numpy array from Gradio's image input
Returns:
Dictionary of predictions formatted for display
"""
if image is None:
return {"Error: No image provided": 1.0}
try:
predictions = predictor.predict(image)
# Format results for display
return predictions
except Exception as e:
print(f"Error in predict_image: {str(e)}")
return {"Error: Failed to process image": 1.0}
# Create Gradio interface
iface = gr.Interface(
fn=predict_image,
inputs=gr.Image(type="pil"),
outputs=gr.Label(num_top_classes=5),
title="ImageNet-1K Classification",
description="Upload an image to classify it into one of 1000 ImageNet categories",
# examples=(
# [
# ["ResNetVision-1K/data/ILSVRC2012_val_00000048.JPEG"],
# ["ResNetVision-1K/data/ILSVRC2012_val_00000090.JPEG"],
# ["ResNetVision-1K/data/ILSVRC2012_val_00000.JPEG"],
# ]
# if all(
# Path(f).exists()
# for f in [
# ["ResNetVision-1K/data/ILSVRC2012_val_00000048.JPEG"],
# ["ResNetVision-1K/data/ILSVRC2012_val_00000090.JPEG"],
# ["ResNetVision-1K/data/ILSVRC2012_val_00000.JPEG"],
# ]
# )
# else None
# ),
analytics_enabled=False,
)
# Launch the app
if __name__ == "__main__":
iface.launch()