File size: 2,236 Bytes
6581479
 
126fb89
6581479
 
 
126fb89
6581479
 
d6ec985
6581479
 
 
126fb89
6581479
66d925a
6581479
126fb89
66d925a
 
 
126fb89
66d925a
 
126fb89
 
6581479
 
 
 
126fb89
6581479
 
 
 
 
 
 
 
 
126fb89
6581479
 
 
 
126fb89
6581479
 
126fb89
 
 
 
 
 
 
 
 
 
 
 
 
6581479
 
126fb89
6581479
 
 
 
126fb89
6581479
126fb89
6581479
 
126fb89
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import torch
from torchvision import models, transforms
from PIL import Image, ImageDraw, ImageFont
import gradio as gr

# =======================
# CONFIGURATION
# =======================
device = "cuda" if torch.cuda.is_available() else "cpu"
MODEL_PATH = "cattle_breed_efficientnetb3_pytorch.pth"  # Upload your .pth model here
CLASS_NAMES = ["Gir", "Deoni", "Murrah"]

# =======================
# MODEL: EfficientNetB3
# =======================
model = models.efficientnet_b3(weights=None)
model.classifier[1] = torch.nn.Linear(model.classifier[1].in_features, len(CLASS_NAMES))

# =======================
# LOAD CHECKPOINT (Feature Extractor Only)
# =======================
checkpoint = torch.load(MODEL_PATH, map_location=device)
# Remove classifier weights from checkpoint
checkpoint = {k: v for k, v in checkpoint.items() if "classifier" not in k}
model.load_state_dict(checkpoint, strict=False)

model.to(device)
model.eval()

# =======================
# IMAGE TRANSFORMS
# =======================
transform = transforms.Compose([
    transforms.Resize((300, 300)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

# =======================
# PREDICTION FUNCTION
# =======================
def predict(image):
    image = image.convert("RGB")
    img_tensor = transform(image).unsqueeze(0).to(device)
    
    with torch.no_grad():
        output = model(img_tensor)
        probs = torch.nn.functional.softmax(output, dim=1)
        conf, pred_idx = torch.max(probs, dim=1)
    
    pred_label = CLASS_NAMES[pred_idx.item()]
    confidence = conf.item() * 100

    # Draw label on image
    draw = ImageDraw.Draw(image)
    font = ImageFont.load_default()
    text = f"{pred_label} ({confidence:.2f}%)"
    draw.text((10, 10), text, fill="red", font=font)

    return image, text

# =======================
# GRADIO INTERFACE
# =======================
iface = gr.Interface(
    fn=predict,
    inputs=gr.Image(type="pil"),
    outputs=[gr.Image(type="pil"), "text"],
    title="Indian Bovine Breed Classifier",
    description="Upload an image of a cow and get the breed prediction with confidence."
)

if __name__ == "__main__":
    iface.launch()