minchul commited on
Commit
3ee5516
·
verified ·
1 Parent(s): 977ddbe

Upload directory

Browse files
Files changed (1) hide show
  1. aligners/base/__init__.py +60 -0
aligners/base/__init__.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Union
3
+ import torch
4
+ from torch import device
5
+ from .utils import get_parameter_device, get_parameter_dtype, save_state_dict_and_config, load_state_dict_from_path
6
+
7
+ class BaseAligner(torch.nn.Module):
8
+
9
+ def __init__(self, config=None):
10
+ super().__init__()
11
+ self.config = config
12
+
13
+ @classmethod
14
+ def from_config(cls, config) -> "BaseAligner":
15
+ raise NotImplementedError('from_config must be implemented in subclass')
16
+
17
+ def make_train_transform(self):
18
+ raise NotImplementedError('from_config must be implemented in subclass')
19
+
20
+ def make_test_transform(self):
21
+ raise NotImplementedError('from_config must be implemented in subclass')
22
+
23
+ def forward(self, x):
24
+ raise NotImplementedError('from_config must be implemented in subclass')
25
+
26
+ def save_pretrained(
27
+ self,
28
+ save_dir: Union[str, os.PathLike],
29
+ name: str = 'model.pt',
30
+ rank: int = 0,
31
+ ):
32
+ save_path = os.path.join(save_dir, name)
33
+ if rank == 0:
34
+ save_state_dict_and_config(self.state_dict(), self.config, save_path)
35
+
36
+ def load_state_dict_from_path(self, pretrained_model_path):
37
+ state_dict = load_state_dict_from_path(pretrained_model_path)
38
+ result = self.load_state_dict(state_dict)
39
+ print(f"Loaded pretrained aligner from {pretrained_model_path}")
40
+
41
+
42
+ @property
43
+ def device(self) -> device:
44
+ return get_parameter_device(self)
45
+
46
+ @property
47
+ def dtype(self) -> torch.dtype:
48
+ return get_parameter_dtype(self)
49
+
50
+ def num_parameters(self, only_trainable: bool = False) -> int:
51
+ return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable)
52
+
53
+ def has_trainable_params(self):
54
+ for param in self.parameters():
55
+ if param.requires_grad:
56
+ return True
57
+ return False
58
+
59
+ def has_params(self):
60
+ return len(list(self.parameters())) > 0