File size: 31,749 Bytes
bae913a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
# first few imports, just to set CUDA_VISIBLE_DEVICES before importing any torch libraries
import fuson_plm.benchmarking.caid.config as config
import os
os.environ['CUDA_VISIBLE_DEVICES'] = config.CUDA_VISIBLE_DEVICES

# remaining imports
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, precision_recall_curve, average_precision_score

from sklearn.model_selection import ParameterGrid
from tqdm import tqdm
import pandas as pd
import numpy as np
import sys
from datetime import datetime
import logging

from fuson_plm.benchmarking.embed import embed_dataset_for_benchmark
from fuson_plm.benchmarking.caid.model import DisorderPredictor
from fuson_plm.benchmarking.caid.utils import DisorderDataset, get_dataloader, check_dataloaders
from fuson_plm.benchmarking.caid.plot import make_auroc_curve, make_benchmark_auroc_curve
from fuson_plm.utils.logging import get_local_time, open_logfile, log_update, print_configpy

# configure Transformers logger to only show messages that are ERROR or more severe 
logging.getLogger("transformers").setLevel(logging.ERROR)

def check_env_variables():
    log_update("\nChecking on environment variables...")
    log_update(f"\tCUDA_VISIBLE_DEVICES: {os.environ.get('CUDA_VISIBLE_DEVICES')}")
    log_update(f"\ttorch.cuda.device_count(): {torch.cuda.device_count()}")
    for i in range(torch.cuda.device_count()):
        log_update(f"\t\tDevice {i}: {torch.cuda.get_device_name(i)}")
        
def check_splits(df):
    # make sure everything has a split
    if len(df.loc[df['split'].isna()])>0:
        raise Exception("Error: not every benchmarking sequence has been allocated to a split (train or test)")
    # make sure the only things are train and test
    if len({'train','test'} - set(df['split'].unique()))!=0:
        raise Exception("Error: splits column should only have \'train\' and \'test\'.")
    # make sure there are no duplicate sequences
    if len(df.loc[df['Sequence'].duplicated()])>0:
        raise Exception("Error: duplicate sequences provided")

# Training function
def train(model, train_loader, optimizer, n_epochs, criterion, device):
    """
    Trains the model for a single epoch.
    Args:
        model (nn.Module): model that will be trained
        dataloader (DataLoader): PyTorch DataLoader with training data
        optimizer (torch.optim): optimizer
        criterion (nn.Module): loss function
        device (torch.device): device (GPU or CPU to train the model
    Returns:
        total_loss (float): model loss
    """
    # Training loop
    model.train()
    
    # Avg loss across epochs
    avg_train_losses = []
    
    # Loop through epochs
    for epoch in range(1, 1+n_epochs):
        log_update(f"EPOCH {epoch}/{n_epochs}")
        
        # Initialize loss for the epoch to 0
        total_train_loss = 0
        
        # Make update settings
        total_steps = len(train_loader)
        update_interval = total_steps // min(20,total_steps) # update semi-frequently
        prog_bar = tqdm(total=total_steps, leave=True, file=sys.stdout)
        
        # Iterate through batches
        #with tqdm(enumerate(train_loader,start=1), total=len(train_loader), desc='Training Batch', leave=True, position=0) as pbar:
            #for batch_idx, (embeddings, labels) in pbar:
        for batch_idx, (_, embeddings, labels) in  enumerate(train_loader, start=1):
            # Move tensors to device
            embeddings, labels = embeddings.to(device), labels.to(device)
            
            # Forward pass
            optimizer.zero_grad()
            outputs = model(embeddings)

            loss = criterion(outputs, labels)
            loss.backward()
            
            # Parameter updates
            optimizer.step()
            
            # Update loss
            total_train_loss += loss.item()

            if batch_idx % update_interval == 0 or batch_idx == total_steps:
                prog_bar.update(update_interval)
                sys.stdout.flush()
                
        prog_bar.close()
    
        # Calculate avg loss for the epoch
        avg_train_loss = total_train_loss / total_steps
        avg_train_losses.append(avg_train_loss)
        
    return avg_train_losses


