Spaces:
Running
on
Zero
Running
on
Zero
bug-fix:load spiga ckpt from local
Browse files- checkpoints/spiga_300wpublic.pt +3 -0
- inference_utils.py +0 -11
- spiga_draw.py +10 -1
checkpoints/spiga_300wpublic.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:98f014611ac25d549e89083992d9e6ade15da133c634a3883473abf2953cee2d
|
3 |
+
size 254397265
|
inference_utils.py
CHANGED
@@ -10,17 +10,6 @@ torch.cuda.manual_seed_all(seed)
|
|
10 |
torch.backends.cudnn.deterministic = True
|
11 |
torch.backends.cudnn.benchmark = False
|
12 |
|
13 |
-
# SPIGA ckpt downloading always fails, so we download it manually and put it in the right place.
|
14 |
-
import spiga
|
15 |
-
from gdown import download
|
16 |
-
|
17 |
-
pkg_path = spiga.__file__.replace("/__init__.py", "")
|
18 |
-
spiga_file_id = "1YrbScfMzrAAWMJQYgxdLZ9l57nmTdpQC"
|
19 |
-
ckpt_path = os.path.join(pkg_path, "spiga/models/weights/spiga_300wpublic.pt")
|
20 |
-
if not os.path.exists(ckpt_path):
|
21 |
-
os.makedirs(os.path.dirname(ckpt_path), exist_ok=True)
|
22 |
-
download(id=spiga_file_id, output=ckpt_path)
|
23 |
-
|
24 |
from PIL import Image
|
25 |
from gdown import download_folder
|
26 |
from facelib import FaceDetector
|
|
|
10 |
torch.backends.cudnn.deterministic = True
|
11 |
torch.backends.cudnn.benchmark = False
|
12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
from PIL import Image
|
14 |
from gdown import download_folder
|
15 |
from facelib import FaceDetector
|
spiga_draw.py
CHANGED
@@ -7,7 +7,16 @@ from facelib import FaceDetector
|
|
7 |
from spiga.inference.config import ModelConfig
|
8 |
from spiga.inference.framework import SPIGAFramework
|
9 |
|
10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
def center_crop(image, size):
|
13 |
width, height = image.size
|
|
|
7 |
from spiga.inference.config import ModelConfig
|
8 |
from spiga.inference.framework import SPIGAFramework
|
9 |
|
10 |
+
# SPIGA ckpt downloading always fails, so we load it from the local path instead.
|
11 |
+
spiga_ckpt = os.path.join(os.path.dirname(__file__), "checkpoints/spiga_300wpublic.pt")
|
12 |
+
if not os.path.exists(spiga_ckpt):
|
13 |
+
from gdown import download
|
14 |
+
spiga_file_id = "1YrbScfMzrAAWMJQYgxdLZ9l57nmTdpQC"
|
15 |
+
download(id=spiga_file_id, output=spiga_ckpt)
|
16 |
+
spiga_config = ModelConfig("300wpublic")
|
17 |
+
spiga_config.load_model_url = False
|
18 |
+
spiga_config.model_weights_path = os.path.dirname(spiga_ckpt)
|
19 |
+
processor = SPIGAFramework(spiga_config)
|
20 |
|
21 |
def center_crop(image, size):
|
22 |
width, height = image.size
|