白鹭先生
commited on
Commit
·
4a9bffc
1
Parent(s):
c1f6ea5
添加模型
Browse files- frame_field_learning/inference.py +15 -15
- requirements.txt +1 -0
frame_field_learning/inference.py
CHANGED
@@ -5,7 +5,7 @@ import scipy
|
|
5 |
|
6 |
import numpy as np
|
7 |
import torch
|
8 |
-
|
9 |
from . import local_utils
|
10 |
from . import polygonize
|
11 |
|
@@ -145,20 +145,20 @@ def load_checkpoint(model, checkpoints_dirpath, device):
|
|
145 |
"""
|
146 |
Loads best val checkpoint in checkpoints_dirpath
|
147 |
"""
|
148 |
-
filepaths = python_utils.get_filepaths(checkpoints_dirpath, startswith_str="checkpoint.best_val.",
|
149 |
-
|
150 |
-
if len(filepaths):
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
else:
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
device = torch.device(device)
|
163 |
checkpoint = torch.load(filepath, map_location=device) # map_location is used to load on current device
|
164 |
|
|
|
5 |
|
6 |
import numpy as np
|
7 |
import torch
|
8 |
+
from huggingface_hub import hf_hub_download
|
9 |
from . import local_utils
|
10 |
from . import polygonize
|
11 |
|
|
|
145 |
"""
|
146 |
Loads best val checkpoint in checkpoints_dirpath
|
147 |
"""
|
148 |
+
# filepaths = python_utils.get_filepaths(checkpoints_dirpath, startswith_str="checkpoint.best_val.",
|
149 |
+
# endswith_str=".tar")
|
150 |
+
# if len(filepaths):
|
151 |
+
# filepaths = sorted(filepaths)
|
152 |
+
# filepath = filepaths[-1] # Last best val checkpoint filepath in case there is more than one
|
153 |
+
# print_utils.print_info("Loading best val checkpoint: {}".format(filepath))
|
154 |
+
# else:
|
155 |
+
# # No best val checkpoint fount: find last checkpoint:
|
156 |
+
# filepaths = python_utils.get_filepaths(checkpoints_dirpath, endswith_str=".tar",
|
157 |
+
# startswith_str="checkpoint.")
|
158 |
+
# filepaths = sorted(filepaths)
|
159 |
+
# filepath = filepaths[-1] # Last checkpoint
|
160 |
+
# print_utils.print_info("Loading last checkpoint: {}".format(filepath))
|
161 |
+
filepath = hf_hub_download(repo_id="Egrt/Luuuu", filename="checkpoint.best_val.epoch_000047.tar")
|
162 |
device = torch.device(device)
|
163 |
checkpoint = torch.load(filepath, map_location=device) # map_location is used to load on current device
|
164 |
|
requirements.txt
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
cython
|
|
|
2 |
scipy==1.4.1
|
3 |
numpy==1.22.3
|
4 |
matplotlib==3.3.2
|
|
|
1 |
cython
|
2 |
+
huggingface_hub
|
3 |
scipy==1.4.1
|
4 |
numpy==1.22.3
|
5 |
matplotlib==3.3.2
|