Spaces:
Sleeping
Sleeping
| import os | |
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import numpy as np | |
| from PIL import Image | |
| import torchvision.transforms as transforms | |
| import matplotlib.pyplot as plt | |
| import timm | |
| class BaseModel(nn.Module): | |
| def predict(self, x: torch.Tensor) -> torch.Tensor: | |
| with torch.no_grad(): | |
| logits = self(x) | |
| return F.softmax(logits, dim=1) | |
| def get_num_classes(self) -> int: | |
| raise NotImplementedError | |
| class CNNModel(BaseModel): | |
| def __init__(self, num_classes: int, input_size: int = 224): | |
| super(CNNModel, self).__init__() | |
| self.conv_layers = nn.Sequential( | |
| # First block: 32 filters | |
| nn.Conv2d(3, 32, kernel_size=3, padding=1), | |
| nn.BatchNorm2d(32), | |
| nn.ReLU(), | |
| nn.MaxPool2d(2), | |
| # Second block: 64 filters | |
| nn.Conv2d(32, 64, kernel_size=3, padding=1), | |
| nn.BatchNorm2d(64), | |
| nn.ReLU(), | |
| nn.MaxPool2d(2), | |
| # Third block: 128 filters | |
| nn.Conv2d(64, 128, kernel_size=3, padding=1), | |
| nn.BatchNorm2d(128), | |
| nn.ReLU(), | |
| nn.MaxPool2d(2), | |
| # Global Average Pooling | |
| nn.AdaptiveAvgPool2d(1) | |
| ) | |
| self.classifier = nn.Sequential( | |
| nn.Flatten(), | |
| nn.Dropout(0.5), | |
| nn.Linear(128, 256), | |
| nn.ReLU(), | |
| nn.Dropout(0.3), | |
| nn.Linear(256, num_classes) | |
| ) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x = self.conv_layers(x) | |
| return self.classifier(x) | |
| def get_num_classes(self) -> int: | |
| return self.classifier[-1].out_features | |
| class EfficientNetModel(BaseModel): | |
| def __init__( | |
| self, | |
| num_classes: int, | |
| model_name: str = "efficientnet_b0", | |
| pretrained: bool = True | |
| ): | |
| super(EfficientNetModel, self).__init__() | |
| self.base_model = timm.create_model( | |
| model_name, | |
| pretrained=pretrained, | |
| num_classes=0 | |
| ) | |
| with torch.no_grad(): | |
| dummy_input = torch.randn(1, 3, 224, 224) | |
| features = self.base_model(dummy_input) | |
| feature_dim = features.shape[1] | |
| self.classifier = nn.Sequential( | |
| nn.Dropout(0.2), | |
| nn.Linear(feature_dim, num_classes) | |
| ) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| features = self.base_model(x) | |
| return self.classifier(features) | |
| def get_num_classes(self) -> int: | |
| return self.classifier[-1].out_features | |
| class AnimalClassifierApp: | |
| def __init__(self): | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| self.labels = ["bird", "cat", "dog", "horse"] | |
| self.transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize( | |
| mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225] | |
| ) | |
| ]) | |
| self.models = self.load_models() | |
| if not self.models: | |
| print("Warning: No models found in checkpoints directory!") | |
| def load_models(self): | |
| models = {} | |
| # Load EfficientNet | |
| try: | |
| efficientnet = EfficientNetModel(num_classes=len(self.labels)) | |
| efficientnet_path = "efficientnet_best_model.pth" | |
| if os.path.exists(efficientnet_path): | |
| checkpoint = torch.load(efficientnet_path, map_location=self.device, weights_only=True) | |
| state_dict = checkpoint.get('model_state_dict', checkpoint) | |
| efficientnet.load_state_dict(state_dict, strict=False) | |
| efficientnet.eval() | |
| models['EfficientNet'] = efficientnet | |
| print("Successfully loaded EfficientNet model") | |
| except Exception as e: | |
| print(f"Error loading EfficientNet model: {str(e)}") | |
| # Load CNN | |
| try: | |
| cnn = CNNModel(num_classes=len(self.labels)) | |
| cnn_path = "cnn_best_model.pth" | |
| if os.path.exists(cnn_path): | |
| checkpoint = torch.load(cnn_path, map_location=self.device, weights_only=True) | |
| state_dict = checkpoint.get('model_state_dict', checkpoint) | |
| cnn.load_state_dict(state_dict, strict=False) | |
| cnn.eval() | |
| models['CNN'] = cnn | |
| print("Successfully loaded CNN model") | |
| except Exception as e: | |
| print(f"Error loading CNN model: {str(e)}") | |
| return models | |
| def predict(self, image: Image.Image): | |
| if not self.models: | |
| return "No trained models found. Please train the models first." | |
| img_tensor = self.transform(image).unsqueeze(0).to(self.device) | |
| results = {} | |
| probabilities = {} | |
| for model_name, model in self.models.items(): | |
| with torch.no_grad(): | |
| output = model(img_tensor) | |
| probs = F.softmax(output, dim=1).squeeze().cpu().numpy() | |
| probabilities[model_name] = probs | |
| pred_idx = np.argmax(probs) | |
| pred_label = self.labels[pred_idx] | |
| pred_prob = probs[pred_idx] | |
| results[model_name] = (pred_label, pred_prob) | |
| fig = plt.figure(figsize=(12, 5)) | |
| if 'EfficientNet' in probabilities: | |
| plt.subplot(1, 2, 1) | |
| plt.bar(self.labels, probabilities['EfficientNet'], color='skyblue') | |
| plt.title('EfficientNet Predictions') | |
| plt.ylim(0, 1) | |
| plt.xticks(rotation=45) | |
| plt.ylabel('Probability') | |
| if 'CNN' in probabilities: | |
| plt.subplot(1, 2, 2) | |
| plt.bar(self.labels, probabilities['CNN'], color='lightcoral') | |
| plt.title('CNN Predictions') | |
| plt.ylim(0, 1) | |
| plt.xticks(rotation=45) | |
| plt.ylabel('Probability') | |
| plt.tight_layout() | |
| text_results = "Model Predictions:\n\n" | |
| for model_name, (label, prob) in results.items(): | |
| text_results += f"{model_name}:\n" | |
| text_results += f"Top prediction: {label} ({prob:.2%})\n" | |
| text_results += "All probabilities:\n" | |
| for label, prob in zip(self.labels, probabilities[model_name]): | |
| text_results += f" {label}: {prob:.2%}\n" | |
| text_results += "\n" | |
| return [fig, text_results] | |
| def create_interface(self): | |
| return gr.Interface( | |
| fn=self.predict, | |
| inputs=gr.Image(type="pil"), | |
| outputs=[ | |
| gr.Plot(label="Prediction Probabilities"), | |
| gr.Textbox(label="Detailed Results", lines=10) | |
| ], | |
| title="Animal Classifier - Model Comparison", | |
| description="Upload an image of an animal to see predictions from both EfficientNet and CNN models." | |
| ) | |
| def main(): | |
| app = AnimalClassifierApp() | |
| interface = app.create_interface() | |
| interface.launch() | |
| if __name__ == "__main__": | |
| main() |