Spaces:
Runtime error
Runtime error
update
Browse files
app.py
CHANGED
@@ -29,7 +29,7 @@ CLASSES = (
|
|
29 |
def load_classifer(model_path):
|
30 |
# load resnet model
|
31 |
model = ResNet(ResidualBlock, [2, 2, 2])
|
32 |
-
model.load_state_dict(torch.load(model_path))
|
33 |
model.eval()
|
34 |
return model
|
35 |
|
|
|
29 |
def load_classifer(model_path):
|
30 |
# load resnet model
|
31 |
model = ResNet(ResidualBlock, [2, 2, 2])
|
32 |
+
model.load_state_dict(torch.load(model_path, map_location=device))
|
33 |
model.eval()
|
34 |
return model
|
35 |
|
attack.py
CHANGED
@@ -3,7 +3,7 @@ from torchvision import transforms
|
|
3 |
|
4 |
|
5 |
class Attack:
|
6 |
-
def __init__(self, pipe, classifer, device="
|
7 |
self.device = device
|
8 |
self.pipe = pipe
|
9 |
self.generator = torch.Generator(device=self.device).manual_seed(1024)
|
|
|
3 |
|
4 |
|
5 |
class Attack:
|
6 |
+
def __init__(self, pipe, classifer, device="cpu"):
|
7 |
self.device = device
|
8 |
self.pipe = pipe
|
9 |
self.generator = torch.Generator(device=self.device).manual_seed(1024)
|