|
|
|
from fuson_plm.benchmarking.idr_prediction.config import TRAIN |
|
import os |
|
os.environ['CUDA_VISIBLE_DEVICES'] = TRAIN.CUDA_VISIBLE_DEVICES |
|
|
|
import torch |
|
import pandas as pd |
|
import numpy as np |
|
import pickle |
|
from sklearn.metrics import r2_score |
|
from sklearn.model_selection import ParameterGrid |
|
from pytorch_lightning.callbacks.early_stopping import EarlyStopping |
|
from pytorch_lightning.callbacks import ModelCheckpoint |
|
from pytorch_lightning import Trainer |
|
|
|
from fuson_plm.benchmarking.embed import embed_dataset_for_benchmark |
|
from fuson_plm.benchmarking.idr_prediction.model import ProteinMLPOneHot, ProteinMLPESM, LossTrackerCallback |
|
from fuson_plm.benchmarking.idr_prediction.utils import IDRProtDataset, IDRDataModule |
|
from fuson_plm.benchmarking.idr_prediction.plot import lengthen_model_name, plot_r2 |
|
from fuson_plm.benchmarking.idr_prediction.config import TRAIN |
|
from fuson_plm.utils.logging import get_local_time, open_logfile, log_update, print_configpy |
|
|
|
def check_env_variables(): |
|
log_update("\nChecking on environment variables...") |
|
log_update(f"\tCUDA_VISIBLE_DEVICES: {os.environ.get('CUDA_VISIBLE_DEVICES')}") |
|
log_update(f"\ttorch.cuda.device_count(): {torch.cuda.device_count()}") |
|
for i in range(torch.cuda.device_count()): |
|
log_update(f"\t\tDevice {i}: {torch.cuda.get_device_name(i)}") |
|
|
|
def grid_search_idr_predictor(embedding_path, details, output_dir, param_grid, idr_property, overwrite_saved_model=True): |
|
|
|
grid = ParameterGrid(param_grid) |
|
|
|
|
|
training_hyperparams = { |
|
"learning_rate": None, |
|
"batch_size": None |
|
} |
|
|
|
for params in grid: |
|
|
|
training_hyperparams.update(params) |
|
log_update(f"\nHyperparams:{training_hyperparams}") |
|
|
|
train_new_model, model_full_path = check_for_trained_model(details, |
|
training_hyperparams, |
|
idr_property, |
|
TRAIN.PERMISSION_TO_OVERWRITE_MODELS) |
|
|
|
if train_new_model: |
|
model = ProteinMLPESM() |
|
log_update("Initialized new model") |
|
|
|
train_df = pd.read_csv(f"splits/{idr_property}/train_df.csv") |
|
val_df = pd.read_csv(f"splits/{idr_property}/val_df.csv") |
|
test_df = pd.read_csv(f"splits/{idr_property}/test_df.csv") |
|
|
|
log_update(f"\nSplit breakdown...\n\t{len(train_df)} train seqs\n\t{len(val_df)} val seqs\n\t{len(test_df)} test seqs") |
|
|
|
|
|
with open(embedding_path, 'rb') as f: |
|
combined_embeddings = pickle.load(f) |
|
|
|
|
|
data_module = IDRDataModule( |
|
train_df=train_df, |
|
val_df=val_df, |
|
test_df=test_df, |
|
combined_embeddings=combined_embeddings, |
|
idr_property=idr_property, |
|
batch_size=training_hyperparams["batch_size"]) |
|
log_update("Initialized IDRDataModule") |
|
|
|
log_update("Training and evaluating...") |
|
train_and_evaluate(model, data_module, model_full_path, idr_property, output_dir) |
|
|
|
|
|
else: |
|
|
|
r2_folder = model_full_path.split('/best-checkpoint.ckpt')[0] |
|
r2_results = pd.read_csv(f"{r2_folder}/{idr_property}_r2.csv") |
|
|
|
|
|
test_r2_combined_results_csv_path = f'{output_dir}/{idr_property}_hyperparam_screen_test_r2.csv' |
|
if not(os.path.exists(test_r2_combined_results_csv_path)): |
|
r2_results.to_csv(test_r2_combined_results_csv_path,index=False) |
|
else: |
|
r2_results.to_csv(test_r2_combined_results_csv_path,mode='a',index=False,header=False) |
|
|
|
return model_full_path |
|
|
|
def find_best_hyperparams(output_dir, idr_properties): |
|
|
|
for idr_property in idr_properties: |
|
df = pd.read_csv(f"{output_dir}/{idr_property}_hyperparam_screen_test_r2.csv") |
|
|
|
df['model_type'] = df['path_to_model'].apply(lambda x: x.split('/')[2]) |
|
df = df.sort_values(by=['model_type','r2'],ascending=[True,False]).reset_index(drop=True) |
|
df = df.drop_duplicates(subset='model_type').reset_index(drop=True) |
|
df.to_csv(f"{output_dir}/{idr_property}_best_test_r2.csv", index=False) |
|
|
|
def check_for_trained_model(details, training_hyperparams, idr_property, overwrite_saved_model): |
|
|
|
benchmark_model_type = details['model_type'] |
|
benchmark_model_name = details['model'] |
|
benchmark_model_epoch = details['epoch'] |
|
|
|
|
|
os.makedirs(f"trained_models/{idr_property}",exist_ok=True) |
|
model_outer_folder = f"trained_models/{idr_property}/{benchmark_model_type}" |
|
if not(np.isnan(benchmark_model_epoch)): model_outer_folder+=f"/{benchmark_model_name}/epoch{benchmark_model_epoch}" |
|
model_full_folder=f"{model_outer_folder}/lr{training_hyperparams['learning_rate']}_bs{training_hyperparams['batch_size']}" |
|
l_model_full_folder = model_full_folder.split("/") |
|
for i in range(0,len(l_model_full_folder)): |
|
newdir="/".join(l_model_full_folder[:i+1]) |
|
os.makedirs(newdir, exist_ok=True) |
|
|
|
|
|
model_full_path = f"{model_full_folder}/best-checkpoint.ckpt" |
|
train_new_model=True |
|
if os.path.exists(model_full_path): |
|
|
|
if overwrite_saved_model: |
|
log_update(f"\nOverwriting previously trained model with same hyperparams at {model_full_path}") |
|
|
|
else: |
|
log_update(f"\nWARNING: this model may already be trained at {model_full_path}. Skipping") |
|
train_new_model=False |
|
|
|
return train_new_model, model_full_path |
|
|
|
def train_and_evaluate(model, data_module, model_save_path, idr_property, output_dir): |
|
early_stop_callback = EarlyStopping( |
|
monitor='val_loss', |
|
min_delta=0.1, |
|
patience=2, |
|
verbose=False, |
|
mode='min' |
|
) |
|
|
|
loss_tracker = LossTrackerCallback() |
|
|
|
|
|
checkpoint_callback = ModelCheckpoint( |
|
monitor='val_loss', |
|
dirpath=model_save_path.split('/best-checkpoint.ckpt')[0], |
|
filename='best-checkpoint', |
|
save_top_k=1, |
|
mode='min' |
|
) |
|
max_epochs = 20 |
|
if idr_property in ['scaled_re','scaled_rg']: max_epochs=50 |
|
trainer = Trainer( |
|
callbacks=[checkpoint_callback, loss_tracker], |
|
max_epochs=max_epochs, |
|
check_val_every_n_epoch=1, |
|
val_check_interval=1.0 |
|
) |
|
|
|
log_update("\tRunning the training loop...") |
|
trainer.fit(model, data_module) |
|
train_losses = loss_tracker.train_losses |
|
val_losses = loss_tracker.val_losses[1::] |
|
|
|
|
|
data = { |
|
'Epoch': list(range(1, len(train_losses) + 1)), |
|
'Train Loss': train_losses, |
|
'Validation Loss': val_losses |
|
} |
|
|
|
|
|
df = pd.DataFrame(data) |
|
|
|
|
|
train_loss_individual_results_csv_path = f"{model_save_path.split('/best-checkpoint.ckpt')[0]}/train_val_losses.csv" |
|
df.to_csv(train_loss_individual_results_csv_path, index=False) |
|
|
|
train_loss_combined_results_csv_path = f'{output_dir}/{idr_property}_hyperparam_screen_train_losses.csv' |
|
df["path_to_model"] = [model_save_path]*len(df) |
|
if not(os.path.exists(train_loss_combined_results_csv_path)): |
|
df.to_csv(train_loss_combined_results_csv_path,index=False) |
|
else: |
|
df.to_csv(train_loss_combined_results_csv_path,mode='a',index=False,header=False) |
|
|
|
|
|
best_model_path = checkpoint_callback.best_model_path |
|
log_update(f"\tLoading best model from {best_model_path} for testing...") |
|
|
|
model = model.__class__.load_from_checkpoint(best_model_path) |
|
|
|
|
|
log_update("\tRunning the testing loop...") |
|
test_results = trainer.test(model, dataloaders=data_module.test_dataloader()) |
|
test_loss = test_results[0]['test_loss'] if 'test_loss' in test_results[0] else None |
|
df = pd.DataFrame(data={ |
|
'Test Loss': [test_loss] |
|
}) |
|
test_loss_individual_results_csv_path = f"{model_save_path.split('/best-checkpoint.ckpt')[0]}/test_loss.csv" |
|
df.to_csv(test_loss_individual_results_csv_path, index=False) |
|
test_loss_combined_results_csv_path = f'{output_dir}/{idr_property}_hyperparam_screen_test_losses.csv' |
|
df['path_to_model'] = [model_save_path]*len(df) |
|
if not(os.path.exists(test_loss_combined_results_csv_path)): |
|
df.to_csv(test_loss_combined_results_csv_path,index=False) |
|
else: |
|
df.to_csv(test_loss_combined_results_csv_path,mode='a',index=False,header=False) |
|
|
|
log_update("\tCalculating R^2...") |
|
get_test_preds_values(model, data_module, model_save_path, idr_property, output_dir) |
|
|
|
def get_test_preds_values(model, data_module, model_save_path, idr_property, output_dir): |
|
|
|
model.eval() |
|
|
|
|
|
true_values = [] |
|
predictions = [] |
|
|
|
|
|
with torch.no_grad(): |
|
for batch in data_module.test_dataloader(): |
|
inputs = batch['Protein Input'] |
|
labels = batch['Dimension'] |
|
|
|
outputs = model(inputs).squeeze(-1) |
|
|
|
|
|
true_values.extend(labels.cpu().numpy()) |
|
predictions.extend(outputs.cpu().numpy()) |
|
|
|
|
|
r2 = r2_score(true_values, predictions) |
|
log_update(f"R^2 Score: {r2}") |
|
|
|
|
|
save_folder = model_save_path.split('/best-checkpoint.ckpt')[0] |
|
df = pd.DataFrame(data={ |
|
'true_values': true_values, |
|
'predictions': predictions |
|
}) |
|
df.to_csv(f"{save_folder}/{idr_property}_test_predictions.csv",index=False) |
|
|
|
|
|
df = pd.DataFrame( |
|
data={ |
|
'path_to_model': [model_save_path], |
|
'r2': [r2] |
|
} |
|
) |
|
df.to_csv(f"{save_folder}/{idr_property}_r2.csv",index=False) |
|
test_r2_combined_results_csv_path = f'{output_dir}/{idr_property}_hyperparam_screen_test_r2.csv' |
|
if not(os.path.exists(test_r2_combined_results_csv_path)): |
|
df.to_csv(test_r2_combined_results_csv_path,index=False) |
|
else: |
|
df.to_csv(test_r2_combined_results_csv_path,mode='a',index=False,header=False) |
|
|
|
def main(): |
|
|
|
os.makedirs('results',exist_ok=True) |
|
output_dir = f'results/{get_local_time()}' |
|
os.makedirs(output_dir,exist_ok=True) |
|
|
|
with open_logfile(f'{output_dir}/idr_prediction_benchmark_log.txt'): |
|
|
|
TRAIN.print_config() |
|
|
|
|
|
check_env_variables() |
|
|
|
|
|
all_embedding_paths = embed_dataset_for_benchmark( |
|
fuson_ckpts=TRAIN.FUSONPLM_CKPTS, |
|
input_data_path=f"processed_data/all_albatross_seqs_and_properties.csv", |
|
input_fname=f"albatross_sequences", |
|
average=True, seq_col='Sequence', |
|
benchmark_fusonplm=TRAIN.BENCHMARK_FUSONPLM, |
|
benchmark_esm=TRAIN.BENCHMARK_ESM, |
|
benchmark_fo_puncta_ml=False, |
|
overwrite=TRAIN.PERMISSION_TO_OVERWRITE_EMBEDDINGS) |
|
|
|
|
|
idr_properties = ['asph','scaled_re','scaled_rg','scaling_exp'] |
|
|
|
param_grid = { |
|
'learning_rate': [1e-5, 3e-4, 1e-4, 3e-3, 1e-3], |
|
'batch_size': [32, 64] |
|
} |
|
|
|
for idr_property in idr_properties: |
|
log_update(f"Benchmarking property {idr_property}") |
|
|
|
log_update("\nTraining and evaluating models") |
|
|
|
|
|
|
|
for embedding_path, details in all_embedding_paths.items(): |
|
log_update(f"\nBenchmarking embeddings at: {embedding_path}") |
|
log_update(details) |
|
|
|
model_full_path = grid_search_idr_predictor(embedding_path, |
|
details, output_dir, |
|
param_grid, idr_property, overwrite_saved_model=TRAIN.PERMISSION_TO_OVERWRITE_MODELS) |
|
|
|
|
|
find_best_hyperparams(output_dir, idr_properties) |
|
for idr_property in idr_properties: |
|
|
|
best_results = pd.read_csv(f"{output_dir}/{idr_property}_best_test_r2.csv") |
|
model_type_to_path_dict = dict(zip(best_results['model_type'],best_results['path_to_model'])) |
|
for model_type, path_to_model in model_type_to_path_dict.items(): |
|
model_preds_folder = path_to_model.split('/best-checkpoint.ckpt')[0] |
|
test_preds = pd.read_csv(f"{model_preds_folder}/{idr_property}_test_predictions.csv") |
|
|
|
|
|
if not os.path.exists(f"{output_dir}/r2_plots"): |
|
os.makedirs(f"{output_dir}/r2_plots") |
|
os.makedirs(f"{output_dir}/r2_plots/{idr_property}", exist_ok=True) |
|
|
|
model_type_dict = { |
|
'fuson_plm': 'FusOn-pLM', |
|
'esm2_t33_650M_UR50D': 'ESM-2' |
|
} |
|
r2_save_path = f"{output_dir}/r2_plots/{idr_property}/{model_type}_{idr_property}_R2.png" |
|
plot_r2(model_type_dict[model_type], idr_property, test_preds, r2_save_path) |
|
|
|
if __name__ == "__main__": |
|
main() |