dhe1raj commited on
Commit
126fb89
·
verified ·
1 Parent(s): 9283415

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -14
app.py CHANGED
@@ -1,26 +1,31 @@
1
  import torch
2
  from torchvision import models, transforms
3
- from PIL import Image
4
  import gradio as gr
5
 
6
  # =======================
7
- # Configuration
8
  # =======================
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
- MODEL_PATH = "cattle_breed_efficientnetb3_pytorch.pth" # Upload this to the Space
11
  CLASS_NAMES = ["Gir", "Deoni", "Murrah"]
12
 
13
  # =======================
14
- # Load Model
15
  # =======================
16
- model = models.efficientnet_b3(pretrained=False)
 
17
  model.classifier[1] = torch.nn.Linear(model.classifier[1].in_features, len(CLASS_NAMES))
18
- model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
 
 
 
 
19
  model.to(device)
20
  model.eval()
21
 
22
  # =======================
23
- # Image Preprocessing
24
  # =======================
25
  transform = transforms.Compose([
26
  transforms.Resize((300, 300)),
@@ -30,25 +35,38 @@ transform = transforms.Compose([
30
  ])
31
 
32
  # =======================
33
- # Prediction Function
34
  # =======================
35
  def predict(image):
36
  image = image.convert("RGB")
37
  img_tensor = transform(image).unsqueeze(0).to(device)
 
38
  with torch.no_grad():
39
  output = model(img_tensor)
40
- pred_idx = torch.argmax(output, dim=1).item()
41
- return CLASS_NAMES[pred_idx]
 
 
 
 
 
 
 
 
 
 
 
42
 
43
  # =======================
44
- # Gradio Interface
45
  # =======================
46
  iface = gr.Interface(
47
  fn=predict,
48
  inputs=gr.Image(type="pil"),
49
- outputs="text",
50
  title="Indian Bovine Breed Classifier",
51
- description="Upload an image of a cow and the model will predict its breed."
52
  )
53
 
54
- iface.launch()
 
 
1
  import torch
2
  from torchvision import models, transforms
3
+ from PIL import Image, ImageDraw, ImageFont
4
  import gradio as gr
5
 
6
  # =======================
7
+ # CONFIGURATION
8
  # =======================
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
+ MODEL_PATH = "cattle_breed_efficientnetb3_pytorch.pth" # Upload your .pth model here
11
  CLASS_NAMES = ["Gir", "Deoni", "Murrah"]
12
 
13
  # =======================
14
+ # MODEL: EfficientNetB3
15
  # =======================
16
+ model = models.efficientnet_b3(weights=None) # Do NOT load pretrained weights here
17
+ # Update classifier for 3 classes
18
  model.classifier[1] = torch.nn.Linear(model.classifier[1].in_features, len(CLASS_NAMES))
19
+
20
+ # Load checkpoint safely (ignores classifier mismatch)
21
+ checkpoint = torch.load(MODEL_PATH, map_location=device)
22
+ model.load_state_dict(checkpoint, strict=False)
23
+
24
  model.to(device)
25
  model.eval()
26
 
27
  # =======================
28
+ # IMAGE TRANSFORMS
29
  # =======================
30
  transform = transforms.Compose([
31
  transforms.Resize((300, 300)),
 
35
  ])
36
 
37
  # =======================
38
+ # PREDICTION FUNCTION
39
  # =======================
40
  def predict(image):
41
  image = image.convert("RGB")
42
  img_tensor = transform(image).unsqueeze(0).to(device)
43
+
44
  with torch.no_grad():
45
  output = model(img_tensor)
46
+ probs = torch.nn.functional.softmax(output, dim=1)
47
+ conf, pred_idx = torch.max(probs, dim=1)
48
+
49
+ pred_label = CLASS_NAMES[pred_idx.item()]
50
+ confidence = conf.item() * 100
51
+
52
+ # Draw label on image
53
+ draw = ImageDraw.Draw(image)
54
+ font = ImageFont.load_default()
55
+ text = f"{pred_label} ({confidence:.2f}%)"
56
+ draw.text((10, 10), text, fill="red", font=font)
57
+
58
+ return image, text
59
 
60
  # =======================
61
+ # GRADIO INTERFACE
62
  # =======================
63
  iface = gr.Interface(
64
  fn=predict,
65
  inputs=gr.Image(type="pil"),
66
+ outputs=[gr.Image(type="pil"), "text"],
67
  title="Indian Bovine Breed Classifier",
68
+ description="Upload an image of a cow and get the breed prediction with confidence."
69
  )
70
 
71
+ if __name__ == "__main__":
72
+ iface.launch()