File size: 16,607 Bytes
e048d40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
# first few imports, just to set CUDA_VISIBLE_DEVICES before importing any torch libraries
from fuson_plm.benchmarking.idr_prediction.config import TRAIN
import os
os.environ['CUDA_VISIBLE_DEVICES'] = TRAIN.CUDA_VISIBLE_DEVICES

import torch
import pandas as pd
import numpy as np
import pickle
from sklearn.metrics import r2_score
from sklearn.model_selection import ParameterGrid
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning import Trainer

from fuson_plm.benchmarking.embed import embed_dataset_for_benchmark
from fuson_plm.benchmarking.idr_prediction.model import ProteinMLPOneHot, ProteinMLPESM, LossTrackerCallback
from fuson_plm.benchmarking.idr_prediction.utils import IDRProtDataset, IDRDataModule
from fuson_plm.benchmarking.idr_prediction.plot import lengthen_model_name, plot_r2
from fuson_plm.benchmarking.idr_prediction.config import TRAIN
from fuson_plm.utils.logging import get_local_time, open_logfile, log_update, print_configpy

def check_env_variables():
      log_update("\nChecking on environment variables...")
      log_update(f"\tCUDA_VISIBLE_DEVICES: {os.environ.get('CUDA_VISIBLE_DEVICES')}")
      log_update(f"\ttorch.cuda.device_count(): {torch.cuda.device_count()}")
      for i in range(torch.cuda.device_count()):
            log_update(f"\t\tDevice {i}: {torch.cuda.get_device_name(i)}")

def grid_search_idr_predictor(embedding_path, details, output_dir, param_grid, idr_property, overwrite_saved_model=True):
      # prepare the grid search
      grid = ParameterGrid(param_grid)

      # initialize dict
      training_hyperparams = {
            "learning_rate": None,
            "batch_size": None
      }

      for params in grid:
            # Update hyperparameters
            training_hyperparams.update(params)
            log_update(f"\nHyperparams:{training_hyperparams}")
            # check if we actually need to train a model
            train_new_model, model_full_path = check_for_trained_model(details,
                                                                        training_hyperparams,
                                                                        idr_property, 
                                                                        TRAIN.PERMISSION_TO_OVERWRITE_MODELS)
            #train model
            if train_new_model:
                  model = ProteinMLPESM()
                  log_update("Initialized new model")
                  # load the splits with labels
                  train_df = pd.read_csv(f"splits/{idr_property}/train_df.csv")
                  val_df = pd.read_csv(f"splits/{idr_property}/val_df.csv")
                  test_df = pd.read_csv(f"splits/{idr_property}/test_df.csv")

                  log_update(f"\nSplit breakdown...\n\t{len(train_df)} train seqs\n\t{len(val_df)} val seqs\n\t{len(test_df)} test seqs")
                  
                  # load the embeddings 
                  with open(embedding_path, 'rb') as f:
                        combined_embeddings = pickle.load(f)
                        
                  # define the data module
                  data_module = IDRDataModule(
                        train_df=train_df, 
                        val_df=val_df, 
                        test_df=test_df, 
                        combined_embeddings=combined_embeddings,
                        idr_property=idr_property,
                        batch_size=training_hyperparams["batch_size"])
                  log_update("Initialized IDRDataModule")
                  
                  log_update("Training and evaluating...")
                  train_and_evaluate(model, data_module, model_full_path, idr_property, output_dir)
            
            # even if not training a new model, pull the old results
            else:
                  # write the results to the results folder anyway
                  r2_folder = model_full_path.split('/best-checkpoint.ckpt')[0]
                  r2_results = pd.read_csv(f"{r2_folder}/{idr_property}_r2.csv")
                  # write to the results folder
            
                  test_r2_combined_results_csv_path = f'{output_dir}/{idr_property}_hyperparam_screen_test_r2.csv'
                  if not(os.path.exists(test_r2_combined_results_csv_path)):
                        r2_results.to_csv(test_r2_combined_results_csv_path,index=False)
                  else:
                        r2_results.to_csv(test_r2_combined_results_csv_path,mode='a',index=False,header=False)
                  
      return model_full_path

def find_best_hyperparams(output_dir, idr_properties):
    # read all the inputs 
    for idr_property in idr_properties:
          df = pd.read_csv(f"{output_dir}/{idr_property}_hyperparam_screen_test_r2.csv")
          # string starts like trained_models/asph/fuson_plm/
          df['model_type'] = df['path_to_model'].apply(lambda x: x.split('/')[2])
          df = df.sort_values(by=['model_type','r2'],ascending=[True,False]).reset_index(drop=True)
          df = df.drop_duplicates(subset='model_type').reset_index(drop=True)
          df.to_csv(f"{output_dir}/{idr_property}_best_test_r2.csv", index=False)
      
