File size: 13,981 Bytes
c43fbc6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
# Python file for making embeddings from a FusOn-pLM model for any dataset 
from fuson_plm.utils.embedding import get_esm_embeddings, load_esm2_type, redump_pickle_dictionary, load_prott5, get_prott5_embeddings
from fuson_plm.utils.logging import log_update, open_logfile, print_configpy
from fuson_plm.utils.data_cleaning import find_invalid_chars
from fuson_plm.utils.constants import VALID_AAS
from fuson_plm.training.model import FusOnpLM
from transformers import AutoModelForMaskedLM, AutoTokenizer, AutoModel
import logging
import torch
import pickle
import os
import pandas as pd
import numpy as np

def validate_sequence_col(df, seq_col):
    # if column isn't there, error
    if seq_col not in list(df.columns):
        raise Exception("Error: provided sequence column does not exist in the input dataframe")
    
    # if column contains invalid characters, error
    df['invalid_chars'] = df[seq_col].apply(lambda x: find_invalid_chars(x, VALID_AAS))
    all_invalid_chars = set().union(*df['invalid_chars'])
    df = df.drop(columns=['invalid_chars'])
    if len(all_invalid_chars)>0:
        raise Exception(f"Error: invalid characters {all_invalid_chars} found in the sequence column")
    
    # make sure there are no duplicates
    sequences = df[seq_col]
    if len(set(sequences))<len(sequences): log_update("\tWARNING: input data has duplicate sequences")

def load_fuson_model(ckpt_path):
    # Suppress warnings about newly initialized 'esm.pooler.dense.bias', 'esm.pooler.dense.weight' layers - these are not used to extract embeddings
    logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR)
    
    # Set device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Load model
    model = AutoModel.from_pretrained(ckpt_path)              # initialize model
    tokenizer = AutoTokenizer.from_pretrained(ckpt_path)      # initialize tokenizer

    # Model to device and in eval mode
    model.to(device)
    model.eval()  # disables dropout for deterministic results
    
    return model, tokenizer, device

def get_fuson_embeddings(model, tokenizer, sequences, device, average=True, print_updates=False, savepath=None, save_at_end=False, max_length=2000): 
    # Correct save path to pickle if necessary
    if savepath is not None:
        if savepath[-4::] != '.pkl': savepath += '.pkl'
    
    if print_updates: log_update(f"Dataset contains {len(sequences)} sequences.")
    
    # If no max length was passed, just set it to the maximum in the dataset
    max_seq_len = max([len(s) for s in sequences])
    if max_length is None: max_length=max_seq_len+2 # add 2 for BOS, EOS
    
    # Initialize an empty dict to store the ESM embeddings
    embedding_dict = {}
    # Iterate through the seqs
    for i in range(len(sequences)):
        sequence = sequences[i]
        # Get the embeddings
        with torch.no_grad():
            # Tokenize the input sequence
            inputs = tokenizer(sequence, return_tensors="pt", padding=True, truncation=True,max_length=max_length)
            inputs = {k: v.to(device) for k, v in inputs.items()}
        
            outputs = model(**inputs)
            # The embeddings are in the last_hidden_state tensor
            embedding = outputs.last_hidden_state
            # remove extra dimension
            embedding = embedding.squeeze(0)
            # remove BOS and EOS tokens
            embedding = embedding[1:-1, :]

            # Convert embeddings to numpy array (if needed)
            embedding = embedding.cpu().numpy()

         # Average (if necessary)   
        if average:
            embedding = embedding.mean(0)
        
        # Add to dictionary
        embedding_dict[sequence] = embedding
            
        # Save individual embedding (if necessary)
        if not(savepath is None) and not(save_at_end): 
            with open(savepath, 'ab+') as f:
                d = {sequence: embedding}
                pickle.dump(d, f)

        # Print update (if necessary)
        if print_updates: log_update(f"sequence {i+1}: {sequence[0:10]}...")
    
    # Dump all at once at the end (if necessary)
    if not(savepath is None):
        # If saving for the first time, just dump it
        if save_at_end:
            with open(savepath, 'wb') as f:
                pickle.dump(embedding_dict, f)
        # If we've been saving all along and made it here without crashing, correct the pickle file so it can be loaded nicely
        else:
            redump_pickle_dictionary(savepath)

