import torch from torchvision import models, transforms from PIL import Image, ImageDraw, ImageFont import gradio as gr # ======================= # CONFIGURATION # ======================= device = "cuda" if torch.cuda.is_available() else "cpu" MODEL_PATH = "cattle_breed_efficientnetb3_pytorch.pth" # Upload your .pth model here CLASS_NAMES = ["Gir", "Deoni", "Murrah"] # ======================= # MODEL: EfficientNetB3 # ======================= model = models.efficientnet_b3(weights=None) model.classifier[1] = torch.nn.Linear(model.classifier[1].in_features, len(CLASS_NAMES)) # ======================= # LOAD CHECKPOINT (Feature Extractor Only) # ======================= checkpoint = torch.load(MODEL_PATH, map_location=device) # Remove classifier weights from checkpoint checkpoint = {k: v for k, v in checkpoint.items() if "classifier" not in k} model.load_state_dict(checkpoint, strict=False) model.to(device) model.eval() # ======================= # IMAGE TRANSFORMS # ======================= transform = transforms.Compose([ transforms.Resize((300, 300)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) # ======================= # PREDICTION FUNCTION # ======================= def predict(image): image = image.convert("RGB") img_tensor = transform(image).unsqueeze(0).to(device) with torch.no_grad(): output = model(img_tensor) probs = torch.nn.functional.softmax(output, dim=1) conf, pred_idx = torch.max(probs, dim=1) pred_label = CLASS_NAMES[pred_idx.item()] confidence = conf.item() * 100 # Draw label on image draw = ImageDraw.Draw(image) font = ImageFont.load_default() text = f"{pred_label} ({confidence:.2f}%)" draw.text((10, 10), text, fill="red", font=font) return image, text # ======================= # GRADIO INTERFACE # ======================= iface = gr.Interface( fn=predict, inputs=gr.Image(type="pil"), outputs=[gr.Image(type="pil"), "text"], title="Indian Bovine Breed Classifier", description="Upload an image of a cow and get the breed prediction with confidence." ) if __name__ == "__main__": iface.launch()