# Evaluation function
def evaluate(model, test_loader, device):
    """
    Performs inference on a trained model
    Args:
        model (nn.Module): the trained model
        test_loader (DataLoader): PyTorch DataLoader with testing data
        device (torch.device): device (GPU or CPU) to be used for inference
    Returns:
        preds (list): predicted per-residue disorder labels
        true_labels (list): ground truth per-residue disorder labels
    """
    model.eval()
    test_sequences, test_preds, true_labels = [], [], []
    
    # Make update settings
    total_steps = len(test_loader)
    update_interval = total_steps // min(20,total_steps) # update semi-frequently
    prog_bar = tqdm(total=total_steps, leave=True, file=sys.stdout)
    
    with torch.no_grad():
        for batch_idx, (sequences, embeddings, labels) in enumerate(test_loader,start=1):
            embeddings, labels = embeddings.to(device), labels.to(device)
            
            # forward pass
            outputs = model(embeddings)
            
            assert len(sequences)==1    # the batch size should be 1; make sure
            test_sequences.append(sequences[0])
            test_preds.append(outputs.cpu().numpy())
            true_labels.append(labels.cpu().numpy())
            
            if batch_idx % update_interval == 0 or batch_idx == total_steps:
                prog_bar.update(update_interval)
                sys.stdout.flush()
        prog_bar.close()
    return test_sequences, test_preds, true_labels

# Evaluation function
def benchmark(model, bench_loader, device):
    """
    Performs inference on a trained model
    Args:
        model (nn.Module): the trained model
        bench_loader (DataLoader): PyTorch DataLoader with benchmarking data
        device (torch.device): device (GPU or CPU) to be used for inference
    Returns:
        preds (list): predicted per-residue disorder labels
        true_labels (list): ground truth per-residue disorder labels
    """
    model.eval()
    bench_sequences, bench_preds, true_labels = [], [], []
    
    # Make update settings
    total_steps = len(bench_loader)
    update_interval = total_steps // min(20,total_steps) # update semi-frequently
    prog_bar = tqdm(total=total_steps, leave=True, file=sys.stdout)
    
    with torch.no_grad():
        for batch_idx, (sequences, embeddings, labels) in enumerate(bench_loader,start=1):
            embeddings, labels = embeddings.to(device), labels.to(device)
            
            # forward pass
            outputs = model(embeddings)
            
            assert len(sequences)==1    # the batch size should be 1; make sure
            bench_sequences.append(sequences[0])
            bench_preds.append(outputs.cpu().numpy())
            true_labels.append(labels.cpu().numpy())
            
            if batch_idx % update_interval == 0 or batch_idx == total_steps:
                prog_bar.update(update_interval)
                sys.stdout.flush()
        prog_bar.close()
    return bench_sequences, bench_preds, true_labels

def grid_search_caid_predictor(embedding_path, details, output_dir, param_grid, overwrite_saved_model=True):
    # prepare the grid search
    grid = ParameterGrid(param_grid)
    
    # initialize dict
    training_hyperparams = {
        "learning_rate": None,
        "num_epochs": None,
        "num_layers": None,
        "num_heads": None,
        "dropout": None
    }
    
    for params in grid:
        # Update hyperparameters
        training_hyperparams.update(params)
        log_update(f"\nHyperparams:{training_hyperparams}")
        train_and_evaluate_caid_predictor(embedding_path, details, output_dir, training_hyperparams, overwrite_saved_model=overwrite_saved_model)
        
        
