ndhieunguyen commited on
Commit
688df3c
·
1 Parent(s): f01494b

feat: map location

Browse files
Files changed (1) hide show
  1. app.py +4 -1
app.py CHANGED
@@ -42,7 +42,10 @@ def get_model():
42
  num_hidden_layers=12,
43
  )
44
  model.load_state_dict(
45
- torch.load(os.path.join("checkpoints", "PLAIN_ema_0.9999_360000.pt"))
 
 
 
46
  )
47
  model.eval()
48
  return model
 
42
  num_hidden_layers=12,
43
  )
44
  model.load_state_dict(
45
+ torch.load(
46
+ os.path.join("checkpoints", "PLAIN_ema_0.9999_360000.pt"),
47
+ map_location=torch.device("cpu"),
48
+ )
49
  )
50
  model.eval()
51
  return model