Spaces:
Runtime error
Runtime error
birdortyedi
commited on
Commit
·
2ee4de9
1
Parent(s):
74de975
hf hub added
Browse files
app.py
CHANGED
|
@@ -18,14 +18,15 @@ cfg.MODEL.CKPT = model_path
|
|
| 18 |
net, _ = build_model(cfg)
|
| 19 |
net = net.eval()
|
| 20 |
vgg16 = models.vgg16(pretrained=True).features.eval()
|
|
|
|
| 21 |
|
| 22 |
|
| 23 |
-
def load_checkpoints_from_ckpt(ckpt_path):
|
| 24 |
-
checkpoints = torch.load(ckpt_path, map_location=
|
| 25 |
net.load_state_dict(checkpoints["ifr"])
|
| 26 |
|
| 27 |
|
| 28 |
-
load_checkpoints_from_ckpt(cfg.MODEL.CKPT)
|
| 29 |
|
| 30 |
|
| 31 |
def filter_removal(img):
|
|
|
|
| 18 |
net, _ = build_model(cfg)
|
| 19 |
net = net.eval()
|
| 20 |
vgg16 = models.vgg16(pretrained=True).features.eval()
|
| 21 |
+
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
| 22 |
|
| 23 |
|
| 24 |
+
def load_checkpoints_from_ckpt(ckpt_path, device):
|
| 25 |
+
checkpoints = torch.load(ckpt_path, map_location=device)
|
| 26 |
net.load_state_dict(checkpoints["ifr"])
|
| 27 |
|
| 28 |
|
| 29 |
+
load_checkpoints_from_ckpt(cfg.MODEL.CKPT, device)
|
| 30 |
|
| 31 |
|
| 32 |
def filter_removal(img):
|