Pie31415 commited on
Commit
cedb7e1
·
1 Parent(s): 5912c7b
Files changed (2) hide show
  1. app.py +1 -1
  2. attack.py +1 -1
app.py CHANGED
@@ -29,7 +29,7 @@ CLASSES = (
29
  def load_classifer(model_path):
30
  # load resnet model
31
  model = ResNet(ResidualBlock, [2, 2, 2])
32
- model.load_state_dict(torch.load(model_path))
33
  model.eval()
34
  return model
35
 
 
29
  def load_classifer(model_path):
30
  # load resnet model
31
  model = ResNet(ResidualBlock, [2, 2, 2])
32
+ model.load_state_dict(torch.load(model_path, map_location=device))
33
  model.eval()
34
  return model
35
 
attack.py CHANGED
@@ -3,7 +3,7 @@ from torchvision import transforms
3
 
4
 
5
  class Attack:
6
- def __init__(self, pipe, classifer, device="cuda"):
7
  self.device = device
8
  self.pipe = pipe
9
  self.generator = torch.Generator(device=self.device).manual_seed(1024)
 
3
 
4
 
5
  class Attack:
6
+ def __init__(self, pipe, classifer, device="cpu"):
7
  self.device = device
8
  self.pipe = pipe
9
  self.generator = torch.Generator(device=self.device).manual_seed(1024)