Spaces:
Paused
Paused
yjwtheonly
commited on
Commit
·
f6678bd
1
Parent(s):
ac8e861
modification
Browse files- DiseaseSpecific/utils.py +2 -2
DiseaseSpecific/utils.py
CHANGED
|
@@ -71,13 +71,13 @@ def load_model(model_path, args, n_ent, n_rel, device):
|
|
| 71 |
model = add_model(args, n_ent, n_rel)
|
| 72 |
model.to(device)
|
| 73 |
logger.info('Loading saved model from {0}'.format(model_path))
|
| 74 |
-
state = torch.load(model_path)
|
| 75 |
model_params = state['state_dict']
|
| 76 |
params = [(key, value.size(), value.numel()) for key, value in model_params.items()]
|
| 77 |
for key, size, count in params:
|
| 78 |
logger.info('Key:{0}, Size:{1}, Count:{2}'.format(key, size, count))
|
| 79 |
|
| 80 |
-
model.load_state_dict(model_params
|
| 81 |
model.eval()
|
| 82 |
logger.info(model)
|
| 83 |
|
|
|
|
| 71 |
model = add_model(args, n_ent, n_rel)
|
| 72 |
model.to(device)
|
| 73 |
logger.info('Loading saved model from {0}'.format(model_path))
|
| 74 |
+
state = torch.load(model_path, map_location=device)
|
| 75 |
model_params = state['state_dict']
|
| 76 |
params = [(key, value.size(), value.numel()) for key, value in model_params.items()]
|
| 77 |
for key, size, count in params:
|
| 78 |
logger.info('Key:{0}, Size:{1}, Count:{2}'.format(key, size, count))
|
| 79 |
|
| 80 |
+
model.load_state_dict(model_params)
|
| 81 |
model.eval()
|
| 82 |
logger.info(model)
|
| 83 |
|