Spaces:
Sleeping
Sleeping
File size: 5,561 Bytes
905e42f 6fae1ea 905e42f ea55197 905e42f ea55197 905e42f ea55197 905e42f ea55197 905e42f ea55197 905e42f ea55197 905e42f ea55197 905e42f ea55197 905e42f ea55197 905e42f ea55197 b555f64 ea55197 905e42f e927cf1 905e42f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 |
import gradio as gr
import torch
from torchvision import transforms
from PIL import Image
import json
from pathlib import Path
import os
from huggingface_hub import hf_hub_download
import numpy as np
class ModelPredictor:
def __init__(
self,
model_repo: str,
model_filename: str,
device: str = None,
):
self.device = (
device if device else ("cuda" if torch.cuda.is_available() else "cpu")
)
# Load the model
checkpoint_path = hf_hub_download(repo_id=model_repo, filename=model_filename)
self.model = self.load_model(checkpoint_path)
self.model.to(self.device)
self.model.eval()
# Setup transforms
self.transform = transforms.Compose(
[
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
),
]
)
# Load ImageNet class labels
self.class_labels = self.load_imagenet_labels()
def load_model(self, checkpoint_path: str):
"""Load the trained model from checkpoint"""
from pl_train import ImageNetModule
checkpoint = torch.load(checkpoint_path, map_location=self.device)
model = ImageNetModule(
learning_rate=0.156,
batch_size=1,
num_workers=0, # Set to 0 for Gradio
max_epochs=40,
train_path="",
val_path="",
checkpoint_dir="",
)
model.load_state_dict(checkpoint["state_dict"])
return model
def load_imagenet_labels(self):
"""Load ImageNet class labels"""
# For HuggingFace Spaces, we'll look for the labels file in the same directory
labels_path = Path("data/imagenet-simple-labels.json")
if labels_path.exists():
with open(labels_path) as f:
data = json.load(f)
return {str(i + 1): name for i, name in enumerate(data)}
return {str(i): f"class_{i}" for i in range(1000)} # Fallback
def predict(self, image):
"""
Make prediction for a single image
Args:
image: numpy array from Gradio
Returns:
Dictionary of class labels and probabilities
"""
try:
# Convert numpy array to PIL Image
if isinstance(image, np.ndarray):
# If image is from Gradio, it will be a numpy array
image = Image.fromarray(image.astype("uint8"))
elif isinstance(image, str):
# If image is a file path
image = Image.open(image)
# Ensure image is in RGB mode
if image.mode != "RGB":
image = image.convert("RGB")
# Apply transforms and predict
image_tensor = self.transform(image).unsqueeze(0)
image_tensor = image_tensor.to(self.device)
with torch.no_grad():
outputs = self.model(image_tensor)
probabilities = torch.nn.functional.softmax(outputs, dim=1)
# Get top 5 predictions
top_probs, top_indices = torch.topk(probabilities, 5)
# Create results dictionary
results = {}
for prob, idx in zip(top_probs[0], top_indices[0]):
class_name = self.class_labels[str(idx.item())]
results[class_name] = float(prob)
return results
except Exception as e:
print(f"Error in prediction: {str(e)}")
return {"error": 1.0}
# Initialize the predictor
try:
predictor = ModelPredictor(
model_repo="Adityak204/ResNetVision-1K", # Replace with your repo
model_filename="resnet50-epoch36-acc60.3506.ckpt", # Replace with your model filename
)
except Exception as e:
print(f"Error initializing predictor: {str(e)}")
def predict_image(image):
"""
Gradio interface function
Args:
image: numpy array from Gradio's image input
Returns:
Dictionary of predictions formatted for display
"""
if image is None:
return {"Error: No image provided": 1.0}
try:
predictions = predictor.predict(image)
# Format results for display
return predictions
except Exception as e:
print(f"Error in predict_image: {str(e)}")
return {"Error: Failed to process image": 1.0}
# Create Gradio interface
iface = gr.Interface(
fn=predict_image,
inputs=gr.Image(type="pil"),
outputs=gr.Label(num_top_classes=5),
title="ImageNet-1K Classification",
description="Upload an image to classify it into one of 1000 ImageNet categories",
# examples=(
# [
# ["ResNetVision-1K/data/ILSVRC2012_val_00000048.JPEG"],
# ["ResNetVision-1K/data/ILSVRC2012_val_00000090.JPEG"],
# ["ResNetVision-1K/data/ILSVRC2012_val_00000.JPEG"],
# ]
# if all(
# Path(f).exists()
# for f in [
# ["ResNetVision-1K/data/ILSVRC2012_val_00000048.JPEG"],
# ["ResNetVision-1K/data/ILSVRC2012_val_00000090.JPEG"],
# ["ResNetVision-1K/data/ILSVRC2012_val_00000.JPEG"],
# ]
# )
# else None
# ),
analytics_enabled=False,
)
# Launch the app
if __name__ == "__main__":
iface.launch()
|