def check_for_trained_model(details, training_hyperparams, idr_property, overwrite_saved_model):
      # unpack the details dictioanry
      benchmark_model_type = details['model_type']
      benchmark_model_name = details['model']
      benchmark_model_epoch = details['epoch']
      
      # define model save directories and make if they don't exist
      os.makedirs(f"trained_models/{idr_property}",exist_ok=True)
      model_outer_folder = f"trained_models/{idr_property}/{benchmark_model_type}"
      if not(np.isnan(benchmark_model_epoch)): model_outer_folder+=f"/{benchmark_model_name}/epoch{benchmark_model_epoch}"
      model_full_folder=f"{model_outer_folder}/lr{training_hyperparams['learning_rate']}_bs{training_hyperparams['batch_size']}"
      l_model_full_folder = model_full_folder.split("/")
      for i in range(0,len(l_model_full_folder)):
            newdir="/".join(l_model_full_folder[:i+1])
            os.makedirs(newdir, exist_ok=True)
            
      # see if we've trained the model before 
      model_full_path = f"{model_full_folder}/best-checkpoint.ckpt"
      train_new_model=True       #initially, we believe we're training a new model. Let's make sure we want to. 
      if os.path.exists(model_full_path):
            # If the model exists and we ARE allowed to overwrite, still train
            if overwrite_saved_model:
                  log_update(f"\nOverwriting previously trained model with same hyperparams at {model_full_path}")
            # If the model exists and we are NOT allowed to overwrite, don't train
            else:
                  log_update(f"\nWARNING: this model may already be trained at {model_full_path}. Skipping")
                  train_new_model=False

      return train_new_model, model_full_path
      
def train_and_evaluate(model, data_module, model_save_path, idr_property, output_dir):
      early_stop_callback = EarlyStopping(
                  monitor='val_loss',
                  min_delta=0.1,
                  patience=2,
                  verbose=False,
                  mode='min'
      )
      
      loss_tracker = LossTrackerCallback()
                        
      # Save the best model based on validation loss (or another monitored metric)
      checkpoint_callback = ModelCheckpoint(
            monitor='val_loss',        # Monitor validation loss (or another metric)
            dirpath=model_save_path.split('/best-checkpoint.ckpt')[0],   # Directory to save the model
            filename='best-checkpoint',  # File name for the best model checkpoint
            save_top_k=1,              # Only save the best model
            mode='min'                 # Mode for the monitored metric ('min' or 'max')
      )
      max_epochs = 20
      if idr_property in ['scaled_re','scaled_rg']: max_epochs=50
      trainer = Trainer(
            callbacks=[checkpoint_callback, loss_tracker], #, early_stop_callback, 
            max_epochs=max_epochs,
            check_val_every_n_epoch=1,  # Ensure validation runs once per epoch
            val_check_interval=1.0      # Perform validation only after each training epoch
            )
      
      log_update("\tRunning the training loop...")
      trainer.fit(model, data_module)
      train_losses = loss_tracker.train_losses
      val_losses = loss_tracker.val_losses[1::] # there's an extra at the beginning
      
      # Prepare the data to write to CSV
      data = {
            'Epoch': list(range(1, len(train_losses) + 1)),
            'Train Loss': train_losses,
            'Validation Loss': val_losses
      }

      # Create a DataFrame
      df = pd.DataFrame(data)

      # Write to CSV
      train_loss_individual_results_csv_path = f"{model_save_path.split('/best-checkpoint.ckpt')[0]}/train_val_losses.csv"
      df.to_csv(train_loss_individual_results_csv_path, index=False)
      # Also write to CSV in main output folder
      train_loss_combined_results_csv_path = f'{output_dir}/{idr_property}_hyperparam_screen_train_losses.csv'
      df["path_to_model"] = [model_save_path]*len(df)
      if not(os.path.exists(train_loss_combined_results_csv_path)):
            df.to_csv(train_loss_combined_results_csv_path,index=False)
      else:
            df.to_csv(train_loss_combined_results_csv_path,mode='a',index=False,header=False)
        
      # Load the best model checkpoint before testing
      best_model_path = checkpoint_callback.best_model_path
      log_update(f"\tLoading best model from {best_model_path} for testing...")
      #model = model.load_from_checkpoint(best_model_path)  # Reload the best model
      model = model.__class__.load_from_checkpoint(best_model_path)
      
      #test model
      log_update("\tRunning the testing loop...")
      test_results = trainer.test(model, dataloaders=data_module.test_dataloader())
      test_loss = test_results[0]['test_loss'] if 'test_loss' in test_results[0] else None
      df = pd.DataFrame(data={
            'Test Loss': [test_loss]
      })
      test_loss_individual_results_csv_path = f"{model_save_path.split('/best-checkpoint.ckpt')[0]}/test_loss.csv"
      df.to_csv(test_loss_individual_results_csv_path, index=False)
      test_loss_combined_results_csv_path = f'{output_dir}/{idr_property}_hyperparam_screen_test_losses.csv'
      df['path_to_model'] = [model_save_path]*len(df)
      if not(os.path.exists(test_loss_combined_results_csv_path)):
            df.to_csv(test_loss_combined_results_csv_path,index=False)
      else:
            df.to_csv(test_loss_combined_results_csv_path,mode='a',index=False,header=False)
      
      log_update("\tCalculating R^2...")
      get_test_preds_values(model, data_module, model_save_path, idr_property, output_dir)

