|
from .base import BaseAligner |
|
|
|
|
|
def get_aligner(aligner_cfg): |
|
|
|
if aligner_cfg.name == 'none': |
|
from .none import NoneAligner |
|
aligner = NoneAligner.from_config(aligner_cfg) |
|
elif aligner_cfg.name == 'retinaface_aligner': |
|
from .retinaface_aligner import RetinaFaceAligner |
|
aligner = RetinaFaceAligner.from_config(aligner_cfg) |
|
elif aligner_cfg.name == 'differentiable_face_aligner': |
|
from .differentiable_face_aligner import DifferentiableFaceAligner |
|
aligner = DifferentiableFaceAligner.from_config(aligner_cfg) |
|
else: |
|
raise ValueError(f"Unknown classifier: {aligner_cfg.name}") |
|
|
|
if aligner_cfg.start_from: |
|
aligner.load_state_dict_from_path(aligner_cfg.start_from) |
|
|
|
if aligner_cfg.freeze: |
|
for param in aligner.parameters(): |
|
param.requires_grad = False |
|
return aligner |
|
|
|
|