def find_best_hyperparams(output_dir, param_grid):
    # Isolate the columns that define the hyperparameters
    param_cols = [f"caid_model_{k}" for k in param_grid.keys()]
    
    # Read in the files with all the stats
    test_metrics = pd.read_csv(f'{output_dir}/caid_hyperparam_screen_test_metrics.csv')
    train_losses = pd.read_csv(f'{output_dir}/caid_hyperparam_screen_train_losses.csv')
    bench_metrics = pd.read_csv(f'{output_dir}/caid_hyperparam_screen_fusion_benchmark_metrics.csv')
    
    # Replace nan with empty string for epoch
    test_metrics['Model Epoch'] = test_metrics['Model Epoch'].fillna('')
    train_losses['Model Epoch'] = train_losses['Model Epoch'].fillna('')
    bench_metrics['Model Epoch'] = bench_metrics['Model Epoch'].fillna('')
    
    # Find the hyperparams that produced the best test metrics for each model; then save all relevant numbers in one file
    benchmarked_model_key = ['Model Type','Model Name','Model Epoch']   # uniquely defines the model being benchmarked
    ordered_priority_stats = ['AUROC','F1 Score','Accuracy','Precision','Recall']
    sort_order = benchmarked_model_key + ordered_priority_stats
    sort_bools = [True]*len(benchmarked_model_key) + [False]*len(ordered_priority_stats)
    test_metrics = test_metrics.sort_values(
            sort_order, 
            ascending=sort_bools
        ).groupby(benchmarked_model_key).head(1).reset_index(drop=True)

    # Find the last-epoch losses for each model and hyperparameters
    group_order = benchmarked_model_key+param_cols
    sort_order = group_order+["caid_model_epoch"]
    sort_bools = [True]*(len(group_order))+[False]*1
    train_losses = train_losses.sort_values(
            by=sort_order,
            ascending=sort_bools,
        ).groupby(group_order).head(1).reset_index(drop=True)
    
    # Combine test and train results
    merge_cols = benchmarked_model_key+param_cols+['path_to_model']
    combined_results = pd.merge(
        test_metrics,train_losses,
        on=merge_cols,
        how='left'
        )
    # Combine with benchmark results
    bench_metrics = bench_metrics.rename(columns = {'AUROC': 'Fusion AUROC',
                                                 'F1 Score': 'Fusion F1 Score',
                                                 'Accuracy': 'Fusion Accuracy',
                                                 'Precision': 'Fusion Precision',
                                                 'Recall': 'Fusion Recall'})
    combined_results = pd.merge(
        combined_results,bench_metrics,
        on=merge_cols,
        how='left'
        )
    
    # reorder columns
    combined_results = combined_results[[
        'Model Type','Model Name','Model Epoch',
        'Accuracy','Precision','Recall','F1 Score','AUROC',
        'Fusion Accuracy','Fusion Precision','Fusion Recall','Fusion F1 Score','Fusion AUROC',
        'caid_model_learning_rate','caid_model_num_epochs','caid_model_num_layers','caid_model_num_heads','caid_model_dropout','caid_model_epoch','caid_model_loss','path_to_model'
    ]]
    combined_results.to_csv(f"{output_dir}/best_caid_model_results.csv",index=False)

def get_fresh_model(training_hyperparams, device):
    input_dim, hidden_dim = 1280, 1280
    model = DisorderPredictor(
        input_dim=input_dim,
        hidden_dim=hidden_dim,
        num_layers=training_hyperparams["num_layers"],
        num_heads=training_hyperparams["num_heads"],
        dropout=training_hyperparams['dropout']
    )
    model.to(device) # Push model to device (should be GPU)
    
    return model

