ndhieunguyen commited on
Commit
f01494b
·
1 Parent(s): 77180e4

fix: load model

Browse files
Files changed (1) hide show
  1. app.py +1 -4
app.py CHANGED
@@ -42,10 +42,7 @@ def get_model():
42
  num_hidden_layers=12,
43
  )
44
  model.load_state_dict(
45
- dist_util.load_state_dict(
46
- os.path.join("checkpoints", "PLAIN_ema_0.9999_360000.pt"),
47
- map_location="cpu",
48
- )
49
  )
50
  model.eval()
51
  return model
 
42
  num_hidden_layers=12,
43
  )
44
  model.load_state_dict(
45
+ torch.load(os.path.join("checkpoints", "PLAIN_ema_0.9999_360000.pt"))
 
 
 
46
  )
47
  model.eval()
48
  return model