import gradio as gr import torch from torchvision import transforms from PIL import Image import json from pathlib import Path import os from huggingface_hub import hf_hub_download import numpy as np class ModelPredictor: def __init__( self, model_repo: str, model_filename: str, device: str = None, ): self.device = ( device if device else ("cuda" if torch.cuda.is_available() else "cpu") ) # Load the model checkpoint_path = hf_hub_download(repo_id=model_repo, filename=model_filename) self.model = self.load_model(checkpoint_path) self.model.to(self.device) self.model.eval() # Setup transforms self.transform = transforms.Compose( [ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ), ] ) # Load ImageNet class labels self.class_labels = self.load_imagenet_labels() def load_model(self, checkpoint_path: str): """Load the trained model from checkpoint""" from pl_train import ImageNetModule checkpoint = torch.load(checkpoint_path, map_location=self.device) model = ImageNetModule( learning_rate=0.156, batch_size=1, num_workers=0, # Set to 0 for Gradio max_epochs=40, train_path="", val_path="", checkpoint_dir="", ) model.load_state_dict(checkpoint["state_dict"]) return model def load_imagenet_labels(self): """Load ImageNet class labels""" # For HuggingFace Spaces, we'll look for the labels file in the same directory labels_path = Path("data/imagenet-simple-labels.json") if labels_path.exists(): with open(labels_path) as f: data = json.load(f) return {str(i + 1): name for i, name in enumerate(data)} return {str(i): f"class_{i}" for i in range(1000)} # Fallback def predict(self, image): """ Make prediction for a single image Args: image: numpy array from Gradio Returns: Dictionary of class labels and probabilities """ try: # Convert numpy array to PIL Image if isinstance(image, np.ndarray): # If image is from Gradio, it will be a numpy array image = Image.fromarray(image.astype("uint8")) elif isinstance(image, str): # If image is a file path image = Image.open(image) # Ensure image is in RGB mode if image.mode != "RGB": image = image.convert("RGB") # Apply transforms and predict image_tensor = self.transform(image).unsqueeze(0) image_tensor = image_tensor.to(self.device) with torch.no_grad(): outputs = self.model(image_tensor) probabilities = torch.nn.functional.softmax(outputs, dim=1) # Get top 5 predictions top_probs, top_indices = torch.topk(probabilities, 5) # Create results dictionary results = {} for prob, idx in zip(top_probs[0], top_indices[0]): class_name = self.class_labels[str(idx.item())] results[class_name] = float(prob) return results except Exception as e: print(f"Error in prediction: {str(e)}") return {"error": 1.0} # Initialize the predictor try: predictor = ModelPredictor( model_repo="Adityak204/ResNetVision-1K", # Replace with your repo model_filename="resnet50-epoch36-acc60.3506.ckpt", # Replace with your model filename ) except Exception as e: print(f"Error initializing predictor: {str(e)}") def predict_image(image): """ Gradio interface function Args: image: numpy array from Gradio's image input Returns: Dictionary of predictions formatted for display """ if image is None: return {"Error: No image provided": 1.0} try: predictions = predictor.predict(image) # Format results for display return predictions except Exception as e: print(f"Error in predict_image: {str(e)}") return {"Error: Failed to process image": 1.0} # Create Gradio interface iface = gr.Interface( fn=predict_image, inputs=gr.Image(type="pil"), outputs=gr.Label(num_top_classes=5), title="ImageNet-1K Classification", description="Upload an image to classify it into one of 1000 ImageNet categories", # examples=( # [ # ["ResNetVision-1K/data/ILSVRC2012_val_00000048.JPEG"], # ["ResNetVision-1K/data/ILSVRC2012_val_00000090.JPEG"], # ["ResNetVision-1K/data/ILSVRC2012_val_00000.JPEG"], # ] # if all( # Path(f).exists() # for f in [ # ["ResNetVision-1K/data/ILSVRC2012_val_00000048.JPEG"], # ["ResNetVision-1K/data/ILSVRC2012_val_00000090.JPEG"], # ["ResNetVision-1K/data/ILSVRC2012_val_00000.JPEG"], # ] # ) # else None # ), analytics_enabled=False, ) # Launch the app if __name__ == "__main__": iface.launch()