|
|
|
import fuson_plm.benchmarking.caid.config as config |
|
import os |
|
os.environ['CUDA_VISIBLE_DEVICES'] = config.CUDA_VISIBLE_DEVICES |
|
|
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.optim as optim |
|
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, precision_recall_curve, average_precision_score |
|
|
|
from sklearn.model_selection import ParameterGrid |
|
from tqdm import tqdm |
|
import pandas as pd |
|
import numpy as np |
|
import sys |
|
from datetime import datetime |
|
import logging |
|
|
|
from fuson_plm.benchmarking.embed import embed_dataset_for_benchmark |
|
from fuson_plm.benchmarking.caid.model import DisorderPredictor |
|
from fuson_plm.benchmarking.caid.utils import DisorderDataset, get_dataloader, check_dataloaders |
|
from fuson_plm.benchmarking.caid.plot import make_auroc_curve, make_benchmark_auroc_curve |
|
from fuson_plm.utils.logging import get_local_time, open_logfile, log_update, print_configpy |
|
|
|
|
|
logging.getLogger("transformers").setLevel(logging.ERROR) |
|
|
|
def check_env_variables(): |
|
log_update("\nChecking on environment variables...") |
|
log_update(f"\tCUDA_VISIBLE_DEVICES: {os.environ.get('CUDA_VISIBLE_DEVICES')}") |
|
log_update(f"\ttorch.cuda.device_count(): {torch.cuda.device_count()}") |
|
for i in range(torch.cuda.device_count()): |
|
log_update(f"\t\tDevice {i}: {torch.cuda.get_device_name(i)}") |
|
|
|
def check_splits(df): |
|
|
|
if len(df.loc[df['split'].isna()])>0: |
|
raise Exception("Error: not every benchmarking sequence has been allocated to a split (train or test)") |
|
|
|
if len({'train','test'} - set(df['split'].unique()))!=0: |
|
raise Exception("Error: splits column should only have \'train\' and \'test\'.") |
|
|
|
if len(df.loc[df['Sequence'].duplicated()])>0: |
|
raise Exception("Error: duplicate sequences provided") |
|
|
|
|
|
def train(model, train_loader, optimizer, n_epochs, criterion, device): |
|
""" |
|
Trains the model for a single epoch. |
|
Args: |
|
model (nn.Module): model that will be trained |
|
dataloader (DataLoader): PyTorch DataLoader with training data |
|
optimizer (torch.optim): optimizer |
|
criterion (nn.Module): loss function |
|
device (torch.device): device (GPU or CPU to train the model |
|
Returns: |
|
total_loss (float): model loss |
|
""" |
|
|
|
model.train() |
|
|
|
|
|
avg_train_losses = [] |
|
|
|
|
|
for epoch in range(1, 1+n_epochs): |
|
log_update(f"EPOCH {epoch}/{n_epochs}") |
|
|
|
|
|
total_train_loss = 0 |
|
|
|
|
|
total_steps = len(train_loader) |
|
update_interval = total_steps // min(20,total_steps) |
|
prog_bar = tqdm(total=total_steps, leave=True, file=sys.stdout) |
|
|
|
|
|
|
|
|
|
for batch_idx, (_, embeddings, labels) in enumerate(train_loader, start=1): |
|
|
|
embeddings, labels = embeddings.to(device), labels.to(device) |
|
|
|
|
|
optimizer.zero_grad() |
|
outputs = model(embeddings) |
|
|
|
loss = criterion(outputs, labels) |
|
loss.backward() |
|
|
|
|
|
optimizer.step() |
|
|
|
|
|
total_train_loss += loss.item() |
|
|
|
if batch_idx % update_interval == 0 or batch_idx == total_steps: |
|
prog_bar.update(update_interval) |
|
sys.stdout.flush() |
|
|
|
prog_bar.close() |
|
|
|
|
|
avg_train_loss = total_train_loss / total_steps |
|
avg_train_losses.append(avg_train_loss) |
|
|
|
return avg_train_losses |
|
|
|
|
|
|
|
def evaluate(model, test_loader, device): |
|
""" |
|
Performs inference on a trained model |
|
Args: |
|
model (nn.Module): the trained model |
|
test_loader (DataLoader): PyTorch DataLoader with testing data |
|
device (torch.device): device (GPU or CPU) to be used for inference |
|
Returns: |
|
preds (list): predicted per-residue disorder labels |
|
true_labels (list): ground truth per-residue disorder labels |
|
""" |
|
model.eval() |
|
test_sequences, test_preds, true_labels = [], [], [] |
|
|
|
|
|
total_steps = len(test_loader) |
|
update_interval = total_steps // min(20,total_steps) |
|
prog_bar = tqdm(total=total_steps, leave=True, file=sys.stdout) |
|
|
|
with torch.no_grad(): |
|
for batch_idx, (sequences, embeddings, labels) in enumerate(test_loader,start=1): |
|
embeddings, labels = embeddings.to(device), labels.to(device) |
|
|
|
|
|
outputs = model(embeddings) |
|
|
|
assert len(sequences)==1 |
|
test_sequences.append(sequences[0]) |
|
test_preds.append(outputs.cpu().numpy()) |
|
true_labels.append(labels.cpu().numpy()) |
|
|
|
if batch_idx % update_interval == 0 or batch_idx == total_steps: |
|
prog_bar.update(update_interval) |
|
sys.stdout.flush() |
|
prog_bar.close() |
|
return test_sequences, test_preds, true_labels |
|
|
|
|
|
def benchmark(model, bench_loader, device): |
|
""" |
|
Performs inference on a trained model |
|
Args: |
|
model (nn.Module): the trained model |
|
bench_loader (DataLoader): PyTorch DataLoader with benchmarking data |
|
device (torch.device): device (GPU or CPU) to be used for inference |
|
Returns: |
|
preds (list): predicted per-residue disorder labels |
|
true_labels (list): ground truth per-residue disorder labels |
|
""" |
|
model.eval() |
|
bench_sequences, bench_preds, true_labels = [], [], [] |
|
|
|
|
|
total_steps = len(bench_loader) |
|
update_interval = total_steps // min(20,total_steps) |
|
prog_bar = tqdm(total=total_steps, leave=True, file=sys.stdout) |
|
|
|
with torch.no_grad(): |
|
for batch_idx, (sequences, embeddings, labels) in enumerate(bench_loader,start=1): |
|
embeddings, labels = embeddings.to(device), labels.to(device) |
|
|
|
|
|
outputs = model(embeddings) |
|
|
|
assert len(sequences)==1 |
|
bench_sequences.append(sequences[0]) |
|
bench_preds.append(outputs.cpu().numpy()) |
|
true_labels.append(labels.cpu().numpy()) |
|
|
|
if batch_idx % update_interval == 0 or batch_idx == total_steps: |
|
prog_bar.update(update_interval) |
|
sys.stdout.flush() |
|
prog_bar.close() |
|
return bench_sequences, bench_preds, true_labels |
|
|
|
def grid_search_caid_predictor(embedding_path, details, output_dir, param_grid, overwrite_saved_model=True): |
|
|
|
grid = ParameterGrid(param_grid) |
|
|
|
|
|
training_hyperparams = { |
|
"learning_rate": None, |
|
"num_epochs": None, |
|
"num_layers": None, |
|
"num_heads": None, |
|
"dropout": None |
|
} |
|
|
|
for params in grid: |
|
|
|
training_hyperparams.update(params) |
|
log_update(f"\nHyperparams:{training_hyperparams}") |
|
train_and_evaluate_caid_predictor(embedding_path, details, output_dir, training_hyperparams, overwrite_saved_model=overwrite_saved_model) |
|
|
|
|
|
def find_best_hyperparams(output_dir, param_grid): |
|
|
|
param_cols = [f"caid_model_{k}" for k in param_grid.keys()] |
|
|
|
|
|
test_metrics = pd.read_csv(f'{output_dir}/caid_hyperparam_screen_test_metrics.csv') |
|
train_losses = pd.read_csv(f'{output_dir}/caid_hyperparam_screen_train_losses.csv') |
|
bench_metrics = pd.read_csv(f'{output_dir}/caid_hyperparam_screen_fusion_benchmark_metrics.csv') |
|
|
|
|
|
test_metrics['Model Epoch'] = test_metrics['Model Epoch'].fillna('') |
|
train_losses['Model Epoch'] = train_losses['Model Epoch'].fillna('') |
|
bench_metrics['Model Epoch'] = bench_metrics['Model Epoch'].fillna('') |
|
|
|
|
|
benchmarked_model_key = ['Model Type','Model Name','Model Epoch'] |
|
ordered_priority_stats = ['AUROC','F1 Score','Accuracy','Precision','Recall'] |
|
sort_order = benchmarked_model_key + ordered_priority_stats |
|
sort_bools = [True]*len(benchmarked_model_key) + [False]*len(ordered_priority_stats) |
|
test_metrics = test_metrics.sort_values( |
|
sort_order, |
|
ascending=sort_bools |
|
).groupby(benchmarked_model_key).head(1).reset_index(drop=True) |
|
|
|
|
|
group_order = benchmarked_model_key+param_cols |
|
sort_order = group_order+["caid_model_epoch"] |
|
sort_bools = [True]*(len(group_order))+[False]*1 |
|
train_losses = train_losses.sort_values( |
|
by=sort_order, |
|
ascending=sort_bools, |
|
).groupby(group_order).head(1).reset_index(drop=True) |
|
|
|
|
|
merge_cols = benchmarked_model_key+param_cols+['path_to_model'] |
|
combined_results = pd.merge( |
|
test_metrics,train_losses, |
|
on=merge_cols, |
|
how='left' |
|
) |
|
|
|
bench_metrics = bench_metrics.rename(columns = {'AUROC': 'Fusion AUROC', |
|
'F1 Score': 'Fusion F1 Score', |
|
'Accuracy': 'Fusion Accuracy', |
|
'Precision': 'Fusion Precision', |
|
'Recall': 'Fusion Recall'}) |
|
combined_results = pd.merge( |
|
combined_results,bench_metrics, |
|
on=merge_cols, |
|
how='left' |
|
) |
|
|
|
|
|
combined_results = combined_results[[ |
|
'Model Type','Model Name','Model Epoch', |
|
'Accuracy','Precision','Recall','F1 Score','AUROC', |
|
'Fusion Accuracy','Fusion Precision','Fusion Recall','Fusion F1 Score','Fusion AUROC', |
|
'caid_model_learning_rate','caid_model_num_epochs','caid_model_num_layers','caid_model_num_heads','caid_model_dropout','caid_model_epoch','caid_model_loss','path_to_model' |
|
]] |
|
combined_results.to_csv(f"{output_dir}/best_caid_model_results.csv",index=False) |
|
|
|
def get_fresh_model(training_hyperparams, device): |
|
input_dim, hidden_dim = 1280, 1280 |
|
model = DisorderPredictor( |
|
input_dim=input_dim, |
|
hidden_dim=hidden_dim, |
|
num_layers=training_hyperparams["num_layers"], |
|
num_heads=training_hyperparams["num_heads"], |
|
dropout=training_hyperparams['dropout'] |
|
) |
|
model.to(device) |
|
|
|
return model |
|
|
|
def predict_from_best_thresh(prob_and_label_df, seq_label_dict=None): |
|
""" |
|
Finds the best prediction threshold for disorder by maximizing F1 Score. Makes predictions |
|
Args: |
|
prob_and_label_df: DataFrame with columns: sequence,prob_1 |
|
seq_label_dict: dictionary of sequences to true labels. e.g. 'MKLP': '1100' |
|
Returns: |
|
prob_and_label_df: new version of original dataframe with added columns: threshold,pred_labels |
|
""" |
|
|
|
prob_and_label_df['labels'] = prob_and_label_df['sequence'].map(seq_label_dict) |
|
|
|
assert prob_and_label_df['labels'].notna().all() |
|
|
|
probs = ','.join(prob_and_label_df['prob_1'].tolist()) |
|
probs = [float(x) for x in probs.split(",")] |
|
true_labels = ''.join(prob_and_label_df['labels'].tolist()) |
|
true_labels = [int(x) for x in list(true_labels)] |
|
total_aas = sum(prob_and_label_df['sequence'].str.len()) |
|
log_update(f"\tLength of dataframe (number of seqs in dataset): {len(prob_and_label_df)}") |
|
log_update(f"\tTotal AAs in dataset: {total_aas}\ttotal probabilities: {len(probs)}\ttotal labels: {len(true_labels)}") |
|
|
|
y_true = np.array(true_labels) |
|
y_probs = np.array(probs) |
|
|
|
|
|
precision, recall, thresholds = precision_recall_curve(y_true, y_probs) |
|
precision = precision[:-1] |
|
recall = recall[:-1] |
|
|
|
f1_scores = 2 * (precision * recall) / (precision + recall) |
|
|
|
|
|
best_threshold_index = np.argmax(f1_scores) |
|
best_threshold = thresholds[best_threshold_index] |
|
|
|
|
|
auprc = average_precision_score(y_true, y_probs) |
|
|
|
log_update(f"\tBest Threshold: {best_threshold}") |
|
log_update(f"\tBest F1 Score: {f1_scores[best_threshold_index]:.2f}") |
|
log_update(f"\tAUPRC: {auprc:.2f}") |
|
|
|
|
|
|
|
prob_and_label_df['threshold'] = [best_threshold]*len(prob_and_label_df) |
|
|
|
prob_and_label_df['pred_labels'] = prob_and_label_df['prob_1'].apply(lambda x: ['1' if float(y)>best_threshold else '0' for y in x.split(",")]) |
|
prob_and_label_df['pred_labels'] = prob_and_label_df['pred_labels'].apply(lambda x: ''.join(x)) |
|
log_update("\tUsed calculated threshold to construct predicted labels for dataset") |
|
return prob_and_label_df |
|
|
|
|
|
def train_and_evaluate_caid_predictor(embedding_path, details, output_dir, training_hyperparams, overwrite_saved_model=True): |
|
|
|
benchmark_model_type = details['model_type'] |
|
benchmark_model_name = details['model'] |
|
benchmark_model_epoch = details['epoch'] |
|
|
|
|
|
model_outer_folder = f"trained_models/{benchmark_model_type}" |
|
if not(np.isnan(benchmark_model_epoch)): model_outer_folder+=f"/{benchmark_model_name}/epoch{benchmark_model_epoch}" |
|
model_full_folder=f"{model_outer_folder}/lr{training_hyperparams['learning_rate']}_bs{1}_hd{1280}_epochs{training_hyperparams['num_epochs']}_layers{training_hyperparams['num_layers']}_heads{training_hyperparams['num_heads']}_drpt{training_hyperparams['dropout']}" |
|
l_model_full_folder = model_full_folder.split("/") |
|
for i in range(0,len(l_model_full_folder)): |
|
newdir="/".join(l_model_full_folder[:i+1]) |
|
os.makedirs(newdir, exist_ok=True) |
|
|
|
|
|
model_full_path = f"{model_full_folder}/model.pth" |
|
train_new_model=True |
|
if os.path.exists(model_full_path): |
|
|
|
if overwrite_saved_model: |
|
log_update(f"\nOverwriting previously trained model with same hyperparams at {model_full_path}") |
|
|
|
else: |
|
log_update(f"\nWARNING: this model may already be trained at {model_full_path}. Skipping") |
|
train_new_model=False |
|
|
|
|
|
if train_new_model: |
|
max_length=4500+2 |
|
|
|
train_dataloader = get_dataloader('splits/train_df.csv', embedding_path, max_length=max_length, batch_size=1, shuffle=True) |
|
test_dataloader = get_dataloader('splits/test_df.csv', embedding_path, max_length=max_length, batch_size=1, shuffle=False) |
|
benchmark_dataloader = get_dataloader('splits/fusion_bench_df.csv', embedding_path, max_length=max_length, batch_size=1, shuffle=False) |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
model = get_fresh_model(training_hyperparams, device) |
|
|
|
|
|
optimizer = optim.Adam(model.parameters(), lr=training_hyperparams["learning_rate"]) |
|
criterion = nn.BCELoss() |
|
num_epochs = training_hyperparams['num_epochs'] |
|
|
|
|
|
|
|
avg_train_losses = train(model, train_dataloader, optimizer, num_epochs, criterion, device) |
|
|
|
formatted_hyperparams = {f"caid_model_{k}":v for k, v in training_hyperparams.items()} |
|
train_loss_df = pd.DataFrame.from_dict(formatted_hyperparams,orient='index').T |
|
train_loss_df['caid_model_epoch'] = [list(range(1,1+num_epochs))] |
|
train_loss_df['caid_model_loss'] = [avg_train_losses] |
|
train_loss_df[['Model Type','Model Name','Model Epoch']] = [[benchmark_model_type,benchmark_model_name,benchmark_model_epoch]] |
|
train_loss_df = train_loss_df.explode(['caid_model_epoch', 'caid_model_loss']) |
|
|
|
|
|
train_loss_combined_results_csv_path = f'{output_dir}/caid_hyperparam_screen_train_losses.csv' |
|
train_loss_individual_results_csv_path = f'{model_full_folder}/caid_train_losses.csv' |
|
train_loss_df.to_csv(train_loss_individual_results_csv_path,mode='w',index=False) |
|
train_loss_df['path_to_model'] = model_full_path |
|
if not(os.path.exists(train_loss_combined_results_csv_path)): |
|
train_loss_df.to_csv(train_loss_combined_results_csv_path,index=False) |
|
else: |
|
train_loss_df.to_csv(train_loss_combined_results_csv_path,mode='a',index=False,header=False) |
|
|
|
log_update(f"Final train loss: {avg_train_losses[-1]:.4f}") |
|
|
|
|
|
|
|
test_sequences, test_preds, test_labels = evaluate(model, test_dataloader, device) |
|
test_metrics = calculate_metrics(test_preds, test_labels) |
|
|
|
test_results_df = pd.DataFrame.from_dict(test_metrics,orient='index').T |
|
test_results_df[['Model Type','Model Name','Model Epoch']] = [[benchmark_model_type,benchmark_model_name,benchmark_model_epoch]] |
|
|
|
hyperparams_df = pd.DataFrame.from_dict(formatted_hyperparams,orient='index').T |
|
test_results_df = pd.concat([test_results_df,hyperparams_df],axis=1) |
|
|
|
|
|
|
|
prob_and_label_df = pd.DataFrame(data = { |
|
'sequence': test_sequences, |
|
'prob_1': [arr.flatten() for arr in test_preds] |
|
}) |
|
prob_and_label_df['prob_1'] = prob_and_label_df['prob_1'].apply( |
|
lambda prob_list: ",".join([f"{round(x, 3):.3f}" for x in prob_list]) |
|
) |
|
prob_and_label_df['Model Type'] = [benchmark_model_type]*len(prob_and_label_df) |
|
prob_and_label_df['Model Name'] = [benchmark_model_name]*len(prob_and_label_df) |
|
prob_and_label_df['Model Epoch'] = [benchmark_model_epoch]*len(prob_and_label_df) |
|
|
|
|
|
test_combined_results_csv_path = f'{output_dir}/caid_hyperparam_screen_test_metrics.csv' |
|
test_results_csv_path = f'{model_full_folder}/caid_hyperparam_screen_test_metrics.csv' |
|
test_results_df.to_csv(test_results_csv_path,mode='w',index=False) |
|
test_results_df['path_to_model'] = model_full_path |
|
if not(os.path.exists(test_combined_results_csv_path)): |
|
test_results_df.to_csv(test_combined_results_csv_path,index=False) |
|
else: |
|
test_results_df.to_csv(test_combined_results_csv_path,mode='a',index=False,header=False) |
|
|
|
|
|
test_probs_csv_path = f'{model_full_folder}/caid_hyperparam_screen_test_probs.csv' |
|
seq_label_dict = pd.read_csv('splits/test_df.csv') |
|
seq_label_dict = dict(zip(seq_label_dict['Sequence'],seq_label_dict['Label'])) |
|
log_update("Finding best threshold for CAID test set predictions based on maximizing F1 Score...") |
|
prob_and_label_df = predict_from_best_thresh(prob_and_label_df, seq_label_dict=seq_label_dict) |
|
prob_and_label_df[['sequence','prob_1','threshold','pred_labels']].to_csv(test_probs_csv_path,mode='w',index=False) |
|
|
|
log_update(f"Test performance: {test_metrics}") |
|
|
|
|
|
|
|
benchmark_sequences, benchmark_preds, benchmark_labels = evaluate(model, benchmark_dataloader, device) |
|
benchmark_metrics = calculate_metrics(benchmark_preds, benchmark_labels) |
|
|
|
benchmark_results_df = pd.DataFrame.from_dict(benchmark_metrics,orient='index').T |
|
benchmark_results_df[['Model Type','Model Name','Model Epoch']] = [[benchmark_model_type,benchmark_model_name,benchmark_model_epoch]] |
|
|
|
hyperparams_df = pd.DataFrame.from_dict(formatted_hyperparams,orient='index').T |
|
benchmark_results_df = pd.concat([benchmark_results_df,hyperparams_df],axis=1) |
|
|
|
|
|
|
|
prob_and_label_df = pd.DataFrame(data = { |
|
'sequence': benchmark_sequences, |
|
'prob_1': [arr.flatten() for arr in benchmark_preds] |
|
}) |
|
prob_and_label_df['prob_1'] = prob_and_label_df['prob_1'].apply( |
|
lambda prob_list: ",".join([f"{round(x, 3):.3f}" for x in prob_list]) |
|
) |
|
prob_and_label_df['Model Type'] = [benchmark_model_type]*len(prob_and_label_df) |
|
prob_and_label_df['Model Name'] = [benchmark_model_name]*len(prob_and_label_df) |
|
prob_and_label_df['Model Epoch'] = [benchmark_model_epoch]*len(prob_and_label_df) |
|
|
|
|
|
benchmark_combined_results_csv_path = f'{output_dir}/caid_hyperparam_screen_fusion_benchmark_metrics.csv' |
|
benchmark_results_csv_path = f'{model_full_folder}/caid_hyperparam_screen_fusion_benchmark_metrics.csv' |
|
benchmark_results_df.to_csv(benchmark_results_csv_path,mode='w',index=False) |
|
benchmark_results_df['path_to_model'] = model_full_path |
|
if not(os.path.exists(benchmark_combined_results_csv_path)): |
|
benchmark_results_df.to_csv(benchmark_combined_results_csv_path,index=False) |
|
else: |
|
benchmark_results_df.to_csv(benchmark_combined_results_csv_path,mode='a',index=False,header=False) |
|
|
|
|
|
benchmark_probs_csv_path = f'{model_full_folder}/caid_hyperparam_screen_fusion_benchmark_probs.csv' |
|
seq_label_dict = pd.read_csv('splits/fusion_bench_df.csv') |
|
seq_label_dict = dict(zip(seq_label_dict['Sequence'],seq_label_dict['Label'])) |
|
log_update("Finding best threshold for fusion benchmark set predictions based on maximizing F1 Score...") |
|
prob_and_label_df = predict_from_best_thresh(prob_and_label_df, seq_label_dict=seq_label_dict) |
|
prob_and_label_df[['sequence','prob_1','threshold','pred_labels']].to_csv(benchmark_probs_csv_path,mode='w',index=False) |
|
|
|
log_update(f"benchmark performance: {benchmark_metrics}") |
|
|
|
|
|
|
|
torch.save(model.state_dict(), model_full_path) |
|
|
|
|
|
else: |
|
|
|
train_loss_combined_results_csv_path = f'{output_dir}/caid_hyperparam_screen_train_losses.csv' |
|
train_loss_individual_results_csv_path = f'{model_full_folder}/caid_train_losses.csv' |
|
train_loss_individual_results = pd.read_csv(train_loss_individual_results_csv_path) |
|
train_loss_individual_results['path_to_model'] = [model_full_path]*len(train_loss_individual_results) |
|
|
|
if not(os.path.exists(train_loss_combined_results_csv_path)): |
|
train_loss_individual_results.to_csv(train_loss_combined_results_csv_path,index=False) |
|
else: |
|
train_loss_individual_results.to_csv(train_loss_combined_results_csv_path,mode='a',index=False,header=False) |
|
|
|
|
|
test_combined_results_csv_path = f'{output_dir}/caid_hyperparam_screen_test_metrics.csv' |
|
test_results_csv_path = f'{model_full_folder}/caid_hyperparam_screen_test_metrics.csv' |
|
test_individual_results = pd.read_csv(test_results_csv_path) |
|
test_individual_results['path_to_model'] = [model_full_path]*len(test_individual_results) |
|
|
|
if not(os.path.exists(test_combined_results_csv_path)): |
|
test_individual_results.to_csv(test_combined_results_csv_path,index=False) |
|
else: |
|
test_individual_results.to_csv(test_combined_results_csv_path,mode='a',index=False,header=False) |
|
|
|
|
|
benchmark_combined_results_csv_path = f'{output_dir}/caid_hyperparam_screen_fusion_benchmark_metrics.csv' |
|
benchmark_results_csv_path = f'{model_full_folder}/caid_hyperparam_screen_fusion_benchmark_metrics.csv' |
|
benchmark_individual_results = pd.read_csv(benchmark_results_csv_path) |
|
benchmark_individual_results['path_to_model'] = [model_full_path]*len(benchmark_individual_results) |
|
|
|
if not(os.path.exists(benchmark_combined_results_csv_path)): |
|
benchmark_individual_results.to_csv(benchmark_combined_results_csv_path,index=False) |
|
else: |
|
benchmark_individual_results.to_csv(benchmark_combined_results_csv_path,mode='a',index=False,header=False) |
|
|
|
|
|
def calculate_metrics(preds, labels, threshold=0.5): |
|
""" |
|
Calculates metrics to assess model performance |
|
Args: |
|
preds (list): model's predictions (probabilities) |
|
labels (list): ground truth labels |
|
threshold (float): minimum threshold a prediction must be met to be considered disordered |
|
Returns: |
|
accuracy (float): accuracy |
|
precision (float): precision |
|
recall (float): recall |
|
f1 (float): F1 score |
|
roc_auc (float): AUROC score |
|
""" |
|
flat_binary_preds, flat_prob_preds, flat_labels = [], [], [] |
|
|
|
for pred, label in zip(preds, labels): |
|
flat_binary_preds.extend((pred > threshold).astype(int).flatten()) |
|
flat_prob_preds.extend(pred.flatten()) |
|
flat_labels.extend(label.flatten()) |
|
|
|
flat_binary_preds = np.array(flat_binary_preds) |
|
flat_prob_preds = np.array(flat_prob_preds) |
|
flat_labels = np.array(flat_labels) |
|
|
|
accuracy = accuracy_score(flat_labels, flat_binary_preds) |
|
precision = precision_score(flat_labels, flat_binary_preds) |
|
recall = recall_score(flat_labels, flat_binary_preds) |
|
f1 = f1_score(flat_labels, flat_binary_preds) |
|
roc_auc = roc_auc_score(flat_labels, flat_prob_preds) |
|
|
|
|
|
metrics_dict = { |
|
'Accuracy': accuracy, |
|
'Precision': precision, |
|
'Recall': recall, |
|
'F1 Score': f1, |
|
'AUROC': roc_auc |
|
} |
|
|
|
return metrics_dict |
|
|
|
def main(): |
|
|
|
os.makedirs('results',exist_ok=True) |
|
output_dir = f'results/{get_local_time()}' |
|
os.makedirs(output_dir,exist_ok=True) |
|
|
|
with open_logfile(f'{output_dir}/caid_benchmark_log.txt'): |
|
|
|
print_configpy(config) |
|
|
|
|
|
check_env_variables() |
|
|
|
|
|
all_embedding_paths = embed_dataset_for_benchmark( |
|
fuson_ckpts=config.FUSONPLM_CKPTS, |
|
input_data_path='splits/splits.csv', |
|
input_fname='CAID2_competition_sequences', |
|
average=False, seq_col='Sequence', |
|
benchmark_fusonplm=config.BENCHMARK_FUSONPLM, |
|
benchmark_esm=config.BENCHMARK_ESM, |
|
benchmark_fo_puncta_ml=False, |
|
overwrite=config.PERMISSION_TO_OVERWRITE_EMBEDDINGS) |
|
|
|
|
|
splits_df = pd.read_csv('splits/splits.csv') |
|
log_update(f"\nSplit breakdown...\n\t{len(splits_df.loc[splits_df['Split']=='Train'])} train seqs\n\t{len(splits_df.loc[splits_df['Split']=='Test'])} test seqs") |
|
|
|
log_update("\nTraining and evaluating models") |
|
|
|
|
|
param_grid = { |
|
'learning_rate': [5e-5], |
|
'num_heads': [5, 8, 10], |
|
'num_layers': [2, 4, 6], |
|
'dropout': [0.2, 0.5], |
|
'num_epochs': [2] |
|
} |
|
|
|
|
|
for embedding_path, details in all_embedding_paths.items(): |
|
log_update(f"\nBenchmarking embeddings at: {embedding_path}") |
|
|
|
grid_search_caid_predictor(embedding_path, details, output_dir, param_grid, overwrite_saved_model=config.PERMISSION_TO_OVERWRITE_MODELS) |
|
|
|
|
|
find_best_hyperparams(output_dir, param_grid) |
|
|
|
|
|
|
|
best_caid_model_results = pd.read_csv(f"{output_dir}/best_caid_model_results.csv") |
|
|
|
best_caid_model_results_benchmark = best_caid_model_results.drop(columns= |
|
['AUROC','F1 Score','Accuracy','Precision','Recall'] |
|
).rename(columns={ |
|
'Fusion AUROC': 'AUROC', |
|
'Fusion F1 Score': 'F1 Score', |
|
'Fusion Accuracy': 'Accuracy', |
|
'Fusion Precision': 'Precision', |
|
'Fusion Recall': 'Recall' |
|
}) |
|
|
|
if __name__ == "__main__": |
|
main() |