minchul's picture
Upload directory
3ee5516 verified
raw
history blame
1.95 kB
import os
from typing import Union
import torch
from torch import device
from .utils import get_parameter_device, get_parameter_dtype, save_state_dict_and_config, load_state_dict_from_path
class BaseAligner(torch.nn.Module):
def __init__(self, config=None):
super().__init__()
self.config = config
@classmethod
def from_config(cls, config) -> "BaseAligner":
raise NotImplementedError('from_config must be implemented in subclass')
def make_train_transform(self):
raise NotImplementedError('from_config must be implemented in subclass')
def make_test_transform(self):
raise NotImplementedError('from_config must be implemented in subclass')
def forward(self, x):
raise NotImplementedError('from_config must be implemented in subclass')
def save_pretrained(
self,
save_dir: Union[str, os.PathLike],
name: str = 'model.pt',
rank: int = 0,
):
save_path = os.path.join(save_dir, name)
if rank == 0:
save_state_dict_and_config(self.state_dict(), self.config, save_path)
def load_state_dict_from_path(self, pretrained_model_path):
state_dict = load_state_dict_from_path(pretrained_model_path)
result = self.load_state_dict(state_dict)
print(f"Loaded pretrained aligner from {pretrained_model_path}")
@property
def device(self) -> device:
return get_parameter_device(self)
@property
def dtype(self) -> torch.dtype:
return get_parameter_dtype(self)
def num_parameters(self, only_trainable: bool = False) -> int:
return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable)
def has_trainable_params(self):
for param in self.parameters():
if param.requires_grad:
return True
return False
def has_params(self):
return len(list(self.parameters())) > 0