Fill-Mask
Transformers
Safetensors
esm
File size: 4,668 Bytes
9a73cb0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
from transformers import AutoModelForMaskedLM, AutoTokenizer, AutoModel
import torch
import os

class FusOnTokenizer:
    """
    FusOnTokenizer class: a wrapper around AutoTokenizer
    """
    def __init__(self, pretrained_path='facebook/esm2_t33_650M_UR50D'):
        self.tokenizer = AutoTokenizer.from_pretrained(pretrained_path)
        
    def __getattr__(self, name):
        """
        Delegate attribute access to the underlying tokenizer.
        This allows calls like .tokenize(), .train(), and .eval() to be forwarded to the tokenizer.
        """
        return getattr(self.tokenizer, name)

    def __call__(self, *args, **kwargs):
        """
        Make the FusOnTokenizer object callable, delegating to the tokenizer's __call__ method.
        """
        return self.tokenizer(*args, **kwargs)

    def save_tokenizer(self, save_directory):
        self.tokenizer.save_pretrained(save_directory)

    def load_tokenizer(self, load_directory):
        self.tokenizer = AutoTokenizer.from_pretrained(load_directory)

class FusOnpLM:
    """
    FusOn-pLM class: a wrapper around AutoModelForMaskedLM
    """
    def __init__(self, pretrained_path='facebook/esm2_t33_650M_UR50D', ckpt_path = None, mlm_head=False):
        if not(ckpt_path is None):
            self.load_model(ckpt_path, mlm_head)
        else:
            # Load the pre-trained model and tokenizer
            self.model = AutoModelForMaskedLM.from_pretrained(pretrained_path)
            self.tokenizer = FusOnTokenizer(pretrained_path)
            
        self.n_layers = self.count_encoder_layers()
    
    def __getattr__(self, name):
        """
        Delegate attribute access to the underlying model.
        This allows calls like .to(), .train(), and .eval() to be forwarded to the model.
        """
        return getattr(self.model, name)
    
    def __call__(self, *args, **kwargs):
        """
        Make the FusOnpLM object callable, delegating to the model's __call__ method.
        """
        return self.model(*args, **kwargs)
    
    def freeze_model(self):
        """
        Freezes all parameters in the model
        """
        for param in self.model.parameters(): 
            param.requires_grad = False
    
    def unfreeze_last_n_layers(self, n_unfrozen_layers, unfreeze_query=True, unfreeze_key=True, unfreeze_value=True):
        """
        Unfreezes specific parts of the final n layers in the model's encoder.

        Args:
            n_unfrozen_layers (int): Number of final layers to unfreeze.
            unfreeze_query (bool): Whether to unfreeze the query projections. Default is True.
            unfreeze_key (bool): Whether to unfreeze the key projections. Default is True.
            unfreeze_value (bool): Whether to unfreeze the value projections. Default is True.
        """
        for i, layer in enumerate(self.model.esm.encoder.layer):
            if (self.n_layers - i) <= n_unfrozen_layers:  # Only the last n layers
                if unfreeze_query:
                    self._unfreeze_parameters(layer.attention.self.query)
                if unfreeze_key:
                    self._unfreeze_parameters(layer.attention.self.key)
                if unfreeze_value:
                    self._unfreeze_parameters(layer.attention.self.value)

    def _unfreeze_parameters(self, module):
        """
        Helper method to unfreeze parameters in a given module.
        
        Args:
            module (nn.Module): The module whose parameters are to be unfrozen.
        """
        for param in module.parameters():
            param.requires_grad = True

    
    def count_encoder_layers(self):
        """
        Count the number of encoder layers in the model.
        """
        return len(self.model.esm.encoder.layer)

    def save_model(self, save_directory, optimizer=None):
        # Save the model and tokenizer
        self.model.save_pretrained(save_directory)
        self.tokenizer.save_pretrained(save_directory)
        
        # If an optimizer is provided, save its state dict
        if optimizer is not None:
            optimizer_path = os.path.join(save_directory, "optimizer.pt")
            torch.save(optimizer.state_dict(), optimizer_path)

    def load_model(self, load_directory, mlm_head):
        # Load a checkpoint of the model either with or without an MLM head
        if mlm_head:
            self.model = AutoModelForMaskedLM.from_pretrained(load_directory)
        else:
        # Load the model and tokenizer from a directory
            self.model = AutoModel.from_pretrained(load_directory)
        self.tokenizer = AutoTokenizer.from_pretrained(load_directory)