def embed_dataset(path_to_file, path_to_output, seq_col='aa_seq', model_type='fuson_plm', fuson_ckpt_path = None, average=True, overwrite=True, print_updates=False,max_length=2000):
    # Make sure we aren't overwriting pre-existing embeddings
    if os.path.exists(path_to_output): 
        if overwrite:
            log_update(f"WARNING: these embeddings may already exist at {path_to_output} and will be overwritten")
        else:
            log_update(f"WARNING: these embeddings may already exist at {path_to_output}. Skipping.")
            return None
    
    dataset = pd.read_csv(path_to_file)
    # Make sure the sequence column is valid
    validate_sequence_col(dataset, seq_col) 
    
    sequences = dataset[seq_col].unique().tolist() # ensure all entries are unique

    ### If FusOn-pLM: make fusion embeddings
    if model_type=='fuson_plm':
        if not(os.path.exists(fuson_ckpt_path)): raise Exception("FusOn-pLM ckpt path does not exist")
        
        # Load model
        try:
            model, tokenizer, device = load_fuson_model(fuson_ckpt_path)
        except:
            raise Exception(f"Could not load FusOn-pLM from {fuson_ckpt_path}")
        
        # Generate embeddigns
        try:
            get_fuson_embeddings(model, tokenizer, sequences, device, average=average, 
                                 print_updates=print_updates, savepath=path_to_output, save_at_end=False,
                                 max_length=max_length)
        except:
            raise Exception("Could not generate FusOn-pLM embeddings")
    
    if model_type=='esm2_t33_650M_UR50D':
        # Load model
        try:
            model, tokenizer, device = load_esm2_type(model_type)
        except:
            raise Exception(f"Could not load {model_type}")
        # Generate embeddings
        try:
            get_esm_embeddings(model, tokenizer, sequences, device, average=average, 
                               print_updates=print_updates, savepath=path_to_output, save_at_end=False,
                               max_length=max_length)
        except:
            raise Exception(f"Could not generate {model_type} embeddings")
    
    if model_type=="prot_t5_xl_half_uniref50_enc":
        # Load model
        try:
            model, tokenizer, device = load_prott5()
        except:
            raise Exception(f"Could not load {model_type}")
        # Generate embeddings
        try:
            get_prott5_embeddings(model, tokenizer, sequences, device, average=average, 
                               print_updates=print_updates, savepath=path_to_output, save_at_end=False,
                               max_length=max_length)
        except:
            raise Exception(f"Could not generate {model_type} embeddings")
        
        