def predict_from_best_thresh(prob_and_label_df, seq_label_dict=None):
    """
    Finds the best prediction threshold for disorder by maximizing F1 Score. Makes predictions
    Args:
        prob_and_label_df: DataFrame with columns: sequence,prob_1
        seq_label_dict: dictionary of sequences to true labels. e.g. 'MKLP': '1100'
    Returns:
        prob_and_label_df: new version of original dataframe with added columns: threshold,pred_labels
    """
    # Use seq_label_dict to insert labels
    prob_and_label_df['labels'] = prob_and_label_df['sequence'].map(seq_label_dict)
    # EVERYTHING should have a label!!
    assert prob_and_label_df['labels'].notna().all()
    
    probs = ','.join(prob_and_label_df['prob_1'].tolist())
    probs = [float(x) for x in probs.split(",")]
    true_labels = ''.join(prob_and_label_df['labels'].tolist())
    true_labels = [int(x) for x in list(true_labels)]
    total_aas = sum(prob_and_label_df['sequence'].str.len())
    log_update(f"\tLength of dataframe (number of seqs in dataset): {len(prob_and_label_df)}")
    log_update(f"\tTotal AAs in dataset: {total_aas}\ttotal probabilities: {len(probs)}\ttotal labels: {len(true_labels)}")

    y_true = np.array(true_labels)  # True labels
    y_probs = np.array(probs)  # Predicted probabilities

    # Compute precision, recall, and thresholds
    precision, recall, thresholds = precision_recall_curve(y_true, y_probs)
    precision = precision[:-1]
    recall = recall[:-1]
    # Calculate F1 scores for each threshold
    f1_scores = 2 * (precision * recall) / (precision + recall)

    # Find the threshold that maximizes the F1 score
    best_threshold_index = np.argmax(f1_scores)
    best_threshold = thresholds[best_threshold_index]

    # Compute AUPRC
    auprc = average_precision_score(y_true, y_probs)

    log_update(f"\tBest Threshold: {best_threshold}")
    log_update(f"\tBest F1 Score: {f1_scores[best_threshold_index]:.2f}")
    log_update(f"\tAUPRC: {auprc:.2f}")

    # Edit the original DataFrame
    # Add threshold
    prob_and_label_df['threshold'] = [best_threshold]*len(prob_and_label_df)
    # Make predictions using this new threshold
    prob_and_label_df['pred_labels'] = prob_and_label_df['prob_1'].apply(lambda x: ['1' if float(y)>best_threshold else '0' for y in x.split(",")])
    prob_and_label_df['pred_labels'] = prob_and_label_df['pred_labels'].apply(lambda x: ''.join(x))
    log_update("\tUsed calculated threshold to construct predicted labels for dataset")
    return prob_and_label_df
    
    
