dhe1raj commited on
Commit
66d925a
·
verified ·
1 Parent(s): 126fb89

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -4
app.py CHANGED
@@ -7,18 +7,21 @@ import gradio as gr
7
  # CONFIGURATION
8
  # =======================
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
- MODEL_PATH = "cattle_breed_efficientnetb3_pytorch.pth" # Upload your .pth model here
11
  CLASS_NAMES = ["Gir", "Deoni", "Murrah"]
12
 
13
  # =======================
14
  # MODEL: EfficientNetB3
15
  # =======================
16
- model = models.efficientnet_b3(weights=None) # Do NOT load pretrained weights here
17
- # Update classifier for 3 classes
18
  model.classifier[1] = torch.nn.Linear(model.classifier[1].in_features, len(CLASS_NAMES))
19
 
20
- # Load checkpoint safely (ignores classifier mismatch)
 
 
21
  checkpoint = torch.load(MODEL_PATH, map_location=device)
 
 
22
  model.load_state_dict(checkpoint, strict=False)
23
 
24
  model.to(device)
 
7
  # CONFIGURATION
8
  # =======================
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
+ MODEL_PATH = "best_model.pth" # Upload your .pth model here
11
  CLASS_NAMES = ["Gir", "Deoni", "Murrah"]
12
 
13
  # =======================
14
  # MODEL: EfficientNetB3
15
  # =======================
16
+ model = models.efficientnet_b3(weights=None)
 
17
  model.classifier[1] = torch.nn.Linear(model.classifier[1].in_features, len(CLASS_NAMES))
18
 
19
+ # =======================
20
+ # LOAD CHECKPOINT (Feature Extractor Only)
21
+ # =======================
22
  checkpoint = torch.load(MODEL_PATH, map_location=device)
23
+ # Remove classifier weights from checkpoint
24
+ checkpoint = {k: v for k, v in checkpoint.items() if "classifier" not in k}
25
  model.load_state_dict(checkpoint, strict=False)
26
 
27
  model.to(device)