Shilpaj commited on
Commit
ebf74f1
·
verified ·
1 Parent(s): cbb6ccb

Fix: Runtime error for CPU

Browse files
Files changed (1) hide show
  1. app.py +4 -3
app.py CHANGED
@@ -20,14 +20,15 @@ def load_model(model_path: str):
20
  # Load the pre-trained ResNet50 model from ImageNet
21
  model = models.resnet50(pretrained=False)
22
 
23
- # Load custom weights from a .pth file
24
- state_dict = torch.load(model_path)
25
 
26
  # Filter out unexpected keys
27
  filtered_state_dict = {k: v for k, v in state_dict['model_state_dict'].items() if k in model.state_dict()}
28
 
29
  # Load the filtered state dictionary into the model
30
- model.load_state_dict(filtered_state_dict, strict=False)
 
31
  return model
32
 
33
 
 
20
  # Load the pre-trained ResNet50 model from ImageNet
21
  model = models.resnet50(pretrained=False)
22
 
23
+ # Load custom weights from a .pth file with CPU mapping
24
+ state_dict = torch.load(model_path, map_location=torch.device('cpu'))
25
 
26
  # Filter out unexpected keys
27
  filtered_state_dict = {k: v for k, v in state_dict['model_state_dict'].items() if k in model.state_dict()}
28
 
29
  # Load the filtered state dictionary into the model
30
+ model.load_state_dict(filtered_state_dict, strict=False)
31
+ model.eval()
32
  return model
33
 
34