def train_and_evaluate_caid_predictor(embedding_path, details, output_dir, training_hyperparams, overwrite_saved_model=True):
    # unpack the details dictioanry
    benchmark_model_type = details['model_type']
    benchmark_model_name = details['model']
    benchmark_model_epoch = details['epoch']
    
    # define model save directories and make if they don't exist
    model_outer_folder = f"trained_models/{benchmark_model_type}"
    if not(np.isnan(benchmark_model_epoch)): model_outer_folder+=f"/{benchmark_model_name}/epoch{benchmark_model_epoch}"
    model_full_folder=f"{model_outer_folder}/lr{training_hyperparams['learning_rate']}_bs{1}_hd{1280}_epochs{training_hyperparams['num_epochs']}_layers{training_hyperparams['num_layers']}_heads{training_hyperparams['num_heads']}_drpt{training_hyperparams['dropout']}"
    l_model_full_folder = model_full_folder.split("/")
    for i in range(0,len(l_model_full_folder)):
        newdir="/".join(l_model_full_folder[:i+1])
        os.makedirs(newdir, exist_ok=True)
        
    # see if we've trained the model before 
    model_full_path = f"{model_full_folder}/model.pth"
    train_new_model=True       #initially, we believe we're training a new model. Let's make sure we want to. 
    if os.path.exists(model_full_path):
        # If the model exists and we ARE allowed to overwrite, still train
        if overwrite_saved_model:
            log_update(f"\nOverwriting previously trained model with same hyperparams at {model_full_path}")
        # If the model exists and we are NOT allowed to overwrite, don't train
        else:
            log_update(f"\nWARNING: this model may already be trained at {model_full_path}. Skipping")
            train_new_model=False
    
    # If training new model, get new model stats. 
    if train_new_model:
        max_length=4500+2
        # make Dataloaders
        train_dataloader = get_dataloader('splits/train_df.csv', embedding_path, max_length=max_length, batch_size=1, shuffle=True)
        test_dataloader = get_dataloader('splits/test_df.csv', embedding_path, max_length=max_length, batch_size=1, shuffle=False)
        benchmark_dataloader = get_dataloader('splits/fusion_bench_df.csv', embedding_path, max_length=max_length, batch_size=1, shuffle=False)
            
        # Set device to GPU
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        # Initialize the model and set it to deice 
        model = get_fresh_model(training_hyperparams, device)
        
        # Initialize optimizer
        optimizer = optim.Adam(model.parameters(), lr=training_hyperparams["learning_rate"])
        criterion = nn.BCELoss()
        num_epochs = training_hyperparams['num_epochs']

        ################# Train
        # Train loop
        avg_train_losses = train(model, train_dataloader, optimizer, num_epochs, criterion, device)
        # Save teh train curve results
        formatted_hyperparams = {f"caid_model_{k}":v for k, v in training_hyperparams.items()}
        train_loss_df = pd.DataFrame.from_dict(formatted_hyperparams,orient='index').T
        train_loss_df['caid_model_epoch'] = [list(range(1,1+num_epochs))]
        train_loss_df['caid_model_loss'] = [avg_train_losses]
        train_loss_df[['Model Type','Model Name','Model Epoch']] = [[benchmark_model_type,benchmark_model_name,benchmark_model_epoch]]
        train_loss_df = train_loss_df.explode(['caid_model_epoch', 'caid_model_loss'])
        
        # Save loss results - both to the model folder (including hyperparams), AND to the current results folder
        train_loss_combined_results_csv_path = f'{output_dir}/caid_hyperparam_screen_train_losses.csv'
        train_loss_individual_results_csv_path = f'{model_full_folder}/caid_train_losses.csv'
        train_loss_df.to_csv(train_loss_individual_results_csv_path,mode='w',index=False)
        train_loss_df['path_to_model'] = model_full_path
        if not(os.path.exists(train_loss_combined_results_csv_path)):
            train_loss_df.to_csv(train_loss_combined_results_csv_path,index=False)
        else:
            train_loss_df.to_csv(train_loss_combined_results_csv_path,mode='a',index=False,header=False)
        
        log_update(f"Final train loss: {avg_train_losses[-1]:.4f}")

        ################# Test
        # Evaluate model on test sequences
        test_sequences, test_preds, test_labels = evaluate(model, test_dataloader, device)
        test_metrics = calculate_metrics(test_preds, test_labels)
        # Make dataframe of test metric results
        test_results_df = pd.DataFrame.from_dict(test_metrics,orient='index').T
        test_results_df[['Model Type','Model Name','Model Epoch']] = [[benchmark_model_type,benchmark_model_name,benchmark_model_epoch]]
        # add the hyperparameters to this
        hyperparams_df = pd.DataFrame.from_dict(formatted_hyperparams,orient='index').T
        test_results_df = pd.concat([test_results_df,hyperparams_df],axis=1)
        
        # Make dataframe of test probabilities (for AUROC curve)
        # Create a pandas DataFrame
        prob_and_label_df = pd.DataFrame(data = {
            'sequence': test_sequences,
            'prob_1': [arr.flatten() for arr in test_preds]
        })
        prob_and_label_df['prob_1'] = prob_and_label_df['prob_1'].apply(
            lambda prob_list: ",".join([f"{round(x, 3):.3f}" for x in prob_list])
        )
        prob_and_label_df['Model Type'] = [benchmark_model_type]*len(prob_and_label_df)
        prob_and_label_df['Model Name'] = [benchmark_model_name]*len(prob_and_label_df)
        prob_and_label_df['Model Epoch'] = [benchmark_model_epoch]*len(prob_and_label_df) 
        
        # Save test results - both to the model folder (including hyperparams), AND to the current results folder
        test_combined_results_csv_path = f'{output_dir}/caid_hyperparam_screen_test_metrics.csv'
        test_results_csv_path = f'{model_full_folder}/caid_hyperparam_screen_test_metrics.csv'
        test_results_df.to_csv(test_results_csv_path,mode='w',index=False)
        test_results_df['path_to_model'] = model_full_path
        if not(os.path.exists(test_combined_results_csv_path)):
            test_results_df.to_csv(test_combined_results_csv_path,index=False)
        else:
            test_results_df.to_csv(test_combined_results_csv_path,mode='a',index=False,header=False)
        
        # Save test probs - only to model folder
        test_probs_csv_path = f'{model_full_folder}/caid_hyperparam_screen_test_probs.csv'
        seq_label_dict = pd.read_csv('splits/test_df.csv')
        seq_label_dict = dict(zip(seq_label_dict['Sequence'],seq_label_dict['Label']))
        log_update("Finding best threshold for CAID test set predictions based on maximizing F1 Score...")
        prob_and_label_df = predict_from_best_thresh(prob_and_label_df, seq_label_dict=seq_label_dict)
        prob_and_label_df[['sequence','prob_1','threshold','pred_labels']].to_csv(test_probs_csv_path,mode='w',index=False)
            
        log_update(f"Test performance: {test_metrics}")
        
        ################# Benchmark
        # Evaluate model on benchmark sequences
        benchmark_sequences, benchmark_preds, benchmark_labels = evaluate(model, benchmark_dataloader, device)
        benchmark_metrics = calculate_metrics(benchmark_preds, benchmark_labels)
        # Make dataframe of benchmark metric results
        benchmark_results_df = pd.DataFrame.from_dict(benchmark_metrics,orient='index').T
        benchmark_results_df[['Model Type','Model Name','Model Epoch']] = [[benchmark_model_type,benchmark_model_name,benchmark_model_epoch]]
        # add the hyperparameters to this
        hyperparams_df = pd.DataFrame.from_dict(formatted_hyperparams,orient='index').T
        benchmark_results_df = pd.concat([benchmark_results_df,hyperparams_df],axis=1)
        
        # Make dataframe of benchmark probabilities (for AUROC curve)
        # Create a pandas DataFrame
        prob_and_label_df = pd.DataFrame(data = {
            'sequence': benchmark_sequences,
            'prob_1': [arr.flatten() for arr in benchmark_preds]
        })
        prob_and_label_df['prob_1'] = prob_and_label_df['prob_1'].apply(
            lambda prob_list: ",".join([f"{round(x, 3):.3f}" for x in prob_list])
        )
        prob_and_label_df['Model Type'] = [benchmark_model_type]*len(prob_and_label_df)
        prob_and_label_df['Model Name'] = [benchmark_model_name]*len(prob_and_label_df)
        prob_and_label_df['Model Epoch'] = [benchmark_model_epoch]*len(prob_and_label_df) 
        
        # Save benchmark results - both to the model folder (including hyperparams), AND to the current results folder
        benchmark_combined_results_csv_path = f'{output_dir}/caid_hyperparam_screen_fusion_benchmark_metrics.csv'
        benchmark_results_csv_path = f'{model_full_folder}/caid_hyperparam_screen_fusion_benchmark_metrics.csv'
        benchmark_results_df.to_csv(benchmark_results_csv_path,mode='w',index=False)
        benchmark_results_df['path_to_model'] = model_full_path
        if not(os.path.exists(benchmark_combined_results_csv_path)):
            benchmark_results_df.to_csv(benchmark_combined_results_csv_path,index=False)
        else:
            benchmark_results_df.to_csv(benchmark_combined_results_csv_path,mode='a',index=False,header=False)
        
        # Save benchmark probs - only to model folder
        benchmark_probs_csv_path = f'{model_full_folder}/caid_hyperparam_screen_fusion_benchmark_probs.csv'
        seq_label_dict = pd.read_csv('splits/fusion_bench_df.csv')
        seq_label_dict = dict(zip(seq_label_dict['Sequence'],seq_label_dict['Label']))
        log_update("Finding best threshold for fusion benchmark set predictions based on maximizing F1 Score...")
        prob_and_label_df = predict_from_best_thresh(prob_and_label_df, seq_label_dict=seq_label_dict)
        prob_and_label_df[['sequence','prob_1','threshold','pred_labels']].to_csv(benchmark_probs_csv_path,mode='w',index=False)
            
        log_update(f"benchmark performance: {benchmark_metrics}")

        ################# Save model
        # Save model and metrics for this hyperparameter combination in the trained models folder 
        torch.save(model.state_dict(), model_full_path)

    # if we didn't train again, still add those results to this benchmarking run so that they all get compared together
    else:
        # Load the appropriate train loses 
        train_loss_combined_results_csv_path = f'{output_dir}/caid_hyperparam_screen_train_losses.csv'
        train_loss_individual_results_csv_path = f'{model_full_folder}/caid_train_losses.csv'
        train_loss_individual_results = pd.read_csv(train_loss_individual_results_csv_path)
        train_loss_individual_results['path_to_model'] = [model_full_path]*len(train_loss_individual_results)
        # Add these results to the combined results file for this run if it exists; otherwise create new combined results file
        if not(os.path.exists(train_loss_combined_results_csv_path)):
            train_loss_individual_results.to_csv(train_loss_combined_results_csv_path,index=False)
        else:
            train_loss_individual_results.to_csv(train_loss_combined_results_csv_path,mode='a',index=False,header=False)
            
        # Load the appropriate test stats
        test_combined_results_csv_path = f'{output_dir}/caid_hyperparam_screen_test_metrics.csv'
        test_results_csv_path = f'{model_full_folder}/caid_hyperparam_screen_test_metrics.csv'
        test_individual_results = pd.read_csv(test_results_csv_path)
        test_individual_results['path_to_model'] = [model_full_path]*len(test_individual_results)
        # Add these results to the combined results file for this run if it exists; otherwise create new combined results file
        if not(os.path.exists(test_combined_results_csv_path)):
            test_individual_results.to_csv(test_combined_results_csv_path,index=False)
        else:
            test_individual_results.to_csv(test_combined_results_csv_path,mode='a',index=False,header=False)
            
        # Load the appropriate benchmark stats
        benchmark_combined_results_csv_path = f'{output_dir}/caid_hyperparam_screen_fusion_benchmark_metrics.csv'
        benchmark_results_csv_path = f'{model_full_folder}/caid_hyperparam_screen_fusion_benchmark_metrics.csv'
        benchmark_individual_results = pd.read_csv(benchmark_results_csv_path)
        benchmark_individual_results['path_to_model'] = [model_full_path]*len(benchmark_individual_results)
        # Add these results to the combined results file for this run if it exists; otherwise create new combined results file
        if not(os.path.exists(benchmark_combined_results_csv_path)):
            benchmark_individual_results.to_csv(benchmark_combined_results_csv_path,index=False)
        else:
            benchmark_individual_results.to_csv(benchmark_combined_results_csv_path,mode='a',index=False,header=False)
        
