svincoff's picture
fixed READMEs and added IDR Prediction benchmark
e048d40
# first few imports, just to set CUDA_VISIBLE_DEVICES before importing any torch libraries
from fuson_plm.benchmarking.idr_prediction.config import TRAIN
import os
os.environ['CUDA_VISIBLE_DEVICES'] = TRAIN.CUDA_VISIBLE_DEVICES
import torch
import pandas as pd
import numpy as np
import pickle
from sklearn.metrics import r2_score
from sklearn.model_selection import ParameterGrid
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning import Trainer
from fuson_plm.benchmarking.embed import embed_dataset_for_benchmark
from fuson_plm.benchmarking.idr_prediction.model import ProteinMLPOneHot, ProteinMLPESM, LossTrackerCallback
from fuson_plm.benchmarking.idr_prediction.utils import IDRProtDataset, IDRDataModule
from fuson_plm.benchmarking.idr_prediction.plot import lengthen_model_name, plot_r2
from fuson_plm.benchmarking.idr_prediction.config import TRAIN
from fuson_plm.utils.logging import get_local_time, open_logfile, log_update, print_configpy
def check_env_variables():
log_update("\nChecking on environment variables...")
log_update(f"\tCUDA_VISIBLE_DEVICES: {os.environ.get('CUDA_VISIBLE_DEVICES')}")
log_update(f"\ttorch.cuda.device_count(): {torch.cuda.device_count()}")
for i in range(torch.cuda.device_count()):
log_update(f"\t\tDevice {i}: {torch.cuda.get_device_name(i)}")
def grid_search_idr_predictor(embedding_path, details, output_dir, param_grid, idr_property, overwrite_saved_model=True):
# prepare the grid search
grid = ParameterGrid(param_grid)
# initialize dict
training_hyperparams = {
"learning_rate": None,
"batch_size": None
}
for params in grid:
# Update hyperparameters
training_hyperparams.update(params)
log_update(f"\nHyperparams:{training_hyperparams}")
# check if we actually need to train a model
train_new_model, model_full_path = check_for_trained_model(details,
training_hyperparams,
idr_property,
TRAIN.PERMISSION_TO_OVERWRITE_MODELS)
#train model
if train_new_model:
model = ProteinMLPESM()
log_update("Initialized new model")
# load the splits with labels
train_df = pd.read_csv(f"splits/{idr_property}/train_df.csv")
val_df = pd.read_csv(f"splits/{idr_property}/val_df.csv")
test_df = pd.read_csv(f"splits/{idr_property}/test_df.csv")
log_update(f"\nSplit breakdown...\n\t{len(train_df)} train seqs\n\t{len(val_df)} val seqs\n\t{len(test_df)} test seqs")
# load the embeddings
with open(embedding_path, 'rb') as f:
combined_embeddings = pickle.load(f)
# define the data module
data_module = IDRDataModule(
train_df=train_df,
val_df=val_df,
test_df=test_df,
combined_embeddings=combined_embeddings,
idr_property=idr_property,
batch_size=training_hyperparams["batch_size"])
log_update("Initialized IDRDataModule")
log_update("Training and evaluating...")
train_and_evaluate(model, data_module, model_full_path, idr_property, output_dir)
# even if not training a new model, pull the old results
else:
# write the results to the results folder anyway
r2_folder = model_full_path.split('/best-checkpoint.ckpt')[0]
r2_results = pd.read_csv(f"{r2_folder}/{idr_property}_r2.csv")
# write to the results folder
test_r2_combined_results_csv_path = f'{output_dir}/{idr_property}_hyperparam_screen_test_r2.csv'
if not(os.path.exists(test_r2_combined_results_csv_path)):
r2_results.to_csv(test_r2_combined_results_csv_path,index=False)
else:
r2_results.to_csv(test_r2_combined_results_csv_path,mode='a',index=False,header=False)
return model_full_path
def find_best_hyperparams(output_dir, idr_properties):
# read all the inputs
for idr_property in idr_properties:
df = pd.read_csv(f"{output_dir}/{idr_property}_hyperparam_screen_test_r2.csv")
# string starts like trained_models/asph/fuson_plm/
df['model_type'] = df['path_to_model'].apply(lambda x: x.split('/')[2])
df = df.sort_values(by=['model_type','r2'],ascending=[True,False]).reset_index(drop=True)
df = df.drop_duplicates(subset='model_type').reset_index(drop=True)
df.to_csv(f"{output_dir}/{idr_property}_best_test_r2.csv", index=False)
def check_for_trained_model(details, training_hyperparams, idr_property, overwrite_saved_model):
# unpack the details dictioanry
benchmark_model_type = details['model_type']
benchmark_model_name = details['model']
benchmark_model_epoch = details['epoch']
# define model save directories and make if they don't exist
os.makedirs(f"trained_models/{idr_property}",exist_ok=True)
model_outer_folder = f"trained_models/{idr_property}/{benchmark_model_type}"
if not(np.isnan(benchmark_model_epoch)): model_outer_folder+=f"/{benchmark_model_name}/epoch{benchmark_model_epoch}"
model_full_folder=f"{model_outer_folder}/lr{training_hyperparams['learning_rate']}_bs{training_hyperparams['batch_size']}"
l_model_full_folder = model_full_folder.split("/")
for i in range(0,len(l_model_full_folder)):
newdir="/".join(l_model_full_folder[:i+1])
os.makedirs(newdir, exist_ok=True)
# see if we've trained the model before
model_full_path = f"{model_full_folder}/best-checkpoint.ckpt"
train_new_model=True #initially, we believe we're training a new model. Let's make sure we want to.
if os.path.exists(model_full_path):
# If the model exists and we ARE allowed to overwrite, still train
if overwrite_saved_model:
log_update(f"\nOverwriting previously trained model with same hyperparams at {model_full_path}")
# If the model exists and we are NOT allowed to overwrite, don't train
else:
log_update(f"\nWARNING: this model may already be trained at {model_full_path}. Skipping")
train_new_model=False
return train_new_model, model_full_path
def train_and_evaluate(model, data_module, model_save_path, idr_property, output_dir):
early_stop_callback = EarlyStopping(
monitor='val_loss',
min_delta=0.1,
patience=2,
verbose=False,
mode='min'
)
loss_tracker = LossTrackerCallback()
# Save the best model based on validation loss (or another monitored metric)
checkpoint_callback = ModelCheckpoint(
monitor='val_loss', # Monitor validation loss (or another metric)
dirpath=model_save_path.split('/best-checkpoint.ckpt')[0], # Directory to save the model
filename='best-checkpoint', # File name for the best model checkpoint
save_top_k=1, # Only save the best model
mode='min' # Mode for the monitored metric ('min' or 'max')
)
max_epochs = 20
if idr_property in ['scaled_re','scaled_rg']: max_epochs=50
trainer = Trainer(
callbacks=[checkpoint_callback, loss_tracker], #, early_stop_callback,
max_epochs=max_epochs,
check_val_every_n_epoch=1, # Ensure validation runs once per epoch
val_check_interval=1.0 # Perform validation only after each training epoch
)
log_update("\tRunning the training loop...")
trainer.fit(model, data_module)
train_losses = loss_tracker.train_losses
val_losses = loss_tracker.val_losses[1::] # there's an extra at the beginning
# Prepare the data to write to CSV
data = {
'Epoch': list(range(1, len(train_losses) + 1)),
'Train Loss': train_losses,
'Validation Loss': val_losses
}
# Create a DataFrame
df = pd.DataFrame(data)
# Write to CSV
train_loss_individual_results_csv_path = f"{model_save_path.split('/best-checkpoint.ckpt')[0]}/train_val_losses.csv"
df.to_csv(train_loss_individual_results_csv_path, index=False)
# Also write to CSV in main output folder
train_loss_combined_results_csv_path = f'{output_dir}/{idr_property}_hyperparam_screen_train_losses.csv'
df["path_to_model"] = [model_save_path]*len(df)
if not(os.path.exists(train_loss_combined_results_csv_path)):
df.to_csv(train_loss_combined_results_csv_path,index=False)
else:
df.to_csv(train_loss_combined_results_csv_path,mode='a',index=False,header=False)
# Load the best model checkpoint before testing
best_model_path = checkpoint_callback.best_model_path
log_update(f"\tLoading best model from {best_model_path} for testing...")
#model = model.load_from_checkpoint(best_model_path) # Reload the best model
model = model.__class__.load_from_checkpoint(best_model_path)
#test model
log_update("\tRunning the testing loop...")
test_results = trainer.test(model, dataloaders=data_module.test_dataloader())
test_loss = test_results[0]['test_loss'] if 'test_loss' in test_results[0] else None
df = pd.DataFrame(data={
'Test Loss': [test_loss]
})
test_loss_individual_results_csv_path = f"{model_save_path.split('/best-checkpoint.ckpt')[0]}/test_loss.csv"
df.to_csv(test_loss_individual_results_csv_path, index=False)
test_loss_combined_results_csv_path = f'{output_dir}/{idr_property}_hyperparam_screen_test_losses.csv'
df['path_to_model'] = [model_save_path]*len(df)
if not(os.path.exists(test_loss_combined_results_csv_path)):
df.to_csv(test_loss_combined_results_csv_path,index=False)
else:
df.to_csv(test_loss_combined_results_csv_path,mode='a',index=False,header=False)
log_update("\tCalculating R^2...")
get_test_preds_values(model, data_module, model_save_path, idr_property, output_dir)
def get_test_preds_values(model, data_module, model_save_path, idr_property, output_dir):
#ensure the model is in evaluation mode
model.eval()
#store predictions and actual values
true_values = []
predictions = []
#no gradient
with torch.no_grad():
for batch in data_module.test_dataloader():
inputs = batch['Protein Input']
labels = batch['Dimension']
outputs = model(inputs).squeeze(-1) #run through model
#get true values and predictions
true_values.extend(labels.cpu().numpy())
predictions.extend(outputs.cpu().numpy())
#calculate the R^2 score
r2 = r2_score(true_values, predictions)
log_update(f"R^2 Score: {r2}")
# write the true values and predictions to a CSV
save_folder = model_save_path.split('/best-checkpoint.ckpt')[0]
df = pd.DataFrame(data={
'true_values': true_values,
'predictions': predictions
})
df.to_csv(f"{save_folder}/{idr_property}_test_predictions.csv",index=False)
# write r2 to a csv
df = pd.DataFrame(
data={
'path_to_model': [model_save_path],
'r2': [r2]
}
)
df.to_csv(f"{save_folder}/{idr_property}_r2.csv",index=False)
test_r2_combined_results_csv_path = f'{output_dir}/{idr_property}_hyperparam_screen_test_r2.csv'
if not(os.path.exists(test_r2_combined_results_csv_path)):
df.to_csv(test_r2_combined_results_csv_path,index=False)
else:
df.to_csv(test_r2_combined_results_csv_path,mode='a',index=False,header=False)
def main():
# make output directory for this run
os.makedirs('results',exist_ok=True)
output_dir = f'results/{get_local_time()}'
os.makedirs(output_dir,exist_ok=True)
with open_logfile(f'{output_dir}/idr_prediction_benchmark_log.txt'):
# print configurations
TRAIN.print_config()
# Verify that the environment variables are set correctly
check_env_variables()
# make embeddings if needed
all_embedding_paths = embed_dataset_for_benchmark(
fuson_ckpts=TRAIN.FUSONPLM_CKPTS,
input_data_path=f"processed_data/all_albatross_seqs_and_properties.csv",
input_fname=f"albatross_sequences",
average=True, seq_col='Sequence',
benchmark_fusonplm=TRAIN.BENCHMARK_FUSONPLM,
benchmark_esm=TRAIN.BENCHMARK_ESM,
benchmark_fo_puncta_ml=False,
overwrite=TRAIN.PERMISSION_TO_OVERWRITE_EMBEDDINGS)
# loop through the different tasks
idr_properties = ['asph','scaled_re','scaled_rg','scaling_exp']
param_grid = {
'learning_rate': [1e-5, 3e-4, 1e-4, 3e-3, 1e-3],
'batch_size': [32, 64]
}
for idr_property in idr_properties:
log_update(f"Benchmarking property {idr_property}")
log_update("\nTraining and evaluating models")
# Set hyperparameters for disorder predictor
# loop through the embedding paths and train each one
for embedding_path, details in all_embedding_paths.items():
log_update(f"\nBenchmarking embeddings at: {embedding_path}")
log_update(details)
model_full_path = grid_search_idr_predictor(embedding_path,
details, output_dir,
param_grid, idr_property, overwrite_saved_model=TRAIN.PERMISSION_TO_OVERWRITE_MODELS)
# find the best grid search performer
find_best_hyperparams(output_dir, idr_properties)
for idr_property in idr_properties:
# make the R^2 Plots for the BEST one
best_results = pd.read_csv(f"{output_dir}/{idr_property}_best_test_r2.csv")
model_type_to_path_dict = dict(zip(best_results['model_type'],best_results['path_to_model']))
for model_type, path_to_model in model_type_to_path_dict.items():
model_preds_folder = path_to_model.split('/best-checkpoint.ckpt')[0]
test_preds = pd.read_csv(f"{model_preds_folder}/{idr_property}_test_predictions.csv")
# make paths for R^2 plots
if not os.path.exists(f"{output_dir}/r2_plots"):
os.makedirs(f"{output_dir}/r2_plots")
os.makedirs(f"{output_dir}/r2_plots/{idr_property}", exist_ok=True)
model_type_dict = {
'fuson_plm': 'FusOn-pLM',
'esm2_t33_650M_UR50D': 'ESM-2'
}
r2_save_path = f"{output_dir}/r2_plots/{idr_property}/{model_type}_{idr_property}_R2.png"
plot_r2(model_type_dict[model_type], idr_property, test_preds, r2_save_path)
if __name__ == "__main__":
main()