# first few imports, just to set CUDA_VISIBLE_DEVICES before importing any torch libraries from fuson_plm.benchmarking.idr_prediction.config import TRAIN import os os.environ['CUDA_VISIBLE_DEVICES'] = TRAIN.CUDA_VISIBLE_DEVICES import torch import pandas as pd import numpy as np import pickle from sklearn.metrics import r2_score from sklearn.model_selection import ParameterGrid from pytorch_lightning.callbacks.early_stopping import EarlyStopping from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning import Trainer from fuson_plm.benchmarking.embed import embed_dataset_for_benchmark from fuson_plm.benchmarking.idr_prediction.model import ProteinMLPOneHot, ProteinMLPESM, LossTrackerCallback from fuson_plm.benchmarking.idr_prediction.utils import IDRProtDataset, IDRDataModule from fuson_plm.benchmarking.idr_prediction.plot import lengthen_model_name, plot_r2 from fuson_plm.benchmarking.idr_prediction.config import TRAIN from fuson_plm.utils.logging import get_local_time, open_logfile, log_update, print_configpy def check_env_variables(): log_update("\nChecking on environment variables...") log_update(f"\tCUDA_VISIBLE_DEVICES: {os.environ.get('CUDA_VISIBLE_DEVICES')}") log_update(f"\ttorch.cuda.device_count(): {torch.cuda.device_count()}") for i in range(torch.cuda.device_count()): log_update(f"\t\tDevice {i}: {torch.cuda.get_device_name(i)}") def grid_search_idr_predictor(embedding_path, details, output_dir, param_grid, idr_property, overwrite_saved_model=True): # prepare the grid search grid = ParameterGrid(param_grid) # initialize dict training_hyperparams = { "learning_rate": None, "batch_size": None } for params in grid: # Update hyperparameters training_hyperparams.update(params) log_update(f"\nHyperparams:{training_hyperparams}") # check if we actually need to train a model train_new_model, model_full_path = check_for_trained_model(details, training_hyperparams, idr_property, TRAIN.PERMISSION_TO_OVERWRITE_MODELS) #train model if train_new_model: model = ProteinMLPESM() log_update("Initialized new model") # load the splits with labels train_df = pd.read_csv(f"splits/{idr_property}/train_df.csv") val_df = pd.read_csv(f"splits/{idr_property}/val_df.csv") test_df = pd.read_csv(f"splits/{idr_property}/test_df.csv") log_update(f"\nSplit breakdown...\n\t{len(train_df)} train seqs\n\t{len(val_df)} val seqs\n\t{len(test_df)} test seqs") # load the embeddings with open(embedding_path, 'rb') as f: combined_embeddings = pickle.load(f) # define the data module data_module = IDRDataModule( train_df=train_df, val_df=val_df, test_df=test_df, combined_embeddings=combined_embeddings, idr_property=idr_property, batch_size=training_hyperparams["batch_size"]) log_update("Initialized IDRDataModule") log_update("Training and evaluating...") train_and_evaluate(model, data_module, model_full_path, idr_property, output_dir) # even if not training a new model, pull the old results else: # write the results to the results folder anyway r2_folder = model_full_path.split('/best-checkpoint.ckpt')[0] r2_results = pd.read_csv(f"{r2_folder}/{idr_property}_r2.csv") # write to the results folder test_r2_combined_results_csv_path = f'{output_dir}/{idr_property}_hyperparam_screen_test_r2.csv' if not(os.path.exists(test_r2_combined_results_csv_path)): r2_results.to_csv(test_r2_combined_results_csv_path,index=False) else: r2_results.to_csv(test_r2_combined_results_csv_path,mode='a',index=False,header=False) return model_full_path def find_best_hyperparams(output_dir, idr_properties): # read all the inputs for idr_property in idr_properties: df = pd.read_csv(f"{output_dir}/{idr_property}_hyperparam_screen_test_r2.csv") # string starts like trained_models/asph/fuson_plm/ df['model_type'] = df['path_to_model'].apply(lambda x: x.split('/')[2]) df = df.sort_values(by=['model_type','r2'],ascending=[True,False]).reset_index(drop=True) df = df.drop_duplicates(subset='model_type').reset_index(drop=True) df.to_csv(f"{output_dir}/{idr_property}_best_test_r2.csv", index=False) def check_for_trained_model(details, training_hyperparams, idr_property, overwrite_saved_model): # unpack the details dictioanry benchmark_model_type = details['model_type'] benchmark_model_name = details['model'] benchmark_model_epoch = details['epoch'] # define model save directories and make if they don't exist os.makedirs(f"trained_models/{idr_property}",exist_ok=True) model_outer_folder = f"trained_models/{idr_property}/{benchmark_model_type}" if not(np.isnan(benchmark_model_epoch)): model_outer_folder+=f"/{benchmark_model_name}/epoch{benchmark_model_epoch}" model_full_folder=f"{model_outer_folder}/lr{training_hyperparams['learning_rate']}_bs{training_hyperparams['batch_size']}" l_model_full_folder = model_full_folder.split("/") for i in range(0,len(l_model_full_folder)): newdir="/".join(l_model_full_folder[:i+1]) os.makedirs(newdir, exist_ok=True) # see if we've trained the model before model_full_path = f"{model_full_folder}/best-checkpoint.ckpt" train_new_model=True #initially, we believe we're training a new model. Let's make sure we want to. if os.path.exists(model_full_path): # If the model exists and we ARE allowed to overwrite, still train if overwrite_saved_model: log_update(f"\nOverwriting previously trained model with same hyperparams at {model_full_path}") # If the model exists and we are NOT allowed to overwrite, don't train else: log_update(f"\nWARNING: this model may already be trained at {model_full_path}. Skipping") train_new_model=False return train_new_model, model_full_path def train_and_evaluate(model, data_module, model_save_path, idr_property, output_dir): early_stop_callback = EarlyStopping( monitor='val_loss', min_delta=0.1, patience=2, verbose=False, mode='min' ) loss_tracker = LossTrackerCallback() # Save the best model based on validation loss (or another monitored metric) checkpoint_callback = ModelCheckpoint( monitor='val_loss', # Monitor validation loss (or another metric) dirpath=model_save_path.split('/best-checkpoint.ckpt')[0], # Directory to save the model filename='best-checkpoint', # File name for the best model checkpoint save_top_k=1, # Only save the best model mode='min' # Mode for the monitored metric ('min' or 'max') ) max_epochs = 20 if idr_property in ['scaled_re','scaled_rg']: max_epochs=50 trainer = Trainer( callbacks=[checkpoint_callback, loss_tracker], #, early_stop_callback, max_epochs=max_epochs, check_val_every_n_epoch=1, # Ensure validation runs once per epoch val_check_interval=1.0 # Perform validation only after each training epoch ) log_update("\tRunning the training loop...") trainer.fit(model, data_module) train_losses = loss_tracker.train_losses val_losses = loss_tracker.val_losses[1::] # there's an extra at the beginning # Prepare the data to write to CSV data = { 'Epoch': list(range(1, len(train_losses) + 1)), 'Train Loss': train_losses, 'Validation Loss': val_losses } # Create a DataFrame df = pd.DataFrame(data) # Write to CSV train_loss_individual_results_csv_path = f"{model_save_path.split('/best-checkpoint.ckpt')[0]}/train_val_losses.csv" df.to_csv(train_loss_individual_results_csv_path, index=False) # Also write to CSV in main output folder train_loss_combined_results_csv_path = f'{output_dir}/{idr_property}_hyperparam_screen_train_losses.csv' df["path_to_model"] = [model_save_path]*len(df) if not(os.path.exists(train_loss_combined_results_csv_path)): df.to_csv(train_loss_combined_results_csv_path,index=False) else: df.to_csv(train_loss_combined_results_csv_path,mode='a',index=False,header=False) # Load the best model checkpoint before testing best_model_path = checkpoint_callback.best_model_path log_update(f"\tLoading best model from {best_model_path} for testing...") #model = model.load_from_checkpoint(best_model_path) # Reload the best model model = model.__class__.load_from_checkpoint(best_model_path) #test model log_update("\tRunning the testing loop...") test_results = trainer.test(model, dataloaders=data_module.test_dataloader()) test_loss = test_results[0]['test_loss'] if 'test_loss' in test_results[0] else None df = pd.DataFrame(data={ 'Test Loss': [test_loss] }) test_loss_individual_results_csv_path = f"{model_save_path.split('/best-checkpoint.ckpt')[0]}/test_loss.csv" df.to_csv(test_loss_individual_results_csv_path, index=False) test_loss_combined_results_csv_path = f'{output_dir}/{idr_property}_hyperparam_screen_test_losses.csv' df['path_to_model'] = [model_save_path]*len(df) if not(os.path.exists(test_loss_combined_results_csv_path)): df.to_csv(test_loss_combined_results_csv_path,index=False) else: df.to_csv(test_loss_combined_results_csv_path,mode='a',index=False,header=False) log_update("\tCalculating R^2...") get_test_preds_values(model, data_module, model_save_path, idr_property, output_dir) def get_test_preds_values(model, data_module, model_save_path, idr_property, output_dir): #ensure the model is in evaluation mode model.eval() #store predictions and actual values true_values = [] predictions = [] #no gradient with torch.no_grad(): for batch in data_module.test_dataloader(): inputs = batch['Protein Input'] labels = batch['Dimension'] outputs = model(inputs).squeeze(-1) #run through model #get true values and predictions true_values.extend(labels.cpu().numpy()) predictions.extend(outputs.cpu().numpy()) #calculate the R^2 score r2 = r2_score(true_values, predictions) log_update(f"R^2 Score: {r2}") # write the true values and predictions to a CSV save_folder = model_save_path.split('/best-checkpoint.ckpt')[0] df = pd.DataFrame(data={ 'true_values': true_values, 'predictions': predictions }) df.to_csv(f"{save_folder}/{idr_property}_test_predictions.csv",index=False) # write r2 to a csv df = pd.DataFrame( data={ 'path_to_model': [model_save_path], 'r2': [r2] } ) df.to_csv(f"{save_folder}/{idr_property}_r2.csv",index=False) test_r2_combined_results_csv_path = f'{output_dir}/{idr_property}_hyperparam_screen_test_r2.csv' if not(os.path.exists(test_r2_combined_results_csv_path)): df.to_csv(test_r2_combined_results_csv_path,index=False) else: df.to_csv(test_r2_combined_results_csv_path,mode='a',index=False,header=False) def main(): # make output directory for this run os.makedirs('results',exist_ok=True) output_dir = f'results/{get_local_time()}' os.makedirs(output_dir,exist_ok=True) with open_logfile(f'{output_dir}/idr_prediction_benchmark_log.txt'): # print configurations TRAIN.print_config() # Verify that the environment variables are set correctly check_env_variables() # make embeddings if needed all_embedding_paths = embed_dataset_for_benchmark( fuson_ckpts=TRAIN.FUSONPLM_CKPTS, input_data_path=f"processed_data/all_albatross_seqs_and_properties.csv", input_fname=f"albatross_sequences", average=True, seq_col='Sequence', benchmark_fusonplm=TRAIN.BENCHMARK_FUSONPLM, benchmark_esm=TRAIN.BENCHMARK_ESM, benchmark_fo_puncta_ml=False, overwrite=TRAIN.PERMISSION_TO_OVERWRITE_EMBEDDINGS) # loop through the different tasks idr_properties = ['asph','scaled_re','scaled_rg','scaling_exp'] param_grid = { 'learning_rate': [1e-5, 3e-4, 1e-4, 3e-3, 1e-3], 'batch_size': [32, 64] } for idr_property in idr_properties: log_update(f"Benchmarking property {idr_property}") log_update("\nTraining and evaluating models") # Set hyperparameters for disorder predictor # loop through the embedding paths and train each one for embedding_path, details in all_embedding_paths.items(): log_update(f"\nBenchmarking embeddings at: {embedding_path}") log_update(details) model_full_path = grid_search_idr_predictor(embedding_path, details, output_dir, param_grid, idr_property, overwrite_saved_model=TRAIN.PERMISSION_TO_OVERWRITE_MODELS) # find the best grid search performer find_best_hyperparams(output_dir, idr_properties) for idr_property in idr_properties: # make the R^2 Plots for the BEST one best_results = pd.read_csv(f"{output_dir}/{idr_property}_best_test_r2.csv") model_type_to_path_dict = dict(zip(best_results['model_type'],best_results['path_to_model'])) for model_type, path_to_model in model_type_to_path_dict.items(): model_preds_folder = path_to_model.split('/best-checkpoint.ckpt')[0] test_preds = pd.read_csv(f"{model_preds_folder}/{idr_property}_test_predictions.csv") # make paths for R^2 plots if not os.path.exists(f"{output_dir}/r2_plots"): os.makedirs(f"{output_dir}/r2_plots") os.makedirs(f"{output_dir}/r2_plots/{idr_property}", exist_ok=True) model_type_dict = { 'fuson_plm': 'FusOn-pLM', 'esm2_t33_650M_UR50D': 'ESM-2' } r2_save_path = f"{output_dir}/r2_plots/{idr_property}/{model_type}_{idr_property}_R2.png" plot_r2(model_type_dict[model_type], idr_property, test_preds, r2_save_path) if __name__ == "__main__": main()