# Metrics calculation
def calculate_metrics(preds, labels, threshold=0.5):
    """
    Calculates metrics to assess model performance
    Args:
        preds (list): model's predictions (probabilities)
        labels (list): ground truth labels
        threshold (float): minimum threshold a prediction must be met to be considered disordered
    Returns:
        accuracy (float): accuracy
        precision (float): precision
        recall (float): recall
        f1 (float): F1 score
        roc_auc (float): AUROC score
    """
    flat_binary_preds, flat_prob_preds, flat_labels = [], [], []

    for pred, label in zip(preds, labels):
        flat_binary_preds.extend((pred > threshold).astype(int).flatten())   # binary preds are 1 or 0; 1 if the prob > threshold
        flat_prob_preds.extend(pred.flatten())
        flat_labels.extend(label.flatten())

    flat_binary_preds = np.array(flat_binary_preds)
    flat_prob_preds = np.array(flat_prob_preds)
    flat_labels = np.array(flat_labels)

    accuracy = accuracy_score(flat_labels, flat_binary_preds)
    precision = precision_score(flat_labels, flat_binary_preds)
    recall = recall_score(flat_labels, flat_binary_preds)
    f1 = f1_score(flat_labels, flat_binary_preds)
    roc_auc = roc_auc_score(flat_labels, flat_prob_preds)
    
    # make dictionary of the results and return it
    metrics_dict = {
        'Accuracy': accuracy, 
        'Precision': precision, 
        'Recall': recall, 
        'F1 Score': f1, 
        'AUROC': roc_auc
    }

    return metrics_dict

