Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from torchvision import transforms | |
from PIL import Image | |
import json | |
from pathlib import Path | |
import os | |
from huggingface_hub import hf_hub_download | |
import numpy as np | |
class ModelPredictor: | |
def __init__( | |
self, | |
model_repo: str, | |
model_filename: str, | |
device: str = None, | |
): | |
self.device = ( | |
device if device else ("cuda" if torch.cuda.is_available() else "cpu") | |
) | |
# Load the model | |
checkpoint_path = hf_hub_download(repo_id=model_repo, filename=model_filename) | |
self.model = self.load_model(checkpoint_path) | |
self.model.to(self.device) | |
self.model.eval() | |
# Setup transforms | |
self.transform = transforms.Compose( | |
[ | |
transforms.Resize(256), | |
transforms.CenterCrop(224), | |
transforms.ToTensor(), | |
transforms.Normalize( | |
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] | |
), | |
] | |
) | |
# Load ImageNet class labels | |
self.class_labels = self.load_imagenet_labels() | |
def load_model(self, checkpoint_path: str): | |
"""Load the trained model from checkpoint""" | |
from pl_train import ImageNetModule | |
checkpoint = torch.load(checkpoint_path, map_location=self.device) | |
model = ImageNetModule( | |
learning_rate=0.156, | |
batch_size=1, | |
num_workers=0, # Set to 0 for Gradio | |
max_epochs=40, | |
train_path="", | |
val_path="", | |
checkpoint_dir="", | |
) | |
model.load_state_dict(checkpoint["state_dict"]) | |
return model | |
def load_imagenet_labels(self): | |
"""Load ImageNet class labels""" | |
# For HuggingFace Spaces, we'll look for the labels file in the same directory | |
labels_path = Path("data/imagenet-simple-labels.json") | |
if labels_path.exists(): | |
with open(labels_path) as f: | |
data = json.load(f) | |
return {str(i + 1): name for i, name in enumerate(data)} | |
return {str(i): f"class_{i}" for i in range(1000)} # Fallback | |
def predict(self, image): | |
""" | |
Make prediction for a single image | |
Args: | |
image: numpy array from Gradio | |
Returns: | |
Dictionary of class labels and probabilities | |
""" | |
try: | |
# Convert numpy array to PIL Image | |
if isinstance(image, np.ndarray): | |
# If image is from Gradio, it will be a numpy array | |
image = Image.fromarray(image.astype("uint8")) | |
elif isinstance(image, str): | |
# If image is a file path | |
image = Image.open(image) | |
# Ensure image is in RGB mode | |
if image.mode != "RGB": | |
image = image.convert("RGB") | |
# Apply transforms and predict | |
image_tensor = self.transform(image).unsqueeze(0) | |
image_tensor = image_tensor.to(self.device) | |
with torch.no_grad(): | |
outputs = self.model(image_tensor) | |
probabilities = torch.nn.functional.softmax(outputs, dim=1) | |
# Get top 5 predictions | |
top_probs, top_indices = torch.topk(probabilities, 5) | |
# Create results dictionary | |
results = {} | |
for prob, idx in zip(top_probs[0], top_indices[0]): | |
class_name = self.class_labels[str(idx.item())] | |
results[class_name] = float(prob) | |
return results | |
except Exception as e: | |
print(f"Error in prediction: {str(e)}") | |
return {"error": 1.0} | |
# Initialize the predictor | |
try: | |
predictor = ModelPredictor( | |
model_repo="Adityak204/ResNetVision-1K", # Replace with your repo | |
model_filename="resnet50-epoch36-acc60.3506.ckpt", # Replace with your model filename | |
) | |
except Exception as e: | |
print(f"Error initializing predictor: {str(e)}") | |
def predict_image(image): | |
""" | |
Gradio interface function | |
Args: | |
image: numpy array from Gradio's image input | |
Returns: | |
Dictionary of predictions formatted for display | |
""" | |
if image is None: | |
return {"Error: No image provided": 1.0} | |
try: | |
predictions = predictor.predict(image) | |
# Format results for display | |
return predictions | |
except Exception as e: | |
print(f"Error in predict_image: {str(e)}") | |
return {"Error: Failed to process image": 1.0} | |
# Create Gradio interface | |
iface = gr.Interface( | |
fn=predict_image, | |
inputs=gr.Image(type="pil"), | |
outputs=gr.Label(num_top_classes=5), | |
title="ImageNet-1K Classification", | |
description="Upload an image to classify it into one of 1000 ImageNet categories", | |
# examples=( | |
# [ | |
# ["ResNetVision-1K/data/ILSVRC2012_val_00000048.JPEG"], | |
# ["ResNetVision-1K/data/ILSVRC2012_val_00000090.JPEG"], | |
# ["ResNetVision-1K/data/ILSVRC2012_val_00000.JPEG"], | |
# ] | |
# if all( | |
# Path(f).exists() | |
# for f in [ | |
# ["ResNetVision-1K/data/ILSVRC2012_val_00000048.JPEG"], | |
# ["ResNetVision-1K/data/ILSVRC2012_val_00000090.JPEG"], | |
# ["ResNetVision-1K/data/ILSVRC2012_val_00000.JPEG"], | |
# ] | |
# ) | |
# else None | |
# ), | |
analytics_enabled=False, | |
) | |
# Launch the app | |
if __name__ == "__main__": | |
iface.launch() | |