import torch import time import pandas as pd import numpy as np import pickle import os from fuson_plm.benchmarking.xgboost_predictor import train_final_predictor, evaluate_predictor from fuson_plm.benchmarking.embed import embed_dataset_for_benchmark import fuson_plm.benchmarking.puncta.config as config from fuson_plm.benchmarking.puncta.plot import make_all_final_bar_charts from fuson_plm.utils.logging import log_update, open_logfile, print_configpy, get_local_time, CustomParams def check_splits(df): # make sure everything has a split if len(df.loc[df['split'].isna()])>0: raise Exception("Error: not every benchmarking sequence has been allocated to a split (train or test)") # make sure the only things are train and test if len({'train','test'} - set(df['split'].unique()))!=0: raise Exception("Error: splits column should only have \'train\' and \'test\'.") # make sure there are no duplicate sequences if len(df.loc[df['aa_seq'].duplicated()])>0: raise Exception("Error: duplicate sequences provided") def train_and_evaluate_puncta_predictor(details, splits_with_embeddings,outdir,task='nucleus',class1_thresh=0.5,n_estimators=50,tree_method="hist"): """ task = 'nucleus', 'cytoplasm', or 'formation' """ # unpack the details dictioanry benchmark_model_type = details['model_type'] benchmark_model_name = details['model'] benchmark_model_epoch = details['epoch'] # prepare train and test sets for model train_split = splits_with_embeddings.loc[splits_with_embeddings['split']=='train'].reset_index(drop=True) test_split = splits_with_embeddings.loc[splits_with_embeddings['split']=='test'].reset_index(drop=True) X_train = np.array(train_split['embedding'].tolist()) y_train = np.array(train_split[task].tolist()) X_test = np.array(test_split['embedding'].tolist()) y_test = np.array(test_split[task].tolist()) # Train the final model on all the data clf = train_final_predictor(X_train, y_train, n_estimators=n_estimators, tree_method=tree_method) # Evaluate it automatic_stats_df, custom_stats_df = evaluate_predictor(clf, X_test, y_test, class1_thresh=class1_thresh) # Add the model details back in cols = list(automatic_stats_df.columns) automatic_stats_df['Model Type'] = [benchmark_model_type] automatic_stats_df['Model Name'] = [benchmark_model_name] automatic_stats_df['Model Epoch'] = [benchmark_model_epoch] newcols = ['Model Type','Model Name','Model Epoch'] + cols automatic_stats_df = automatic_stats_df[newcols] cols = list(custom_stats_df.columns) custom_stats_df['Model Type'] = [benchmark_model_type] custom_stats_df['Model Name'] = [benchmark_model_name] custom_stats_df['Model Epoch'] = [benchmark_model_epoch] newcols = ['Model Type','Model Name','Model Epoch'] + cols custom_stats_df = custom_stats_df[newcols] # Save automatic results (for nucleus and cytoplasm) if task!="formation": automatic_stats_path = f'{outdir}/{task}_verificationFOs_results.csv' if not(os.path.exists(automatic_stats_path)): automatic_stats_df.to_csv(automatic_stats_path,index=False) else: automatic_stats_df.to_csv(automatic_stats_path,mode='a',index=False,header=False) # Save custom threshold results (only if it's formation) if task=="formation": custom_stats_path = f'{outdir}/{task}_verificationFOs_{class1_thresh}thresh_results.csv' if not(os.path.exists(custom_stats_path)): custom_stats_df.to_csv(custom_stats_path,index=False) else: custom_stats_df.to_csv(custom_stats_path,mode='a',index=False,header=False) def main(): # make output directory for this run os.makedirs('results',exist_ok=True) output_dir = f'results/{get_local_time()}' os.makedirs(output_dir,exist_ok=True) with open_logfile(f'{output_dir}/puncta_benchmark_log.txt'): # print configurations print_configpy(config) # Verify that the environment variables are set correctly os.environ['CUDA_VISIBLE_DEVICES'] = config.CUDA_VISIBLE_DEVICES log_update("\nChecking on environment variables...") log_update(f"\tCUDA_VISIBLE_DEVICES: {os.environ.get('CUDA_VISIBLE_DEVICES')}") # make embeddings if needed all_embedding_paths = embed_dataset_for_benchmark( fuson_ckpts=config.FUSONPLM_CKPTS, input_data_path='splits.csv', input_fname='FOdb_puncta_sequences', average=True, seq_col='aa_seq', benchmark_fusonplm=config.BENCHMARK_FUSONPLM, benchmark_esm=config.BENCHMARK_ESM, benchmark_fo_puncta_ml=config.BENCHMARK_FO_PUNCTA_ML, benchmark_prott5 = config.BENCHMARK_PROTT5, overwrite=config.PERMISSION_TO_OVERWRITE) # load the splits with labels splits = pd.read_csv('splits.csv') # perform some sanity checks on the splits check_splits(splits) n_train = len(splits.loc[splits['split']=='train']) n_test = len(splits.loc[splits['split']=='test']) log_update(f"\nSplit breakdown...\n\t{n_train} Training FOs\n\t{n_test} Verification FOs") # set training constants train_params = CustomParams( N_ESTIMATORS = 50, TREE_METHOD = "hist", CLASS1_THRESHOLDS = { 'nucleus': 0.83, 'cytoplasm': 0.83, 'formation': 0.83 }, ) log_update("\nTraining configs:") train_params.print_config(indent='\t') log_update("\nTraining models") # loop through the embedding paths and train each one for embedding_path, details in all_embedding_paths.items(): log_update(f"\tBenchmarking embeddings at: {embedding_path}") try: with open(embedding_path, "rb") as f: embeddings = pickle.load(f) except: raise Exception(f"Cannot read embeddings from {embedding_path}") # combine the embeddings and splits into one dataframe splits_with_embeddings = pd.DataFrame.from_dict(embeddings.items()) splits_with_embeddings = splits_with_embeddings.rename(columns={0: 'aa_seq', 1: 'embedding'}) splits_with_embeddings = pd.merge(splits_with_embeddings, splits, on='aa_seq',how='left') for task in ['nucleus','cytoplasm','formation']: log_update(f"\t\tTask: {task}") train_and_evaluate_puncta_predictor(details, splits_with_embeddings, output_dir, task=task, class1_thresh=train_params.CLASS1_THRESHOLDS[task], n_estimators=train_params.N_ESTIMATORS,tree_method=train_params.TREE_METHOD) log_update(f"\nMaking summary figures:\n") log_update(f"\tbar charts...") os.makedirs(f"{output_dir}/figures",exist_ok=True) make_all_final_bar_charts(output_dir) log_update(f"\tDone.") if __name__ == '__main__': main()