def main():
    # make output directory for this run
    os.makedirs('results',exist_ok=True)
    output_dir = f'results/{get_local_time()}'
    os.makedirs(output_dir,exist_ok=True)
    
    with open_logfile(f'{output_dir}/caid_benchmark_log.txt'):
        # print configurations 
        print_configpy(config)
        
        # Verify that the environment variables are set correctly 
        check_env_variables()
        
        # make embeddings if needed
        all_embedding_paths = embed_dataset_for_benchmark(
                                            fuson_ckpts=config.FUSONPLM_CKPTS, 
                                            input_data_path='splits/splits.csv', 
                                            input_fname='CAID2_competition_sequences', 
                                            average=False, seq_col='Sequence',
                                            benchmark_fusonplm=config.BENCHMARK_FUSONPLM, 
                                            benchmark_esm=config.BENCHMARK_ESM, 
                                            benchmark_fo_puncta_ml=False, 
                                            overwrite=config.PERMISSION_TO_OVERWRITE_EMBEDDINGS)
        
        # load the splits with labels
        splits_df = pd.read_csv('splits/splits.csv')
        log_update(f"\nSplit breakdown...\n\t{len(splits_df.loc[splits_df['Split']=='Train'])} train seqs\n\t{len(splits_df.loc[splits_df['Split']=='Test'])} test seqs")
        
        log_update("\nTraining and evaluating models")
        # Set hyperparameters for disorder predictor 
        
        param_grid = {
        'learning_rate': [5e-5],
        'num_heads': [5, 8, 10],
        'num_layers': [2, 4, 6],
        'dropout': [0.2, 0.5],
        'num_epochs': [2]
        }
        
        # loop through the embedding paths and train each one
        for embedding_path, details in all_embedding_paths.items():
            log_update(f"\nBenchmarking embeddings at: {embedding_path}")
            
            grid_search_caid_predictor(embedding_path, details, output_dir, param_grid, overwrite_saved_model=config.PERMISSION_TO_OVERWRITE_MODELS)
            
        # find the best grid search performer
        find_best_hyperparams(output_dir, param_grid)
        
        # make plots
        #### caid test set
        best_caid_model_results = pd.read_csv(f"{output_dir}/best_caid_model_results.csv")
        #### fusion benchmark set
        best_caid_model_results_benchmark = best_caid_model_results.drop(columns=
            ['AUROC','F1 Score','Accuracy','Precision','Recall']
            ).rename(columns={
                'Fusion AUROC': 'AUROC',
                'Fusion F1 Score': 'F1 Score',
                'Fusion Accuracy': 'Accuracy',
                'Fusion Precision': 'Precision',
                'Fusion Recall': 'Recall'
                })
        
if __name__ == "__main__":
    main()