File size: 5,561 Bytes
905e42f
 
 
 
 
 
 
 
6fae1ea
905e42f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ea55197
905e42f
 
 
ea55197
 
 
 
 
 
 
 
905e42f
ea55197
 
 
905e42f
ea55197
 
 
905e42f
ea55197
 
 
905e42f
ea55197
 
905e42f
ea55197
 
 
 
 
 
 
 
 
 
 
905e42f
 
 
ea55197
 
 
 
 
 
 
905e42f
 
 
 
 
ea55197
 
 
 
905e42f
ea55197
 
 
 
 
 
b555f64
ea55197
 
 
905e42f
 
 
 
 
 
 
 
 
e927cf1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
905e42f
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
import gradio as gr
import torch
from torchvision import transforms
from PIL import Image
import json
from pathlib import Path
import os
from huggingface_hub import hf_hub_download
import numpy as np


class ModelPredictor:
    def __init__(
        self,
        model_repo: str,
        model_filename: str,
        device: str = None,
    ):
        self.device = (
            device if device else ("cuda" if torch.cuda.is_available() else "cpu")
        )

        # Load the model
        checkpoint_path = hf_hub_download(repo_id=model_repo, filename=model_filename)
        self.model = self.load_model(checkpoint_path)
        self.model.to(self.device)
        self.model.eval()

        # Setup transforms
        self.transform = transforms.Compose(
            [
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                transforms.Normalize(
                    mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
                ),
            ]
        )

        # Load ImageNet class labels
        self.class_labels = self.load_imagenet_labels()

    def load_model(self, checkpoint_path: str):
        """Load the trained model from checkpoint"""
        from pl_train import ImageNetModule

        checkpoint = torch.load(checkpoint_path, map_location=self.device)
        model = ImageNetModule(
            learning_rate=0.156,
            batch_size=1,
            num_workers=0,  # Set to 0 for Gradio
            max_epochs=40,
            train_path="",
            val_path="",
            checkpoint_dir="",
        )
        model.load_state_dict(checkpoint["state_dict"])
        return model

    def load_imagenet_labels(self):
        """Load ImageNet class labels"""
        # For HuggingFace Spaces, we'll look for the labels file in the same directory
        labels_path = Path("data/imagenet-simple-labels.json")

        if labels_path.exists():
            with open(labels_path) as f:
                data = json.load(f)
                return {str(i + 1): name for i, name in enumerate(data)}
        return {str(i): f"class_{i}" for i in range(1000)}  # Fallback

    def predict(self, image):
        """
        Make prediction for a single image
        Args:
            image: numpy array from Gradio
        Returns:
            Dictionary of class labels and probabilities
        """
        try:
            # Convert numpy array to PIL Image
            if isinstance(image, np.ndarray):
                # If image is from Gradio, it will be a numpy array
                image = Image.fromarray(image.astype("uint8"))
            elif isinstance(image, str):
                # If image is a file path
                image = Image.open(image)

            # Ensure image is in RGB mode
            if image.mode != "RGB":
                image = image.convert("RGB")

            # Apply transforms and predict
            image_tensor = self.transform(image).unsqueeze(0)
            image_tensor = image_tensor.to(self.device)

            with torch.no_grad():
                outputs = self.model(image_tensor)
                probabilities = torch.nn.functional.softmax(outputs, dim=1)

                # Get top 5 predictions
                top_probs, top_indices = torch.topk(probabilities, 5)

                # Create results dictionary
                results = {}
                for prob, idx in zip(top_probs[0], top_indices[0]):
                    class_name = self.class_labels[str(idx.item())]
                    results[class_name] = float(prob)

            return results

        except Exception as e:
            print(f"Error in prediction: {str(e)}")
            return {"error": 1.0}


# Initialize the predictor
try:
    predictor = ModelPredictor(
        model_repo="Adityak204/ResNetVision-1K",  # Replace with your repo
        model_filename="resnet50-epoch36-acc60.3506.ckpt",  # Replace with your model filename
    )
except Exception as e:
    print(f"Error initializing predictor: {str(e)}")


def predict_image(image):
    """
    Gradio interface function
    Args:
        image: numpy array from Gradio's image input
    Returns:
        Dictionary of predictions formatted for display
    """
    if image is None:
        return {"Error: No image provided": 1.0}

    try:
        predictions = predictor.predict(image)
        # Format results for display
        return predictions
    except Exception as e:
        print(f"Error in predict_image: {str(e)}")
        return {"Error: Failed to process image": 1.0}


# Create Gradio interface
iface = gr.Interface(
    fn=predict_image,
    inputs=gr.Image(type="pil"),
    outputs=gr.Label(num_top_classes=5),
    title="ImageNet-1K Classification",
    description="Upload an image to classify it into one of 1000 ImageNet categories",
    # examples=(
    #     [
    #         ["ResNetVision-1K/data/ILSVRC2012_val_00000048.JPEG"],
    #         ["ResNetVision-1K/data/ILSVRC2012_val_00000090.JPEG"],
    #         ["ResNetVision-1K/data/ILSVRC2012_val_00000.JPEG"],
    #     ]
    #     if all(
    #         Path(f).exists()
    #         for f in [
    #             ["ResNetVision-1K/data/ILSVRC2012_val_00000048.JPEG"],
    #             ["ResNetVision-1K/data/ILSVRC2012_val_00000090.JPEG"],
    #             ["ResNetVision-1K/data/ILSVRC2012_val_00000.JPEG"],
    #         ]
    #     )
    #     else None
    # ),
    analytics_enabled=False,
)

# Launch the app
if __name__ == "__main__":
    iface.launch()