jhj0517
commited on
Commit
·
7962f8d
1
Parent(s):
7a623dc
Refactor
Browse files
modules/live_portrait/live_portrait_inferencer.py
CHANGED
|
@@ -55,6 +55,11 @@ class LivePortraitInferencer:
|
|
| 55 |
self.d_info = None
|
| 56 |
|
| 57 |
def load_models(self):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
self.download_if_no_models()
|
| 59 |
|
| 60 |
appearance_feat_config = self.model_config["appearance_feature_extractor_params"]
|
|
@@ -85,11 +90,6 @@ class LivePortraitInferencer:
|
|
| 85 |
os.path.join(self.model_dir, "spade_generator.safetensors")
|
| 86 |
)
|
| 87 |
|
| 88 |
-
def filter_stitcher(checkpoint, prefix):
|
| 89 |
-
filtered_checkpoint = {key.replace(prefix + "_module.", ""): value for key, value in checkpoint.items() if
|
| 90 |
-
key.startswith(prefix)}
|
| 91 |
-
return filtered_checkpoint
|
| 92 |
-
|
| 93 |
stitcher_config = self.model_config["stitching_retargeting_module_params"]
|
| 94 |
self.stitching_retargeting_module = StitchingRetargetingNetwork(**stitcher_config.get('stitching'))
|
| 95 |
stitcher_model_path = os.path.join(self.model_dir, "stitching_retargeting_module.safetensors")
|
|
|
|
| 55 |
self.d_info = None
|
| 56 |
|
| 57 |
def load_models(self):
|
| 58 |
+
def filter_stitcher(checkpoint, prefix):
|
| 59 |
+
filtered_checkpoint = {key.replace(prefix + "_module.", ""): value for key, value in checkpoint.items() if
|
| 60 |
+
key.startswith(prefix)}
|
| 61 |
+
return filtered_checkpoint
|
| 62 |
+
|
| 63 |
self.download_if_no_models()
|
| 64 |
|
| 65 |
appearance_feat_config = self.model_config["appearance_feature_extractor_params"]
|
|
|
|
| 90 |
os.path.join(self.model_dir, "spade_generator.safetensors")
|
| 91 |
)
|
| 92 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
stitcher_config = self.model_config["stitching_retargeting_module_params"]
|
| 94 |
self.stitching_retargeting_module = StitchingRetargetingNetwork(**stitcher_config.get('stitching'))
|
| 95 |
stitcher_model_path = os.path.join(self.model_dir, "stitching_retargeting_module.safetensors")
|