|
import os |
|
from typing import Union |
|
import torch |
|
from torch import device |
|
from .utils import get_parameter_device, get_parameter_dtype, save_state_dict_and_config, load_state_dict_from_path |
|
|
|
class BaseAligner(torch.nn.Module): |
|
|
|
def __init__(self, config=None): |
|
super().__init__() |
|
self.config = config |
|
|
|
@classmethod |
|
def from_config(cls, config) -> "BaseAligner": |
|
raise NotImplementedError('from_config must be implemented in subclass') |
|
|
|
def make_train_transform(self): |
|
raise NotImplementedError('from_config must be implemented in subclass') |
|
|
|
def make_test_transform(self): |
|
raise NotImplementedError('from_config must be implemented in subclass') |
|
|
|
def forward(self, x): |
|
raise NotImplementedError('from_config must be implemented in subclass') |
|
|
|
def save_pretrained( |
|
self, |
|
save_dir: Union[str, os.PathLike], |
|
name: str = 'model.pt', |
|
rank: int = 0, |
|
): |
|
save_path = os.path.join(save_dir, name) |
|
if rank == 0: |
|
save_state_dict_and_config(self.state_dict(), self.config, save_path) |
|
|
|
def load_state_dict_from_path(self, pretrained_model_path): |
|
state_dict = load_state_dict_from_path(pretrained_model_path) |
|
result = self.load_state_dict(state_dict) |
|
print(f"Loaded pretrained aligner from {pretrained_model_path}") |
|
|
|
|
|
@property |
|
def device(self) -> device: |
|
return get_parameter_device(self) |
|
|
|
@property |
|
def dtype(self) -> torch.dtype: |
|
return get_parameter_dtype(self) |
|
|
|
def num_parameters(self, only_trainable: bool = False) -> int: |
|
return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable) |
|
|
|
def has_trainable_params(self): |
|
for param in self.parameters(): |
|
if param.requires_grad: |
|
return True |
|
return False |
|
|
|
def has_params(self): |
|
return len(list(self.parameters())) > 0 |