Gizachew commited on
Commit
1b9e776
·
verified ·
1 Parent(s): 52540c8

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +1 -1
model.py CHANGED
@@ -52,7 +52,7 @@ class ViTRecognitionModel(nn.Module):
52
 
53
  def load_model(model_path, device='cpu'):
54
  model = ViTRecognitionModel(vocab_size=vocab_size, hidden_dim=768, max_length=20)
55
- model.load_state_dict(torch.load(model_path, map_location=device))
56
  model.to(device)
57
  model.eval()
58
  return model
 
52
 
53
  def load_model(model_path, device='cpu'):
54
  model = ViTRecognitionModel(vocab_size=vocab_size, hidden_dim=768, max_length=20)
55
+ model.load_state_dict(torch.load(model_path, map_location=device, weights_only=True))
56
  model.to(device)
57
  model.eval()
58
  return model