Update model.py
Browse files
model.py
CHANGED
@@ -52,6 +52,7 @@ class ViTRecognitionModel(nn.Module):
|
|
52 |
|
53 |
def load_model(model_path, device='cpu'):
|
54 |
model = ViTRecognitionModel(vocab_size=vocab_size, hidden_dim=768, max_length=20)
|
|
|
55 |
model.load_state_dict(torch.load(model_path, map_location=device, weights_only=True))
|
56 |
model.to(device)
|
57 |
model.eval()
|
|
|
52 |
|
53 |
def load_model(model_path, device='cpu'):
|
54 |
model = ViTRecognitionModel(vocab_size=vocab_size, hidden_dim=768, max_length=20)
|
55 |
+
# Set weights_only=True to address the FutureWarning
|
56 |
model.load_state_dict(torch.load(model_path, map_location=device, weights_only=True))
|
57 |
model.to(device)
|
58 |
model.eval()
|