def embed_dataset_for_benchmark(fuson_ckpts=None, input_data_path=None, input_fname=None, average=True, seq_col='seq', benchmark_fusonplm=False, benchmark_esm=False, benchmark_fo_puncta_ml=False, benchmark_prott5=False, overwrite=False,max_length=None):
    # make directory for embeddings inside benchmarking dataset if one doesn't already eist
    os.makedirs('embeddings',exist_ok=True)
    
    # Extract input file name from configs
    emb_type_tag ='average' if average else '2D'
    
    all_embedding_paths = dict() # dictionary organized where embedding path points to model, epoch
    
    # make the embedding files. Put them in an embedding directory
    if benchmark_fusonplm:
        os.makedirs('embeddings/fuson_plm',exist_ok=True)
            
        log_update(f"\nMaking Fuson-PLM embeddings")
        # make subdirs for all the 
        if type(fuson_ckpts)==dict:
            for model_name, epoch_list in fuson_ckpts.items():
                os.makedirs(f'embeddings/fuson_plm/{model_name}',exist_ok=True)
                for epoch in epoch_list:    
                    # Assemble ckpt path and throw error if it doesn't exist
                    fuson_ckpt_path = f'../../training/checkpoints/{model_name}/checkpoint_epoch_{epoch}'
                    if not(os.path.exists(fuson_ckpt_path)): raise Exception(f"Error. Cannot find ckpt path: {fuson_ckpt_path}")
                    
                    # Make output directory and output embedding path
                    embedding_output_dir = f'embeddings/fuson_plm/{model_name}/epoch{epoch}'
                    embedding_output_path = f'{embedding_output_dir}/{input_fname}_{emb_type_tag}_embeddings.pkl'
                    os.makedirs(embedding_output_dir,exist_ok=True)
                    
                    # Make dictionary item
                    model_type = 'fuson_plm'
                    all_embedding_paths[embedding_output_path] = {
                        'model_type': model_type,
                        'model': model_name,
                        'epoch': epoch 
                    }
            
                    # Create embeddings (or skip if they're already made)
                    log_update(f"\tUsing ckpt {fuson_ckpt_path} and saving results to {embedding_output_path}...")
                    embed_dataset(input_data_path, embedding_output_path, 
                                seq_col=seq_col, model_type=model_type, 
                                fuson_ckpt_path=fuson_ckpt_path, average=average,
                                overwrite=overwrite,print_updates=True,
                                max_length=max_length)
        elif fuson_ckpts=="FusOn-pLM":
            model_name = "best"
            os.makedirs(f'embeddings/fuson_plm/{model_name}',exist_ok=True)
            
            # Assemble ckpt path and throw error if it doesn't exist
            fuson_ckpt_path = "../../.." # go back to the FusOn-pLM directory to find the best ckpt
            if not(os.path.exists(fuson_ckpt_path)): raise Exception(f"Error. Cannot find ckpt path: {fuson_ckpt_path}")
            
            # Make output directory and output embedding path
            embedding_output_dir = f'embeddings/fuson_plm/{model_name}'
            embedding_output_path = f'{embedding_output_dir}/{input_fname}_{emb_type_tag}_embeddings.pkl'
            os.makedirs(embedding_output_dir,exist_ok=True)
            
            # Make dictionary item
            model_type = 'fuson_plm'
            all_embedding_paths[embedding_output_path] = {
                'model_type': model_type,
                'model': model_name,
                'epoch': None 
            }
    
            # Create embeddings (or skip if they're already made)
            log_update(f"\tUsing ckpt {fuson_ckpt_path} and saving results to {embedding_output_path}...")
            embed_dataset(input_data_path, embedding_output_path, 
                        seq_col=seq_col, model_type=model_type, 
                        fuson_ckpt_path=fuson_ckpt_path, average=average,
                        overwrite=overwrite,print_updates=True,
                        max_length=max_length)
        else:
            raise Exception(f"Error. fuson_ckpts should be a dict or str")
    
    # make the embedding files. Put them in an embedding directory
    if benchmark_esm:
        os.makedirs('embeddings/esm2_t33_650M_UR50D',exist_ok=True)
        
        # make output path
        embedding_output_path = f'embeddings/esm2_t33_650M_UR50D/{input_fname}_{emb_type_tag}_embeddings.pkl'
        
        # Make dictioary item
        model_type = 'esm2_t33_650M_UR50D'
        all_embedding_paths[embedding_output_path] = {
                    'model_type': model_type,
                    'model': model_type,
                    'epoch': np.nan 
                }

        log_update(f"\nMaking ESM-2-650M embeddings for {input_data_path} and saving results to {embedding_output_path}...")
        embed_dataset(input_data_path, embedding_output_path, 
                    seq_col=seq_col, model_type=model_type, 
                    fuson_ckpt_path = None, average=average, 
                    overwrite=overwrite,print_updates=True,
                    max_length=max_length)
    
    if benchmark_prott5:
        os.makedirs('embeddings/prot_t5_xl_half_uniref50_enc',exist_ok=True)
        
        # make output path
        embedding_output_path = f'embeddings/prot_t5_xl_half_uniref50_enc/{input_fname}_{emb_type_tag}_embeddings.pkl'
        
        # Make dictioary item
        model_type = 'prot_t5_xl_half_uniref50_enc'
        all_embedding_paths[embedding_output_path] = {
                    'model_type': model_type,
                    'model': model_type,
                    'epoch': np.nan 
                }

        log_update(f"\nMaking ProtT5-XL-UniRef50 embeddings for {input_data_path} and saving results to {embedding_output_path}...")
        embed_dataset(input_data_path, embedding_output_path, 
                    seq_col=seq_col, model_type=model_type, 
                    fuson_ckpt_path = None, average=average, 
                    overwrite=overwrite,print_updates=True,
                    max_length=max_length)
        
    if benchmark_fo_puncta_ml:
        embedding_output_path =f'FOdb_physicochemical_embeddings.pkl'
        # Make dictionary item
        all_embedding_paths[embedding_output_path] = {
                    'model_type': 'fo_puncta_ml',
                    'model': 'fo_puncta_ml',
                    'epoch': np.nan 
                }
    
    return all_embedding_paths