白鹭先生 commited on
Commit
4a9bffc
·
1 Parent(s): c1f6ea5

添加模型

Browse files
frame_field_learning/inference.py CHANGED
@@ -5,7 +5,7 @@ import scipy
5
 
6
  import numpy as np
7
  import torch
8
-
9
  from . import local_utils
10
  from . import polygonize
11
 
@@ -145,20 +145,20 @@ def load_checkpoint(model, checkpoints_dirpath, device):
145
  """
146
  Loads best val checkpoint in checkpoints_dirpath
147
  """
148
- filepaths = python_utils.get_filepaths(checkpoints_dirpath, startswith_str="checkpoint.best_val.",
149
- endswith_str=".tar")
150
- if len(filepaths):
151
- filepaths = sorted(filepaths)
152
- filepath = filepaths[-1] # Last best val checkpoint filepath in case there is more than one
153
- print_utils.print_info("Loading best val checkpoint: {}".format(filepath))
154
- else:
155
- # No best val checkpoint fount: find last checkpoint:
156
- filepaths = python_utils.get_filepaths(checkpoints_dirpath, endswith_str=".tar",
157
- startswith_str="checkpoint.")
158
- filepaths = sorted(filepaths)
159
- filepath = filepaths[-1] # Last checkpoint
160
- print_utils.print_info("Loading last checkpoint: {}".format(filepath))
161
-
162
  device = torch.device(device)
163
  checkpoint = torch.load(filepath, map_location=device) # map_location is used to load on current device
164
 
 
5
 
6
  import numpy as np
7
  import torch
8
+ from huggingface_hub import hf_hub_download
9
  from . import local_utils
10
  from . import polygonize
11
 
 
145
  """
146
  Loads best val checkpoint in checkpoints_dirpath
147
  """
148
+ # filepaths = python_utils.get_filepaths(checkpoints_dirpath, startswith_str="checkpoint.best_val.",
149
+ # endswith_str=".tar")
150
+ # if len(filepaths):
151
+ # filepaths = sorted(filepaths)
152
+ # filepath = filepaths[-1] # Last best val checkpoint filepath in case there is more than one
153
+ # print_utils.print_info("Loading best val checkpoint: {}".format(filepath))
154
+ # else:
155
+ # # No best val checkpoint fount: find last checkpoint:
156
+ # filepaths = python_utils.get_filepaths(checkpoints_dirpath, endswith_str=".tar",
157
+ # startswith_str="checkpoint.")
158
+ # filepaths = sorted(filepaths)
159
+ # filepath = filepaths[-1] # Last checkpoint
160
+ # print_utils.print_info("Loading last checkpoint: {}".format(filepath))
161
+ filepath = hf_hub_download(repo_id="Egrt/Luuuu", filename="checkpoint.best_val.epoch_000047.tar")
162
  device = torch.device(device)
163
  checkpoint = torch.load(filepath, map_location=device) # map_location is used to load on current device
164
 
requirements.txt CHANGED
@@ -1,4 +1,5 @@
1
  cython
 
2
  scipy==1.4.1
3
  numpy==1.22.3
4
  matplotlib==3.3.2
 
1
  cython
2
+ huggingface_hub
3
  scipy==1.4.1
4
  numpy==1.22.3
5
  matplotlib==3.3.2