svincoff's picture
dependencies and embedding_exploration benchmark
c43fbc6
import torch
import time
import pandas as pd
import numpy as np
import pickle
import os
from fuson_plm.benchmarking.xgboost_predictor import train_final_predictor, evaluate_predictor
from fuson_plm.benchmarking.embed import embed_dataset_for_benchmark
import fuson_plm.benchmarking.puncta.config as config
from fuson_plm.benchmarking.puncta.plot import make_all_final_bar_charts
from fuson_plm.utils.logging import log_update, open_logfile, print_configpy, get_local_time, CustomParams
def check_splits(df):
# make sure everything has a split
if len(df.loc[df['split'].isna()])>0:
raise Exception("Error: not every benchmarking sequence has been allocated to a split (train or test)")
# make sure the only things are train and test
if len({'train','test'} - set(df['split'].unique()))!=0:
raise Exception("Error: splits column should only have \'train\' and \'test\'.")
# make sure there are no duplicate sequences
if len(df.loc[df['aa_seq'].duplicated()])>0:
raise Exception("Error: duplicate sequences provided")
def train_and_evaluate_puncta_predictor(details, splits_with_embeddings,outdir,task='nucleus',class1_thresh=0.5,n_estimators=50,tree_method="hist"):
"""
task = 'nucleus', 'cytoplasm', or 'formation'
"""
# unpack the details dictioanry
benchmark_model_type = details['model_type']
benchmark_model_name = details['model']
benchmark_model_epoch = details['epoch']
# prepare train and test sets for model
train_split = splits_with_embeddings.loc[splits_with_embeddings['split']=='train'].reset_index(drop=True)
test_split = splits_with_embeddings.loc[splits_with_embeddings['split']=='test'].reset_index(drop=True)
X_train = np.array(train_split['embedding'].tolist())
y_train = np.array(train_split[task].tolist())
X_test = np.array(test_split['embedding'].tolist())
y_test = np.array(test_split[task].tolist())
# Train the final model on all the data
clf = train_final_predictor(X_train, y_train, n_estimators=n_estimators, tree_method=tree_method)
# Evaluate it
automatic_stats_df, custom_stats_df = evaluate_predictor(clf, X_test, y_test, class1_thresh=class1_thresh)
# Add the model details back in
cols = list(automatic_stats_df.columns)
automatic_stats_df['Model Type'] = [benchmark_model_type]
automatic_stats_df['Model Name'] = [benchmark_model_name]
automatic_stats_df['Model Epoch'] = [benchmark_model_epoch]
newcols = ['Model Type','Model Name','Model Epoch'] + cols
automatic_stats_df = automatic_stats_df[newcols]
cols = list(custom_stats_df.columns)
custom_stats_df['Model Type'] = [benchmark_model_type]
custom_stats_df['Model Name'] = [benchmark_model_name]
custom_stats_df['Model Epoch'] = [benchmark_model_epoch]
newcols = ['Model Type','Model Name','Model Epoch'] + cols
custom_stats_df = custom_stats_df[newcols]
# Save automatic results (for nucleus and cytoplasm)
if task!="formation":
automatic_stats_path = f'{outdir}/{task}_verificationFOs_results.csv'
if not(os.path.exists(automatic_stats_path)):
automatic_stats_df.to_csv(automatic_stats_path,index=False)
else:
automatic_stats_df.to_csv(automatic_stats_path,mode='a',index=False,header=False)
# Save custom threshold results (only if it's formation)
if task=="formation":
custom_stats_path = f'{outdir}/{task}_verificationFOs_{class1_thresh}thresh_results.csv'
if not(os.path.exists(custom_stats_path)):
custom_stats_df.to_csv(custom_stats_path,index=False)
else:
custom_stats_df.to_csv(custom_stats_path,mode='a',index=False,header=False)
def main():
# make output directory for this run
os.makedirs('results',exist_ok=True)
output_dir = f'results/{get_local_time()}'
os.makedirs(output_dir,exist_ok=True)
with open_logfile(f'{output_dir}/puncta_benchmark_log.txt'):
# print configurations
print_configpy(config)
# Verify that the environment variables are set correctly
os.environ['CUDA_VISIBLE_DEVICES'] = config.CUDA_VISIBLE_DEVICES
log_update("\nChecking on environment variables...")
log_update(f"\tCUDA_VISIBLE_DEVICES: {os.environ.get('CUDA_VISIBLE_DEVICES')}")
# make embeddings if needed
all_embedding_paths = embed_dataset_for_benchmark(
fuson_ckpts=config.FUSONPLM_CKPTS,
input_data_path='splits.csv', input_fname='FOdb_puncta_sequences',
average=True, seq_col='aa_seq',
benchmark_fusonplm=config.BENCHMARK_FUSONPLM,
benchmark_esm=config.BENCHMARK_ESM,
benchmark_fo_puncta_ml=config.BENCHMARK_FO_PUNCTA_ML,
benchmark_prott5 = config.BENCHMARK_PROTT5,
overwrite=config.PERMISSION_TO_OVERWRITE)
# load the splits with labels
splits = pd.read_csv('splits.csv')
# perform some sanity checks on the splits
check_splits(splits)
n_train = len(splits.loc[splits['split']=='train'])
n_test = len(splits.loc[splits['split']=='test'])
log_update(f"\nSplit breakdown...\n\t{n_train} Training FOs\n\t{n_test} Verification FOs")
# set training constants
train_params = CustomParams(
N_ESTIMATORS = 50,
TREE_METHOD = "hist",
CLASS1_THRESHOLDS = {
'nucleus': 0.83,
'cytoplasm': 0.83,
'formation': 0.83
},
)
log_update("\nTraining configs:")
train_params.print_config(indent='\t')
log_update("\nTraining models")
# loop through the embedding paths and train each one
for embedding_path, details in all_embedding_paths.items():
log_update(f"\tBenchmarking embeddings at: {embedding_path}")
try:
with open(embedding_path, "rb") as f:
embeddings = pickle.load(f)
except:
raise Exception(f"Cannot read embeddings from {embedding_path}")
# combine the embeddings and splits into one dataframe
splits_with_embeddings = pd.DataFrame.from_dict(embeddings.items())
splits_with_embeddings = splits_with_embeddings.rename(columns={0: 'aa_seq', 1: 'embedding'})
splits_with_embeddings = pd.merge(splits_with_embeddings, splits, on='aa_seq',how='left')
for task in ['nucleus','cytoplasm','formation']:
log_update(f"\t\tTask: {task}")
train_and_evaluate_puncta_predictor(details, splits_with_embeddings, output_dir, task=task,
class1_thresh=train_params.CLASS1_THRESHOLDS[task],
n_estimators=train_params.N_ESTIMATORS,tree_method=train_params.TREE_METHOD)
log_update(f"\nMaking summary figures:\n")
log_update(f"\tbar charts...")
os.makedirs(f"{output_dir}/figures",exist_ok=True)
make_all_final_bar_charts(output_dir)
log_update(f"\tDone.")
if __name__ == '__main__':
main()