File size: 2,236 Bytes
6581479 126fb89 6581479 126fb89 6581479 d6ec985 6581479 126fb89 6581479 66d925a 6581479 126fb89 66d925a 126fb89 66d925a 126fb89 6581479 126fb89 6581479 126fb89 6581479 126fb89 6581479 126fb89 6581479 126fb89 6581479 126fb89 6581479 126fb89 6581479 126fb89 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 |
import torch
from torchvision import models, transforms
from PIL import Image, ImageDraw, ImageFont
import gradio as gr
# =======================
# CONFIGURATION
# =======================
device = "cuda" if torch.cuda.is_available() else "cpu"
MODEL_PATH = "cattle_breed_efficientnetb3_pytorch.pth" # Upload your .pth model here
CLASS_NAMES = ["Gir", "Deoni", "Murrah"]
# =======================
# MODEL: EfficientNetB3
# =======================
model = models.efficientnet_b3(weights=None)
model.classifier[1] = torch.nn.Linear(model.classifier[1].in_features, len(CLASS_NAMES))
# =======================
# LOAD CHECKPOINT (Feature Extractor Only)
# =======================
checkpoint = torch.load(MODEL_PATH, map_location=device)
# Remove classifier weights from checkpoint
checkpoint = {k: v for k, v in checkpoint.items() if "classifier" not in k}
model.load_state_dict(checkpoint, strict=False)
model.to(device)
model.eval()
# =======================
# IMAGE TRANSFORMS
# =======================
transform = transforms.Compose([
transforms.Resize((300, 300)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])
])
# =======================
# PREDICTION FUNCTION
# =======================
def predict(image):
image = image.convert("RGB")
img_tensor = transform(image).unsqueeze(0).to(device)
with torch.no_grad():
output = model(img_tensor)
probs = torch.nn.functional.softmax(output, dim=1)
conf, pred_idx = torch.max(probs, dim=1)
pred_label = CLASS_NAMES[pred_idx.item()]
confidence = conf.item() * 100
# Draw label on image
draw = ImageDraw.Draw(image)
font = ImageFont.load_default()
text = f"{pred_label} ({confidence:.2f}%)"
draw.text((10, 10), text, fill="red", font=font)
return image, text
# =======================
# GRADIO INTERFACE
# =======================
iface = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil"),
outputs=[gr.Image(type="pil"), "text"],
title="Indian Bovine Breed Classifier",
description="Upload an image of a cow and get the breed prediction with confidence."
)
if __name__ == "__main__":
iface.launch()
|