|
import torch |
|
import time |
|
import pandas as pd |
|
import numpy as np |
|
import pickle |
|
import os |
|
|
|
from fuson_plm.benchmarking.xgboost_predictor import train_final_predictor, evaluate_predictor |
|
from fuson_plm.benchmarking.embed import embed_dataset_for_benchmark |
|
import fuson_plm.benchmarking.puncta.config as config |
|
from fuson_plm.benchmarking.puncta.plot import make_all_final_bar_charts |
|
from fuson_plm.utils.logging import log_update, open_logfile, print_configpy, get_local_time, CustomParams |
|
|
|
def check_splits(df): |
|
|
|
if len(df.loc[df['split'].isna()])>0: |
|
raise Exception("Error: not every benchmarking sequence has been allocated to a split (train or test)") |
|
|
|
if len({'train','test'} - set(df['split'].unique()))!=0: |
|
raise Exception("Error: splits column should only have \'train\' and \'test\'.") |
|
|
|
if len(df.loc[df['aa_seq'].duplicated()])>0: |
|
raise Exception("Error: duplicate sequences provided") |
|
|
|
def train_and_evaluate_puncta_predictor(details, splits_with_embeddings,outdir,task='nucleus',class1_thresh=0.5,n_estimators=50,tree_method="hist"): |
|
""" |
|
task = 'nucleus', 'cytoplasm', or 'formation' |
|
""" |
|
|
|
benchmark_model_type = details['model_type'] |
|
benchmark_model_name = details['model'] |
|
benchmark_model_epoch = details['epoch'] |
|
|
|
|
|
train_split = splits_with_embeddings.loc[splits_with_embeddings['split']=='train'].reset_index(drop=True) |
|
test_split = splits_with_embeddings.loc[splits_with_embeddings['split']=='test'].reset_index(drop=True) |
|
|
|
X_train = np.array(train_split['embedding'].tolist()) |
|
y_train = np.array(train_split[task].tolist()) |
|
X_test = np.array(test_split['embedding'].tolist()) |
|
y_test = np.array(test_split[task].tolist()) |
|
|
|
|
|
clf = train_final_predictor(X_train, y_train, n_estimators=n_estimators, tree_method=tree_method) |
|
|
|
|
|
automatic_stats_df, custom_stats_df = evaluate_predictor(clf, X_test, y_test, class1_thresh=class1_thresh) |
|
|
|
|
|
cols = list(automatic_stats_df.columns) |
|
automatic_stats_df['Model Type'] = [benchmark_model_type] |
|
automatic_stats_df['Model Name'] = [benchmark_model_name] |
|
automatic_stats_df['Model Epoch'] = [benchmark_model_epoch] |
|
newcols = ['Model Type','Model Name','Model Epoch'] + cols |
|
automatic_stats_df = automatic_stats_df[newcols] |
|
|
|
cols = list(custom_stats_df.columns) |
|
custom_stats_df['Model Type'] = [benchmark_model_type] |
|
custom_stats_df['Model Name'] = [benchmark_model_name] |
|
custom_stats_df['Model Epoch'] = [benchmark_model_epoch] |
|
newcols = ['Model Type','Model Name','Model Epoch'] + cols |
|
custom_stats_df = custom_stats_df[newcols] |
|
|
|
|
|
if task!="formation": |
|
automatic_stats_path = f'{outdir}/{task}_verificationFOs_results.csv' |
|
if not(os.path.exists(automatic_stats_path)): |
|
automatic_stats_df.to_csv(automatic_stats_path,index=False) |
|
else: |
|
automatic_stats_df.to_csv(automatic_stats_path,mode='a',index=False,header=False) |
|
|
|
|
|
if task=="formation": |
|
custom_stats_path = f'{outdir}/{task}_verificationFOs_{class1_thresh}thresh_results.csv' |
|
if not(os.path.exists(custom_stats_path)): |
|
custom_stats_df.to_csv(custom_stats_path,index=False) |
|
else: |
|
custom_stats_df.to_csv(custom_stats_path,mode='a',index=False,header=False) |
|
|
|
def main(): |
|
|
|
os.makedirs('results',exist_ok=True) |
|
output_dir = f'results/{get_local_time()}' |
|
os.makedirs(output_dir,exist_ok=True) |
|
|
|
with open_logfile(f'{output_dir}/puncta_benchmark_log.txt'): |
|
|
|
print_configpy(config) |
|
|
|
|
|
os.environ['CUDA_VISIBLE_DEVICES'] = config.CUDA_VISIBLE_DEVICES |
|
log_update("\nChecking on environment variables...") |
|
log_update(f"\tCUDA_VISIBLE_DEVICES: {os.environ.get('CUDA_VISIBLE_DEVICES')}") |
|
|
|
|
|
all_embedding_paths = embed_dataset_for_benchmark( |
|
fuson_ckpts=config.FUSONPLM_CKPTS, |
|
input_data_path='splits.csv', input_fname='FOdb_puncta_sequences', |
|
average=True, seq_col='aa_seq', |
|
benchmark_fusonplm=config.BENCHMARK_FUSONPLM, |
|
benchmark_esm=config.BENCHMARK_ESM, |
|
benchmark_fo_puncta_ml=config.BENCHMARK_FO_PUNCTA_ML, |
|
benchmark_prott5 = config.BENCHMARK_PROTT5, |
|
overwrite=config.PERMISSION_TO_OVERWRITE) |
|
|
|
|
|
splits = pd.read_csv('splits.csv') |
|
|
|
check_splits(splits) |
|
n_train = len(splits.loc[splits['split']=='train']) |
|
n_test = len(splits.loc[splits['split']=='test']) |
|
log_update(f"\nSplit breakdown...\n\t{n_train} Training FOs\n\t{n_test} Verification FOs") |
|
|
|
|
|
train_params = CustomParams( |
|
N_ESTIMATORS = 50, |
|
TREE_METHOD = "hist", |
|
CLASS1_THRESHOLDS = { |
|
'nucleus': 0.83, |
|
'cytoplasm': 0.83, |
|
'formation': 0.83 |
|
}, |
|
) |
|
log_update("\nTraining configs:") |
|
train_params.print_config(indent='\t') |
|
|
|
log_update("\nTraining models") |
|
|
|
for embedding_path, details in all_embedding_paths.items(): |
|
log_update(f"\tBenchmarking embeddings at: {embedding_path}") |
|
try: |
|
with open(embedding_path, "rb") as f: |
|
embeddings = pickle.load(f) |
|
except: |
|
raise Exception(f"Cannot read embeddings from {embedding_path}") |
|
|
|
|
|
splits_with_embeddings = pd.DataFrame.from_dict(embeddings.items()) |
|
splits_with_embeddings = splits_with_embeddings.rename(columns={0: 'aa_seq', 1: 'embedding'}) |
|
splits_with_embeddings = pd.merge(splits_with_embeddings, splits, on='aa_seq',how='left') |
|
|
|
for task in ['nucleus','cytoplasm','formation']: |
|
log_update(f"\t\tTask: {task}") |
|
train_and_evaluate_puncta_predictor(details, splits_with_embeddings, output_dir, task=task, |
|
class1_thresh=train_params.CLASS1_THRESHOLDS[task], |
|
n_estimators=train_params.N_ESTIMATORS,tree_method=train_params.TREE_METHOD) |
|
|
|
log_update(f"\nMaking summary figures:\n") |
|
log_update(f"\tbar charts...") |
|
os.makedirs(f"{output_dir}/figures",exist_ok=True) |
|
make_all_final_bar_charts(output_dir) |
|
log_update(f"\tDone.") |
|
|
|
if __name__ == '__main__': |
|
main() |