Spaces:
Runtime error
Runtime error
Set map_location in torch.load call
Browse files- shakespeare_demo.py +2 -1
shakespeare_demo.py
CHANGED
|
@@ -22,7 +22,8 @@ with open('config.yaml', 'r') as f:
|
|
| 22 |
#%%
|
| 23 |
with open('model_state_dict.pt') as f:
|
| 24 |
state_dict = t.load(
|
| 25 |
-
'model_state_dict.pt'
|
|
|
|
| 26 |
)
|
| 27 |
#%%
|
| 28 |
base_config = transformer_replication.TransformerConfig(
|
|
|
|
| 22 |
#%%
|
| 23 |
with open('model_state_dict.pt') as f:
|
| 24 |
state_dict = t.load(
|
| 25 |
+
'model_state_dict.pt',
|
| 26 |
+
map_location=device,
|
| 27 |
)
|
| 28 |
#%%
|
| 29 |
base_config = transformer_replication.TransformerConfig(
|