def get_test_preds_values(model, data_module, model_save_path, idr_property, output_dir):
      #ensure the model is in evaluation mode
      model.eval()

      #store predictions and actual values
      true_values = []
      predictions = []

      #no gradient
      with torch.no_grad():
            for batch in data_module.test_dataloader():
                  inputs = batch['Protein Input']
                  labels = batch['Dimension']

                  outputs = model(inputs).squeeze(-1)  #run through model

                  #get true values and predictions
                  true_values.extend(labels.cpu().numpy())
                  predictions.extend(outputs.cpu().numpy())

      #calculate the R^2 score
      r2 = r2_score(true_values, predictions)
      log_update(f"R^2 Score: {r2}")
      
      # write the true values and predictions to a CSV
      save_folder = model_save_path.split('/best-checkpoint.ckpt')[0]
      df = pd.DataFrame(data={
            'true_values': true_values,
            'predictions': predictions
      })
      df.to_csv(f"{save_folder}/{idr_property}_test_predictions.csv",index=False)
      
      # write r2 to a csv
      df = pd.DataFrame(
            data={
                  'path_to_model': [model_save_path],
                  'r2': [r2]
            }
      )
      df.to_csv(f"{save_folder}/{idr_property}_r2.csv",index=False)
      test_r2_combined_results_csv_path = f'{output_dir}/{idr_property}_hyperparam_screen_test_r2.csv'
      if not(os.path.exists(test_r2_combined_results_csv_path)):
            df.to_csv(test_r2_combined_results_csv_path,index=False)
      else:
            df.to_csv(test_r2_combined_results_csv_path,mode='a',index=False,header=False)

def main():
      # make output directory for this run
      os.makedirs('results',exist_ok=True)
      output_dir = f'results/{get_local_time()}'
      os.makedirs(output_dir,exist_ok=True)
      
      with open_logfile(f'{output_dir}/idr_prediction_benchmark_log.txt'):
            # print configurations 
            TRAIN.print_config()
            
            # Verify that the environment variables are set correctly 
            check_env_variables()
            
            # make embeddings if needed
            all_embedding_paths = embed_dataset_for_benchmark(
                                    fuson_ckpts=TRAIN.FUSONPLM_CKPTS, 
                                    input_data_path=f"processed_data/all_albatross_seqs_and_properties.csv", 
                                    input_fname=f"albatross_sequences", 
                                    average=True, seq_col='Sequence',
                                    benchmark_fusonplm=TRAIN.BENCHMARK_FUSONPLM, 
                                    benchmark_esm=TRAIN.BENCHMARK_ESM, 
                                    benchmark_fo_puncta_ml=False, 
                                    overwrite=TRAIN.PERMISSION_TO_OVERWRITE_EMBEDDINGS)
            
            # loop through the different tasks
            idr_properties = ['asph','scaled_re','scaled_rg','scaling_exp']
            
            param_grid = {
            'learning_rate': [1e-5, 3e-4, 1e-4, 3e-3, 1e-3],
            'batch_size': [32, 64]
            }
            
            for idr_property in idr_properties:
                  log_update(f"Benchmarking property {idr_property}")
                  
                  log_update("\nTraining and evaluating models")
                  # Set hyperparameters for disorder predictor 
            
                  # loop through the embedding paths and train each one
                  for embedding_path, details in all_embedding_paths.items():
                        log_update(f"\nBenchmarking embeddings at: {embedding_path}")
                        log_update(details)
                              
                        model_full_path = grid_search_idr_predictor(embedding_path, 
                                                   details, output_dir, 
                                                   param_grid, idr_property, overwrite_saved_model=TRAIN.PERMISSION_TO_OVERWRITE_MODELS)
                  
            # find the best grid search performer
            find_best_hyperparams(output_dir, idr_properties)
            for idr_property in idr_properties:
                  # make the R^2 Plots for the BEST one
                  best_results = pd.read_csv(f"{output_dir}/{idr_property}_best_test_r2.csv")
                  model_type_to_path_dict = dict(zip(best_results['model_type'],best_results['path_to_model']))
                  for model_type, path_to_model in model_type_to_path_dict.items():
                        model_preds_folder = path_to_model.split('/best-checkpoint.ckpt')[0]
                        test_preds = pd.read_csv(f"{model_preds_folder}/{idr_property}_test_predictions.csv")
                  
                        # make paths for R^2 plots
                        if not os.path.exists(f"{output_dir}/r2_plots"):
                              os.makedirs(f"{output_dir}/r2_plots")
                        os.makedirs(f"{output_dir}/r2_plots/{idr_property}", exist_ok=True)
                        
                        model_type_dict = {
                              'fuson_plm': 'FusOn-pLM',
                              'esm2_t33_650M_UR50D': 'ESM-2'
                        }
                        r2_save_path = f"{output_dir}/r2_plots/{idr_property}/{model_type}_{idr_property}_R2.png"
                        plot_r2(model_type_dict[model_type], idr_property, test_preds, r2_save_path)             
        
if __name__ == "__main__":
    main()