sky24h commited on
Commit
53f2335
·
1 Parent(s): 31abe01

bug-fix:load spiga ckpt from local

Browse files
checkpoints/spiga_300wpublic.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:98f014611ac25d549e89083992d9e6ade15da133c634a3883473abf2953cee2d
3
+ size 254397265
inference_utils.py CHANGED
@@ -10,17 +10,6 @@ torch.cuda.manual_seed_all(seed)
10
  torch.backends.cudnn.deterministic = True
11
  torch.backends.cudnn.benchmark = False
12
 
13
- # SPIGA ckpt downloading always fails, so we download it manually and put it in the right place.
14
- import spiga
15
- from gdown import download
16
-
17
- pkg_path = spiga.__file__.replace("/__init__.py", "")
18
- spiga_file_id = "1YrbScfMzrAAWMJQYgxdLZ9l57nmTdpQC"
19
- ckpt_path = os.path.join(pkg_path, "spiga/models/weights/spiga_300wpublic.pt")
20
- if not os.path.exists(ckpt_path):
21
- os.makedirs(os.path.dirname(ckpt_path), exist_ok=True)
22
- download(id=spiga_file_id, output=ckpt_path)
23
-
24
  from PIL import Image
25
  from gdown import download_folder
26
  from facelib import FaceDetector
 
10
  torch.backends.cudnn.deterministic = True
11
  torch.backends.cudnn.benchmark = False
12
 
 
 
 
 
 
 
 
 
 
 
 
13
  from PIL import Image
14
  from gdown import download_folder
15
  from facelib import FaceDetector
spiga_draw.py CHANGED
@@ -7,7 +7,16 @@ from facelib import FaceDetector
7
  from spiga.inference.config import ModelConfig
8
  from spiga.inference.framework import SPIGAFramework
9
 
10
- processor = SPIGAFramework(ModelConfig("300wpublic"))
 
 
 
 
 
 
 
 
 
11
 
12
  def center_crop(image, size):
13
  width, height = image.size
 
7
  from spiga.inference.config import ModelConfig
8
  from spiga.inference.framework import SPIGAFramework
9
 
10
+ # SPIGA ckpt downloading always fails, so we load it from the local path instead.
11
+ spiga_ckpt = os.path.join(os.path.dirname(__file__), "checkpoints/spiga_300wpublic.pt")
12
+ if not os.path.exists(spiga_ckpt):
13
+ from gdown import download
14
+ spiga_file_id = "1YrbScfMzrAAWMJQYgxdLZ9l57nmTdpQC"
15
+ download(id=spiga_file_id, output=spiga_ckpt)
16
+ spiga_config = ModelConfig("300wpublic")
17
+ spiga_config.load_model_url = False
18
+ spiga_config.model_weights_path = os.path.dirname(spiga_ckpt)
19
+ processor = SPIGAFramework(spiga_config)
20
 
21
  def center_crop(image, size):
22
  width, height = image.size