File size: 46,902 Bytes
bae913a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
import os
import matplotlib.colors as mcolors
import matplotlib.patches as mpatches
from matplotlib import font_manager
import matplotlib.patches as patches
from sklearn.metrics import roc_curve, auc, r2_score

from fuson_plm.utils.visualizing import set_font

global caid2_winners, caid2_model_rankings
caid2_winners = pd.DataFrame(data=
        {
        'Model Name': ['Dispredict3','flDPnn2','flDPnn','flDPlr','flDPlr2','DisoPred',
                       'IDP-Fusion','ESpritz-D','DeepIDP-2L','disomine','DISOPRED3-diso','IUPred3',
                       'AlphaFold-rsa','AlphaFold-pLDDT'],    # do the top 6 models, and IUPred because it's well-known
        'AUROC': [0.838,0.836,0.833,0.827,0.821,0.821,
                  0.818,0.802,0.800,0.797,0.692,0.755,0.747,0.695],
})
caid2_winners['Model Type'] = ['caid2_competition']*len(caid2_winners)
caid2_winners['Model Epoch'] = [np.nan]*len(caid2_winners)

caid2_model_rankings = {
    'Dispredict3': 1,
    'flDPnn2': 2,
    'flDPnn': 3,
    'flDPlr': 4,
    'flDPlr2': 5,
    'DisoPred': 6,
    'IDP-Fusion': 7,
    'ESpritz-D': 8,
    'DeepIDP-2L': 9,
    'disomine': 10,
    'DISOPRED3-diso': 35,
    'IUPred3': 21,
    'AlphaFold-rsa': 24,
    'AlphaFold-pLDDT': 34
}

# Method for lengthening the model name
def lengthen_model_name(row):
    model_type = row['Model Type']
    name = row['Model Name']
    epoch = row['Model Epoch']
    
    if 'esm' in name:
        return name
    if 'puncta' in name:
        return name
    if model_type=='caid2_competition':
        return name
    
    return f'{name}_e{epoch}'

# Method for shortening the model name for display
def shorten_model_name(row):
    model_type = row['Model Type']
    name = row['Model Name']
    epoch = row['Model Epoch']
    
    if 'esm' in name:
        return 'ESM-2-650M'
    if model_type=='caid2_competition':
        return name
    
    if 'snp_' in name:
        prob_type = 'snp'
    elif 'uniform_' in name:
        prob_type = 'uni'
    
    layers = name.split('layers')[0].split('_')[-1]
    maskrate = name.split('mask')[1].split('-', 1)[0]
    kqv_tag = name.split('layers_')[1].split('_')[0]
    dt = name.split('mask')[1].split('-', 1)[1]
    
    return f'{prob_type}_{layers}L_{kqv_tag}_mask{maskrate}_{dt}_e{epoch}'

def make_heatmap(df, results_dir='.', gold_standard_model_name="esm2_t33_650M_UR50D",split="test",thresh=None,ax=None):
    # Set font to Ubuntu
    set_font()
    
    # Declare columns to compare: metrics
    columns_to_compare = ['AUROC']
    
    # Define the literature-reported values for CAID competition winners - only IF the split is not "benchmark"
    if not(split=="benchmark"):
        df = pd.concat([df,caid2_winners])
    
    # Create Short Model Name and Full Model Name columns for later use
    df['Model Epoch'] = df['Model Epoch'].apply(lambda x: str(int(x)) if not(np.isnan(x)) else '')
    df['Short Model Name'] = df.apply(lambda row: shorten_model_name(row),axis=1)
    df['Full Model Name'] = df.apply(lambda row: lengthen_model_name(row), axis=1)
    
    # Isolate gold standard row for later comparison
    gold_standard = df[df['Full Model Name'] == gold_standard_model_name].reset_index(drop=True).iloc[0]
    gold_standard_short_model_name = df[df['Full Model Name'] == gold_standard_model_name]['Short Model Name'].item()

    # Create a new dataframe for the heatmap; sort by model type and place gold standard on top
    heatmap_data = df[['Model Type','Short Model Name','Full Model Name'] + columns_to_compare].copy()
    heatmap_data['is_gold_standard'] = (heatmap_data['Full Model Name'] == gold_standard_model_name).astype(int)
    heatmap_data = heatmap_data.sort_values(by=['is_gold_standard','Model Type','AUROC'], ascending=[False,True,False]).reset_index(drop=True).drop(columns=['is_gold_standard'])
    # Save the original values before calculating differences so we can use them for annotation
    original_values = heatmap_data[columns_to_compare].copy()
    
    # Calculate differences from the gold standard
    for col in columns_to_compare:
        heatmap_data[col] = heatmap_data[col] - gold_standard[col]

    # Create a color map where values equal to 0 are white, above are red, and below are blue
    cmap = sns.color_palette("coolwarm", as_cmap=True)  # other option is diverging_palette(220, 20, as_cmap=True)

    ### Make the plot
    # can plot on a bigger plot, or make it an individual plot
    if ax is None:
        tallsize = max(8, 8 +.25*(len(heatmap_data)-26))
        fig, ax = plt.subplots(1, 1, figsize=(8, tallsize), dpi=300)
        
    # Plot the heatmap with original values as annotations
    hm = sns.heatmap(heatmap_data.set_index('Short Model Name').drop(columns=['Model Type','Full Model Name']),
                    annot=False, fmt='', cmap=cmap, center=0, 
                    cbar_kws={'label': 'Difference from Gold Standard'})
    
    # Explicitly set tick labels to prevent them from being messed up
    ax.set_yticklabels(heatmap_data['Short Model Name'], rotation=0, fontsize=12)
    # Add padding to the y-axis label
    ax.set_ylabel("Short Model Name", labelpad=20)  # Increase the labelpad value to add more padding

    # Bold any values values that exceed the gold standard
    for i in range(original_values.shape[0]):
        for j in range(original_values.shape[1]):
            value = original_values.iloc[i, j]
            if value > gold_standard[columns_to_compare[j]]:
                ax.text(j + 0.5, i + 0.5, f'{value:.3f}', ha='center', va='center', fontweight='bold', color='black')
            else:
                ax.text(j + 0.5, i + 0.5, f'{value:.3f}', ha='center', va='center', color='black')
                
    # Add horizontal lines between different model types
    model_type_series = heatmap_data['Model Type'].values
    last_index = 0
    labels_positions = []  # To store the positions for labels
    for i in range(1, len(model_type_series)):
        if model_type_series[i] != model_type_series[i - 1]:
            hm.axhline(i, color='white', linewidth=8)  # Draw a thick white line between groups
            labels_positions.append((last_index + i) / 2)  # Store the midpoint for labeling
            last_index = i

    # Add label for the last group
    labels_positions.append((last_index + len(model_type_series)) / 2)
    
    # Italic and bold models that win AUROC; apply yellow coloring to gold standard model
    for ytick, model_name in enumerate(heatmap_data['Short Model Name']):
        if model_name == gold_standard_short_model_name:
            # color yellow
            label = ax.get_yticklabels()[ytick]
            #label.set_color('gold')
            label.set_bbox(dict(facecolor='gold', alpha=0.5, edgecolor='gold'))
        if model_name != gold_standard_short_model_name:
            auroc_value = original_values.loc[ytick, 'AUROC']
            
            # Apply bold and italic for wins on either AUROC or F1 Score
            if (auroc_value > gold_standard['AUROC']):
                label = ax.get_yticklabels()[ytick]
                #label.set_style('italic')
                #label.set_weight('bold')
                label.set_bbox(dict(facecolor='red', alpha=0.3, edgecolor='red'))
    
    # Make legend
    gold_patch = mpatches.Patch(color='gold', alpha=0.5, label='Gold Standard')
    red_patch = mpatches.Patch(color='red', alpha=0.5, label='Winner')
    plt.legend(handles=[gold_patch, red_patch], loc='best', bbox_to_anchor=(0, 0))  # You can change loc to position the legend

    split_fname_dict = {
        "testing": "CAID2_test",
        "training": "CAID2_train",
        "benchmark": "FusionPDB_pLDDT_disorder"
    }
    split_title_dict = {
        "testing": "CAID-2 Disorder Prediction",
        "training": "CAID-2 Disorder Prediction",
        "benchmark": "FusionPDB_pLDDT Disorder Prediction"
    }
    ax.set_title(split_title_dict[split])
    
    # Rotate the color bar label
    cbar = hm.collections[0].colorbar
    cbar.ax.yaxis.set_label_position('right')
    cbar.ax.yaxis.set_ticks_position('right')
    cbar.set_label('Difference from Gold Standard', rotation=270, labelpad=20)  # Rotate 270 degrees and add some padding
    
    # Set tight layout using fig
    fig.tight_layout(rect=[0, 0, 0.95, 1])  # Add extra padding on the right side to fit the label

    plt.savefig(f"{results_dir}/{split_fname_dict[split]}_heatmap_vs_{gold_standard_model_name}.png")

# Plot AUROC curve of ONE model of interest on its fusion pdb performance
def make_benchmark_auroc_curve(results_dir='.', seq_label_dict=None, path_to_results_of_interest='', model_alias=None):
    # Isolate the information for the model we'll be plotting
    benchmark_model = path_to_results_of_interest.split('trained_models/')[1].split('/')
    benchmark_model_type = benchmark_model[0]
    benchmark_model_epoch = np.nan
    benchmark_model_hyperparams = None
    if len(benchmark_model)==5: 
        benchmark_model_name = benchmark_model[1]
        benchmark_model_epoch = benchmark_model[2].split('epoch')[1]
        benchmark_model_hyperparams = benchmark_model[3]
    else:
        benchmark_model_name = benchmark_model[0]
        benchmark_model_hyperparams = benchmark_model[1]
    benchmark_model_info = pd.DataFrame(data={
        'Model Type': [benchmark_model_type], 'Model Name': [benchmark_model_name], 'Model Epoch': [benchmark_model_epoch]
    })
    if model_alias is None:
        model_alias = benchmark_model_info.apply(lambda row: shorten_model_name(row),axis=1).iloc[0]
        
    color_map = {
        model_alias: 'black'
    }
    method_results = {model_alias: path_to_results_of_interest}
    method_results = {k:v for k,v in method_results.items() if v not in [None, '']}
    
    set_font()
    plt.figure(figsize=(10,6),dpi=300)
    
    # To store AUROC values and corresponding labels for sorting
    roc_data = []
    # Read each result file and plot the metrics
    for method, path in method_results.items():
        df = pd.read_csv(path) # columns = prob_1,labels

        # Extract probabilities and labels
        prob_1 = ",".join(df['prob_1'].tolist())
        df['labels'] = df['sequence'].apply(lambda x: seq_label_dict[x])
        labels = "".join(df['labels'].tolist())
        prob_1 = [float(x) for x in prob_1.split(",")]
        labels = [int(x) for x in list(labels)]
        sequences = "".join(df['sequence'].tolist())
        assert len(prob_1)==len(labels)==len(sequences)

        # Compute ROC curve and ROC area
        fpr, tpr, thresholds = roc_curve(labels, prob_1)
        roc_auc = auc(fpr, tpr)
        
        # Store data for sorting later
        roc_data.append((method, fpr, tpr, roc_auc))
        
    # Sort the methods by AUROC values
    roc_data = sorted(roc_data, key=lambda x: x[3], reverse=True)

    # Plot sorted ROC curves
    for method, fpr, tpr, roc_auc in roc_data:
        if method == model_alias:
            plt.plot(fpr, tpr, color=color_map[method], lw=2, label=f'{method} ({roc_auc:0.3f})')
        else:
            plt.plot(fpr, tpr, color=color_map[method], lw=1, alpha=0.7, label=f'{method} ({roc_auc:0.3f})')

    # Set other stylistic elements
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.plot([0, 1], [0, 1], color='darkgrey', lw=2, linestyle='--')
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Receiver Operating Characteristic (ROC) Curve')
    
    # After plotting the ROC curves, customize the legend
    handles, labels = plt.gca().get_legend_handles_labels()

    # Create the legend first
    legend = plt.legend(handles, labels, loc="center left", bbox_to_anchor=(1, 0.5))

    # Iterate through the legend's text labels
    for text in legend.get_texts():
        if model_alias in text.get_text():
            text.set_fontweight('bold')  # Bold the alias model
        
    plt.tight_layout()
    plt.savefig(f'{results_dir}/FusionPDB_pLDDT_disorder_{model_alias}_AUROC_curve.png')
    
# Plot AUROC curve of ONE model of interest with all the CAID models
def make_auroc_curve(results_dir='.', seq_label_dict=None, seq_ids_dict=None, path_to_results_of_interest='', model_alias=None, path_to_esm_results=None, with_rankings=False):
    # Isolate the information for the model we'll be plotting
    benchmark_model = path_to_results_of_interest.split('trained_models/')[1].split('/')
    benchmark_model_type = benchmark_model[0]
    benchmark_model_epoch = np.nan
    benchmark_model_hyperparams = None
    if len(benchmark_model)==5: 
        benchmark_model_name = benchmark_model[1]
        benchmark_model_epoch = benchmark_model[2].split('epoch')[1]
        benchmark_model_hyperparams = benchmark_model[3]
    else:
        benchmark_model_name = benchmark_model[0]
        benchmark_model_hyperparams = benchmark_model[1]
    benchmark_model_info = pd.DataFrame(data={
        'Model Type': [benchmark_model_type], 'Model Name': [benchmark_model_name], 'Model Epoch': [benchmark_model_epoch]
    })
    if model_alias is None:
        model_alias = benchmark_model_info.apply(lambda row: shorten_model_name(row),axis=1).iloc[0]
        
    color_map = {
        'Dispredict3': '#d62727',           #1
        'flDPnn2': '#ff7f0f',               #2
        'flDPnn': '#1f77b4',                #3
        'flDPlr': '#bcbd21',                #4
        'flDPlr2': '#16becf',               #5
        'DisoPred': '#1f77b4',              #6
        'IDP-Fusion': '#d62727',            #7
        'ESpritz-D': '#8b564c',             #8
        'DeepIDP-2L': '#e377c2',            #9
        'disomine': '#e377c2',                #10
        'DISOPRED3-diso': '#ff892d',             
        'IUPred3': '#8b564c',
        'AlphaFold-rsa': '#2ba02b',
        'AlphaFold-pLDDT': '#ff892d',
        model_alias: 'black'
    }
    method_results = {'Dispredict3': 'processed_data/caid2_competition_results/Dispredict3_CAID-2_Disorder_NOX.csv',
                    'flDPnn2': 'processed_data/caid2_competition_results/flDPnn2_CAID-2_Disorder_NOX.csv',
                    'flDPnn': 'processed_data/caid2_competition_results/flDPnn_CAID-2_Disorder_NOX.csv',
                    'flDPlr': 'processed_data/caid2_competition_results/flDPtr_CAID-2_Disorder_NOX.csv',   # name doesn't match but this is what it is in raw download
                    'flDPlr2': 'processed_data/caid2_competition_results/flDPlr2_CAID-2_Disorder_NOX.csv',
                    'DisoPred': 'processed_data/caid2_competition_results/DisoPred_CAID-2_Disorder_NOX.csv',
                    'IDP-Fusion': 'processed_data/caid2_competition_results/IDP-Fusion_CAID-2_Disorder_NOX.csv',        
                    'ESpritz-D': 'processed_data/caid2_competition_results/ESpritz-D_CAID-2_Disorder_NOX.csv',           
                    'DeepIDP-2L': 'processed_data/caid2_competition_results/DeepIDP-2L_CAID-2_Disorder_NOX.csv',          
                    'disomine': 'processed_data/caid2_competition_results/disomine_CAID-2_Disorder_NOX.csv',              
                    'DISOPRED3-diso': 'processed_data/caid2_competition_results/DISOPRED3-diso_CAID-2_Disorder_NOX.csv',             
                    'AlphaFold-rsa': 'processed_data/caid2_competition_results/AlphaFold-rsa_CAID-2_Disorder_NOX.csv',
                    'AlphaFold-pLDDT': 'processed_data/caid2_competition_results/AlphaFold-disorder_CAID-2_Disorder_NOX.csv',        # name doesn't match but this is what it is in raw download
                    'IUPred3': 'processed_data/caid2_competition_results/IUPred3_CAID-2_Disorder_NOX.csv',
                    model_alias: path_to_results_of_interest
                }
    if path_to_esm_results is not None:
        method_results['ESM-2-650M'] = path_to_esm_results
        color_map['ESM-2-650M'] = 'black'
        
    method_results = {k:v for k,v in method_results.items() if v not in [None, '']}
    
    set_font()
    plt.figure(figsize=(12,6),dpi=300)
    
    # To store AUROC values and corresponding labels for sorting
    merged_preds = pd.DataFrame(data={'sequence':[]})
    merged_tpr_fpr = pd.DataFrame(data={'model': [],'fpr':[],'tpr':[]})
    roc_data = []
    # Read each result file and plot the metrics
    for method, path in method_results.items():
        df = pd.read_csv(path) # columns = prob_1,labels
        merged_preds = pd.merge(merged_preds, 
                                df.rename(columns={'prob_1':f"{method}_prob_1"})[['sequence',f"{method}_prob_1",]],
                                on=['sequence'],how='outer')
        
        # Extract probabilities and labels
        prob_1 = ",".join(df['prob_1'].tolist())
        df['labels'] = df['sequence'].apply(lambda x: seq_label_dict[x])
        labels = "".join(df['labels'].tolist())
        prob_1 = [float(x) for x in prob_1.split(",")]
        labels = [int(x) for x in list(labels)]
        sequences = "".join(df['sequence'].tolist())
        assert len(prob_1)==len(labels)==len(sequences)

        # Compute ROC curve and ROC area
        fpr, tpr, thresholds = roc_curve(labels, prob_1)
        new_tpr_fpr = pd.DataFrame(data={
            'model': [method]*len(fpr),
            'fpr': fpr, 'tpr': tpr
        })
        merged_tpr_fpr = pd.concat([merged_tpr_fpr,new_tpr_fpr])
        roc_auc = auc(fpr, tpr)
        
        if method==model_alias:
            path_to_og_metrics = path_to_results_of_interest.rsplit('/',1)[0]+'/caid_hyperparam_screen_test_metrics.csv'
            og_metrics = pd.read_csv(path_to_og_metrics)
            roc_auc = og_metrics['AUROC'][0]
        
        # Store data for sorting later
        roc_data.append((method, fpr, tpr, roc_auc))
       
    # Save the merged dataframe as source data
    merged_preds['labels'] = merged_preds['sequence'].apply(lambda x: seq_label_dict[x])
    merged_preds['labels'] = merged_preds['labels'].apply(lambda x: ",".join([str(y) for y in x]))
    merged_preds['ids'] = merged_preds['sequence'].apply(lambda x: seq_ids_dict[x])
    merged_preds.drop(columns={'sequence'}).to_csv(f"{results_dir}/CAID_prediction_source_data.csv",index=False)
    merged_tpr_fpr.to_csv(f"{results_dir}/CAID_fpr_tpr_source_data.csv",index=False)
    # Sort the methods by AUROC values
    roc_data = sorted(roc_data, key=lambda x: x[3], reverse=True)
    
    # figure out the labels
    labels = {method: method for method in method_results}
    if with_rankings:
        for method in labels:
            if method in caid2_model_rankings:
                labels[method] = f"{caid2_model_rankings[method]}. {method}"

    # Plot sorted ROC curves
    for method, fpr, tpr, roc_auc in roc_data:
        if method=='ESM-2-650M' and path_to_esm_results is not None:
            plt.plot(fpr, tpr, color=color_map[method], lw=2, linestyle='--', label=f'{labels[method]} ({roc_auc:0.3f})')
        elif method == model_alias:
            plt.plot(fpr, tpr, color=color_map[method], lw=2, label=f'{labels[method]} ({roc_auc:0.3f})')
        else:
            plt.plot(fpr, tpr, color=color_map[method], lw=1, alpha=0.7, label=f'{labels[method]} ({roc_auc:0.3f})')

    # Set other stylistic elements
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xticks(fontsize=20)
    plt.yticks(fontsize=20)
    plt.plot([0, 1], [0, 1], color='darkgrey', lw=2, linestyle='--')
    plt.xlabel('False Positive Rate', fontsize=22)
    plt.ylabel('True Positive Rate', fontsize=22)
    plt.title('CAID2 Disorder NOX Dataset: ROC Curve', fontsize=22)
    
    # After plotting the ROC curves, customize the legend
    handles, labels = plt.gca().get_legend_handles_labels()

    # Create the legend first
    legend = plt.legend(handles, labels, loc="center left", bbox_to_anchor=(1.1, 0.5), fontsize=16)

    # Iterate through the legend's text labels
    for text in legend.get_texts():
        if model_alias in text.get_text():
            text.set_fontweight('bold')  # Bold the alias model
        elif (path_to_esm_results is not None) and "ESM-2-650M" in text.get_text():
            text.set_fontweight('bold') # Bold ESM if we're comparing to it
        
    plt.tight_layout()
    figpath = f'{results_dir}/CAID2_{model_alias}_AUROC_curve.png'
    if path_to_esm_results is not None:
        figpath = f'{results_dir}/CAID2_{model_alias}_with_ESM_AUROC_curve.png'
    plt.savefig(figpath)

    
def plot_disorder_content_scatter(train_labels, test_labels, benchmark_labels, savepath='splits/disorder_content_scatter.png'):
    """
    Compare disorder content between the train, test, and fusion benchmark sets based on the TRUE labels.
    Each labels vector should have ['11110000','0001110',...] format.  
    """
    
    # Get train disorder distribution
    train_lengths = []
    train_frac_disorder = []
    for vec in train_labels:
        veclist = [int(x) for x in vec]
        train_lengths.append(len(veclist))
        train_frac_disorder.append(sum(veclist)/len(veclist))
    
    # Get test disorder distribution
    test_lengths = []
    test_frac_disorder = []
    for vec in test_labels:
        veclist = [int(x) for x in vec]
        test_lengths.append(len(veclist))
        test_frac_disorder.append(sum(veclist)/len(veclist))
        
    # Get benchmark disorder distribution
    benchmark_lengths = []
    benchmark_frac_disorder = []
    for vec in benchmark_labels:
        veclist = [int(x) for x in vec]
        benchmark_lengths.append(len(veclist))
        benchmark_frac_disorder.append(sum(veclist)/len(veclist))
    
    # make a plot
    set_font()
    color_map = {
    'train': '#0072B2',
    'test': '#E69F00',
    'fusion': 'purple'
    }
    
    # Plotting
    fig, ax = plt.subplots(figsize=(10, 6))

    ax.scatter(train_lengths, train_frac_disorder, color=color_map['train'], label='Train', alpha=0.7)
    ax.scatter(test_lengths, test_frac_disorder, color=color_map['test'], label='Test', alpha=0.7)
    ax.scatter(benchmark_lengths, benchmark_frac_disorder, color=color_map['fusion'], label='Fusion', alpha=0.7)

    # Labels and title
    ax.set_xlabel('Length')
    ax.set_ylabel('Fraction of Disorder')
    ax.set_title('Length vs. Fraction of Disorder for Train, Test, and Benchmark Datasets')
    ax.legend()
    plt.tight_layout()
    plt.savefig(savepath)

def plot_disorder_content_hist(labels, ids, title="data", color="black", savepath='splits/disorder_content_histograms.png'):
    """
    Compare disorder content between the train, test, and fusion benchmark sets based on the TRUE labels.
    Each labels vector should have ['11110000','0001110',...] format.  
    """
    set_font()
    
    # Get disorder distribution
    lengths = []
    frac_disorder = []
    for vec in labels:
        veclist = [int(x) for x in vec]
        lengths.append(len(veclist))
        frac_disorder.append(100*sum(veclist)/len(veclist)) # make it a percent, i like this better
    
    # save the source data
    source_data = pd.DataFrame(data={
        'ID': ids,
        'Percent_Disordered': frac_disorder 
    })
    source_data['Percent_Disordered'] = source_data['Percent_Disordered'].round(3)
    source_data.to_csv(savepath.replace(".png","_source_data.csv"),index=False)
    
    fig, ax = plt.subplots(1, 1, figsize=(20, 12))

    # Plot histogram for train data
    title_fontsize = 70
    axislabel_fontsize = 70
    tick_fontsize = 50
    ax.hist(frac_disorder, bins=20, color=color, alpha=0.7)
    ax.set_title(title, fontsize=title_fontsize)
    ax.set_xlabel('% Disordered', fontsize=axislabel_fontsize)
    ax.set_ylabel('Count', fontsize=axislabel_fontsize)
    ax.grid(True)
    ax.set_axisbelow(True)
    ax.tick_params(axis='both', which='major', labelsize=tick_fontsize)
    
    # Calculate the mean and median of the percent coverage
    mean_coverage = np.mean(frac_disorder)
    median_coverage = np.median(frac_disorder)

    # Add vertical line for the mean
    ax.axvline(mean_coverage, color='black', linestyle='--', linewidth=2, label=f'Mean: {mean_coverage:.1f}%')
    
    # Add vertical line for the median
    ax.axvline(median_coverage, color='black', linestyle='-', linewidth=2, label=f'Median: {median_coverage:.1f}%')

    ax.legend(fontsize=50, title_fontsize=50)

    plt.tight_layout()
    plt.savefig(savepath)

def plot_group_disorder_content_hist(train_labels, test_labels, benchmark_labels, savepath='splits/disorder_content_histograms.png',orient='horizontal'):
    """
    Compare disorder content between the train, test, and fusion benchmark sets based on the TRUE labels.
    Each labels vector should have ['11110000','0001110',...] format.  
    """
    
    # Get train disorder distribution
    train_lengths = []
    train_frac_disorder = []
    for vec in train_labels:
        veclist = [int(x) for x in vec]
        train_lengths.append(len(veclist))
        train_frac_disorder.append(sum(veclist)/len(veclist))
    
    # Get test disorder distribution
    test_lengths = []
    test_frac_disorder = []
    for vec in test_labels:
        veclist = [int(x) for x in vec]
        test_lengths.append(len(veclist))
        test_frac_disorder.append(sum(veclist)/len(veclist))
        
    # Get benchmark disorder distribution
    benchmark_lengths = []
    benchmark_frac_disorder = []
    for vec in benchmark_labels:
        veclist = [int(x) for x in vec]
        benchmark_lengths.append(len(veclist))
        benchmark_frac_disorder.append(sum(veclist)/len(veclist))
    
    # make a plot
    set_font()
    color_map = {
    'train': '#0072B2',
    'test': '#E69F00',
    'fusion': 'mediumpurple'
    }
    
    # Create a 1x3 subplot (1 row, 3 columns) or 3x1 
    if orient=='horizontal':
        fig, axes = plt.subplots(1, 3, figsize=(15, 5), sharey=False)
    if orient=='vertical':
        fig, axes = plt.subplots(3, 1, figsize=(5, 15), sharey=False)

    # Plot histogram for train data
    title_fontsize = 26
    axislabel_fontsize = 26
    tick_fontsize = 16
    axes[0].hist(train_frac_disorder, bins=20, color=color_map['train'], alpha=0.7)
    axes[0].set_title('CAID2 Train', fontsize=title_fontsize)
    if orient=="horizontal":
        axes[0].set_xlabel('Fraction of Disorder', fontsize=axislabel_fontsize)
    axes[0].set_ylabel('Frequency', fontsize=axislabel_fontsize)
    axes[0].grid(True)
    axes[0].set_axisbelow(True)
    axes[0].tick_params(axis='both', which='major', labelsize=tick_fontsize)


    # Plot histogram for test data
    axes[1].hist(test_frac_disorder, bins=20, color=color_map['test'], alpha=0.7)
    axes[1].set_title('CAID2 Test',fontsize=title_fontsize)
    if orient=="horizontal":
        axes[1].set_xlabel('Fraction of Disorder', fontsize=axislabel_fontsize)
    if orient=="vertical":
        axes[1].set_ylabel('Frequency', fontsize=axislabel_fontsize)
    axes[1].grid(True)
    axes[1].set_axisbelow(True)
    axes[1].tick_params(axis='both', which='major', labelsize=tick_fontsize)

    # Plot histogram for benchmark (fusion) data
    axes[2].hist(benchmark_frac_disorder, bins=20, color=color_map['fusion'], alpha=0.7)
    axes[2].set_title('Fusion Oncoproteins',fontsize=title_fontsize)
    axes[2].set_xlabel('Fraction of Disorder', fontsize=axislabel_fontsize)
    if orient=="vertical":
        axes[2].set_ylabel('Frequency', fontsize=axislabel_fontsize)
    axes[2].grid(True)
    axes[2].set_axisbelow(True)
    axes[2].tick_params(axis='both', which='major', labelsize=tick_fontsize)
    plt.tight_layout()
    plt.savefig(savepath)
    
def categorize_plddt(values):
    categories = {
        "<= 50": sum(1 for x in values if x <= 50),
        "50-70": sum(1 for x in values if 50 < x <= 70),
        "70-90": sum(1 for x in values if 70 < x <= 90),
        "> 90": sum(1 for x in values if x > 90)
    }
    return categories


def plot_fusion_sequence_pLDDT_left_to_right(fusion_structure_data, fusiongene, save_path=''):
    """
    Plot each amino acid in the sequence as a separate colored bar based on pLDDT values.
    """
    set_font()
    # Filter for specific fusion data and preprocess
    df_of_interest = fusion_structure_data[fusion_structure_data['FusionGene'] == fusiongene].copy()
    df_of_interest['Fusion_AA_pLDDTs'] = df_of_interest['Fusion_AA_pLDDTs'].apply(lambda x: [float(i) for i in x.split(',')])
    df_of_interest['Label'] = df_of_interest['Fusion_Length'].astype(str) + 'AAs'
    
    # Sort data by Fusion_Length
    df_of_interest = df_of_interest.sort_values(by='Fusion_Length', ascending=True).reset_index(drop=True)
    
    # Define colors for each pLDDT range
    category_colors = {"<= 50": "#f27842", "50-70": "#f8d514", "70-90": "#60c1e8", "> 90": "#004ecb"}
    
    # Helper function to get color based on pLDDT
    def get_color(pLDDT):
        if pLDDT > 90:
            return category_colors["> 90"]
        elif pLDDT > 70:
            return category_colors["70-90"]
        elif pLDDT > 50:
            return category_colors["50-70"]
        else:
            return category_colors["<= 50"]
    
    # Start plotting each sequence with colored bars
    fig, ax = plt.subplots(figsize=(10, 6))
    if len(df_of_interest)<3:
        fig, ax = plt.subplots(figsize=(10, 2))

    average_plddt = dict(zip(df_of_interest['Label'], df_of_interest['Fusion_pLDDT']))
    df_of_interest['Fusion_AA_colors'] = df_of_interest['Fusion_AA_pLDDTs'].apply(lambda x: [get_color(plddt) for plddt in x])
    df_of_interest['Fusion_pLDDT_color'] = df_of_interest['Fusion_pLDDT'].apply(lambda plddt: get_color(plddt))
    # just save the columns needed for the plot 
    df_of_interest[['FusionGene','seq_id','Fusion_Length','Fusion_pLDDT','Fusion_AA_pLDDTs','Fusion_AA_colors','Fusion_pLDDT_color',
                    'top_hg_UniProtID','top_hg_UniProt_isoform','top_hg_UniProt_fus_indices',
                    'top_tg_UniProtID','top_tg_UniProt_isoform','top_tg_UniProt_fus_indices']].to_csv(f"{save_path}/plddt_sequence_{fusiongene}_source_data.csv",index=False)
    
    for idx, row in df_of_interest.iterrows():
        pLDDT_values = row['Fusion_AA_pLDDTs']
        colors = [get_color(plddt) for plddt in pLDDT_values]
        
        # Plot each amino acid in the sequence with the respective color
        ax.bar(range(len(pLDDT_values)), 
               [0.7] * len(pLDDT_values), color=colors, edgecolor='none', 
               bottom=idx - 0.7 / 2)  # Centering each row at idx
        
    labels = df_of_interest['Label'].tolist()
    # Annotate each bar with the Fusion_pLDDT value on the right, colored by PLDDT category
    for idx, label in enumerate(labels):
        avg_plddt_value = average_plddt[label]
        
        # Determine color based on the PLDDT category
        if avg_plddt_value > 90:
            color = '#004ecb'
        elif avg_plddt_value > 70:
            color = "#60c1e8"
        elif avg_plddt_value > 50:
            color = '#f8d514'
        else:
            color = '#f27842'
            
        # Annotate with the determined color
        if len(df_of_interest)>10:
            markersize = 10
        elif len(df_of_interest)>5:
            markersize = 16
        else:
            markersize=12
        ax.plot(1.02*max(df_of_interest['Fusion_Length']), 
                idx, marker='o', color="black", markersize=markersize, markerfacecolor=color, markeredgewidth=2)

        # Add breakpoint box - make sure we actually HAVE one of each
        hg_indices, tg_indices = None, None 
        if not(type(df_of_interest['top_hg_UniProt_fus_indices'][idx])==float):
            hg_indices = [int(x) for x in df_of_interest['top_hg_UniProt_fus_indices'][idx].split(',')]
        if not(type(df_of_interest['top_tg_UniProt_fus_indices'][idx])==float):
            tg_indices = [int(x) for x in df_of_interest['top_tg_UniProt_fus_indices'][idx].split(',')]
        print(hg_indices, tg_indices)
        
        if (hg_indices is not None) and (tg_indices is not None):
            box_start = min(hg_indices[-1],tg_indices[0])
            box_end = max(hg_indices[-1],tg_indices[0])
        elif hg_indices is not None:
            box_start, box_end = hg_indices[-1], hg_indices[-1]
        elif tg_indices is not None:
            box_start, box_end = tg_indices[0], tg_indices[0]
            
        print(f"box indices for structure {idx}, fusion gene {fusiongene}", box_start, box_end)
        
        # Plot the rectangle, making it slightly larger than the rest of the bar
        rect = patches.Rectangle((box_start, idx - 0.7 / 2), box_end-box_start, 0.7, linewidth=2, edgecolor='black', facecolor='none')
        ax.add_patch(rect)
    
    # Customize plot
    ax.set_yticks([])  # Hide y-axis ticks
    ax.set_yticklabels([])  # Hide y-axis labels
    ax.set_ylim(-0.5, len(df_of_interest) - 0.5) # reduce white space at top
    ax.set_xlabel("Amino Acid Sequence (ordered)", fontsize=14)
    # Customize x-axis for labeling
    ax.set_xlim(left=0)  # Start x-axis at 0 to make bars flush left
    ax.set_xlabel("Amino Acid Sequence (ordered)", fontsize=14)
    ax.tick_params(axis='x', labelsize=30) 

    
    plt.title(f"{fusiongene} pLDDT Distribution by Amino Acid Sequence", fontsize=16)
    plt.tight_layout()

    # Save figure
    fusiongene_savename = fusiongene.replace("::","-")
    plt.savefig(f"{save_path}/plddt_sequence_{fusiongene_savename}.png", dpi=300)
    plt.show()
    
def plot_favorite_fusion_pLDDT_distribution(fusion_structure_data, fusiongene, save_path=''):
    """
    Make a stacked bar chart of the pLDDT distribution 
    """
    set_font()
    # Filter for EWSR1::FLI1 fusion data and preprocess
    df_of_interest = fusion_structure_data[fusion_structure_data['FusionGene'] == fusiongene].copy()
    df_of_interest['Fusion_AA_pLDDTs'] = df_of_interest['Fusion_AA_pLDDTs'].apply(lambda x: [float(i) for i in x.split(',')])
    df_of_interest['Label'] = df_of_interest['Fusion_Length'].astype(str) + 'AAs'
    # Sort data by Fusion_Length
    df_of_interest = df_of_interest.sort_values(by='Fusion_Length', ascending=True).reset_index(drop=True)
    # Convert to dictionary format
    data_dict = dict(zip(df_of_interest['Label'], df_of_interest['Fusion_AA_pLDDTs']))
    average_plddt = dict(zip(df_of_interest['Label'], df_of_interest['Fusion_pLDDT']))
    
    # Categorize each structure
    categorized_data = {structure: categorize_plddt(plddt_values) for structure, plddt_values in data_dict.items()}

    # Extract counts for each category
    labels = list(categorized_data.keys())
    categories = ["<= 50", "50-70", "70-90", "> 90"]
    counts = {cat: [categorized_data[structure][cat] for structure in labels] for cat in categories}

    # Define colors for each category
    category_colors = {"<= 50": "#f27842", "50-70": "#f8d514", "70-90": "#60c1e8", "> 90": "#004ecb"}

    # Re-categorize PLDDT values for the bar chart
    categorized_data = {structure: categorize_plddt(plddt_values) for structure, plddt_values in data_dict.items()}
    labels = list(categorized_data.keys())
    counts = {cat: [categorized_data[structure][cat] for structure in labels] for cat in categories}

    # Plotting the horizontal stacked bar chart with annotations for 'Fusion_pLDDT' values
    fig, ax = plt.subplots(figsize=(10, 6))
    if len(data_dict)<3:
        fig, ax = plt.subplots(figsize=(10, 2))
    bottom = np.zeros(len(labels))

    # Stack each category horizontally
    for cat in categories:
        ax.barh(labels, counts[cat], label=cat, color=category_colors[cat], left=bottom)
        bottom += counts[cat]  # Update the left position for the next stack

    # Annotate each bar with the Fusion_pLDDT value on the right, colored by PLDDT category
    for idx, label in enumerate(labels):
        avg_plddt_value = average_plddt[label]
        
        # Determine color based on the PLDDT category
        if avg_plddt_value > 90:
            color = '#004ecb'
        elif avg_plddt_value > 70:
            color = "#60c1e8"
        elif avg_plddt_value > 50:
            color = '#f8d514'
        else:
            color = '#f27842'
            
        # Annotate with the determined color
        #ax.text(bottom[idx] + 1, idx, f"{avg_plddt_value:.2f}", va='center', ha='left', color="black", fontsize=18, fontweight='bold')
        if len(df_of_interest)>10:
            markersize = 10
        elif len(df_of_interest)>5:
            markersize = 16
        else:
            markersize=12
        ax.plot(bottom[idx] + .02*max(df_of_interest['Fusion_Length']), idx, marker='s', color="black", markersize=markersize, markerfacecolor=color, markeredgewidth=2)


    # Add labels and legend
    #ax.set_xlim([0,max(df_of_interest['Fusion_Length'])*1.0])
    #ax.set_ylabel("Structures")
    # Save original ticks before changing label size
    #ax.tick_params(axis='x', labelsize=16) 
    #original_xticks = ax.get_xticks()
    # Set ticks explicitly to avoid automatic adjustment
    #ax.set_xticks(original_xticks)

    #ax.set_xlabel("Length",fontsize=40)
    ax.tick_params(axis='x', labelsize=30) 
    #ax.tick_params(axis='y', labelsize=16) 
    ax.tick_params(axis='y', left=False, labelleft=False)
    #ax.set_title(f"{fusiongene} pLDDT Distribution")
    #ax.legend(title="pLDDT Ranges", fontsize=16, bbox_to_anchor=(1, 1), title_fontsize=16)

    plt.tight_layout()
    fusiongene_savename = fusiongene.replace("::","-")
    plt.savefig(f"{save_path}/plddt_dist_{fusiongene_savename}.png",dpi=300)

def make_all_favorite_fusion_pLDDT_plots(favorite_fusions,left_to_right=True):
    fusion_structure_data = pd.read_csv('processed_data/fusionpdb/FusionPDB_level2-3_cleaned_structure_info.csv')
    swissprot_top_alignments = pd.read_csv("../../data/blast/blast_outputs/swissprot_top_alignments.csv")
    fuson_db = pd.read_csv("../../data/fuson_db.csv")
    seq_id_dict = dict(zip(fuson_db['aa_seq'],fuson_db['seq_id']))
    fusion_structure_data['seq_id'] = fusion_structure_data['Fusion_Seq'].map(seq_id_dict)
    fusion_structure_data = pd.merge(
        fusion_structure_data,
        swissprot_top_alignments,
        on="seq_id",
        how="left"
    )
    for x in favorite_fusions:
        if left_to_right:
            plot_fusion_sequence_pLDDT_left_to_right(fusion_structure_data, x, save_path='processed_data/figures/fusion_disorder')
        else:
            plot_favorite_fusion_pLDDT_distribution(fusion_structure_data, x, save_path='processed_data/figures/fusion_disorder')
    
def prep_data_for_ht_disorder_comparison():
    ht_structure_data = pd.read_csv('processed_data/fusionpdb/heads_tails_structural_data.csv')
    fusion_structure_data = pd.read_csv('processed_data/fusionpdb/FusionPDB_level2-3_cleaned_structure_info.csv')
    fusion_heads_and_tails = pd.read_csv('processed_data/fusionpdb/fusion_heads_and_tails.csv')

    all_hts_with_structures = ht_structure_data['UniProtID'].unique().tolist()

    fuson_ht_db = pd.read_csv('../../data/blast/fuson_ht_db.csv')[['seq_id','aa_seq','fusiongenes','hgUniProt','tgUniProt']]

    merge = pd.merge(
        fuson_ht_db.rename(columns={'aa_seq':'Fusion_Seq'}),
        fusion_structure_data[['FusionGID', 'Fusion_Seq','Fusion_pLDDT','Fusion_AA_pLDDTs']],
        on='Fusion_Seq',
        how='right'
    )
    # now merge again
    merge['hgUniProt'] = merge['hgUniProt'].apply(lambda x: x.split(','))
    merge['tgUniProt'] = merge['tgUniProt'].apply(lambda x: x.split(','))
    merge = merge.explode('hgUniProt')
    merge = merge.explode('tgUniProt')
    merge = merge.loc[
        merge['hgUniProt'].isin(all_hts_with_structures) &
        merge['tgUniProt'].isin(all_hts_with_structures)
    ].reset_index(drop=True)

    merge = pd.merge(
        merge,
        ht_structure_data.rename(columns=
            {'UniProtID':'hgUniProt',
            'Avg pLDDT': 'hg_pLDDT',
            'All pLDDTs': 'hg_AA_pLDDTs',
            'Seq': 'hg_seq'}),
        on='hgUniProt',
        how='inner'
    )

    merge = pd.merge(
        merge,
        ht_structure_data.rename(columns=
            {'UniProtID':'tgUniProt',
            'Avg pLDDT': 'tg_pLDDT',
            'All pLDDTs': 'tg_AA_pLDDTs',
            'Seq': 'tg_seq'}),
        on='tgUniProt',
        how='inner'
    )
    merge = merge.loc[merge['hg_AA_pLDDTs'].notna()]
    merge = merge.loc[merge['tg_AA_pLDDTs'].notna()].reset_index(drop=True)

    # finally, calcualte label
    merge['hg_label'] = merge['hg_AA_pLDDTs'].apply(lambda x: x.split(','))
    merge['hg_label'] = merge['hg_label'].apply(lambda x: [float(y) for y in x])
    merge['hg_label'] = merge['hg_label'].apply(lambda x: [apply_plddt_thresh(y) for y in x])
    merge['hg_label'] = merge['hg_label'].apply(lambda x: ''.join(x))

    merge['tg_label'] = merge['tg_AA_pLDDTs'].apply(lambda x: x.split(','))
    merge['tg_label'] = merge['tg_label'].apply(lambda x: [float(y) for y in x])
    merge['tg_label'] = merge['tg_label'].apply(lambda x: [apply_plddt_thresh(y) for y in x])
    merge['tg_label'] = merge['tg_label'].apply(lambda x: ''.join(x))

    merge['fusion_label'] = merge['Fusion_AA_pLDDTs'].apply(lambda x: x.split(','))
    merge['fusion_label'] = merge['fusion_label'].apply(lambda x: [float(y) for y in x])
    merge['fusion_label'] = merge['fusion_label'].apply(lambda x: [apply_plddt_thresh(y) for y in x])
    merge['fusion_label'] = merge['fusion_label'].apply(lambda x: ''.join(x))

    return merge

def apply_plddt_thresh(y):
    if y < 68.8:
        return '1'
    else: 
        return '0'

def plot_fusion_stats_boxplots(data, save_path="fusion_disorder_boxplots.png"):
    set_font()
    # Create box plots
    plt.figure(figsize=(6, 5))
    # for ones that are 100% disordered, AUROC was NaN, so drop these
    box = plt.boxplot([data[col].dropna() for col in data.columns], labels=data.columns, patch_artist=True)

    # Set color of each box plot
    for patch in box['boxes']:
        patch.set_facecolor('#ff68b4')
        patch.set_edgecolor('#ff68b4')
    
    # Customize other elements if needed
    #for whisker in box['whiskers']:
        #whisker.set_color('#ff68b4')
    #for cap in box['caps']:
        #cap.set_color('#ff68b4')
    for median in box['medians']:
        median.set_color('black')
    # Add labels and title
    #plt.xlabel('Metrics')
    #plt.ylabel('Values')
    plt.title(f"Per-Residue Disorder (n={len(data)})",fontsize=22)
    plt.xticks(rotation=20,fontsize=22)
    plt.yticks(fontsize=22)

    # Show plot
    plt.tight_layout()
    plt.show()
    plt.savefig(save_path,dpi=300)

def plot_fusion_frac_disorder_r2(actual_values, predicted_values, save_path="fusion_pred_disorder_r2.png"):
    set_font()
    plt.figure(figsize=(6, 6))
    r2 = r2_score(actual_values, predicted_values)
    #sns.kdeplot(actual_values, label="Actual Values", shade=True)
    #sns.kdeplot(predicted_values, label="Predicted Values", shade=True)
    plt.scatter(actual_values, predicted_values, alpha=0.5, label=f"Predictions", color="#ff68b4")
    plt.plot([min(actual_values), max(actual_values)], [min(actual_values), max(actual_values)], 'k--', label='Ideal Fit')
    plt.text(0, 92, f"$R^2$={r2:.2f}", fontsize=32)
    # Adjusting font sizes and setting font properties
    plt.xlabel(f'AlphaFold-pLDDT',size=32)
    plt.ylabel(f'FusOn-pLM-Diso',size=32)
    plt.title(f"% Disordered (n={len(actual_values)})",size=32)
    plt.xticks(fontsize=24)
    plt.yticks(fontsize=24)
    #plt.xlabel("Values")
    #plt.ylabel("Density")
    #plt.title(f"Density Plot of Actual vs Predicted Values (R^2 = {r2:.2f})")
    plt.legend(prop={'size': 16})
    plt.tight_layout()
    plt.show()
    plt.savefig(save_path, dpi=300)
    
def main():
    set_font()
    #output_dir = "results/test"
    output_dir = "results/final"
    seq_label_dict = pd.read_csv('splits/test_df.csv')
    seq_ids_dict = dict(zip(seq_label_dict['Sequence'],seq_label_dict['IDs']))
    seq_label_dict = dict(zip(seq_label_dict['Sequence'],seq_label_dict['Label']))
    best_caid_model_results = pd.read_csv(f"{output_dir}/best_caid_model_results.csv")
    make_auroc_curve(results_dir=output_dir, 
                    seq_label_dict=seq_label_dict,
                    seq_ids_dict=seq_ids_dict,
                    path_to_results_of_interest="trained_models/fuson_plm/best/caid_hyperparam_screen_test_probs.csv", 
                    model_alias="FusOn-pLM", 
                    path_to_esm_results="trained_models/esm2_t33_650M_UR50D/best/caid_hyperparam_screen_test_probs.csv",
                    with_rankings=True)
        
    caid2_test_data = pd.read_csv(f"splits/splits.csv")
    caid2_test_data = caid2_test_data.loc[caid2_test_data['Split']=='Test']
    caid2_test_labels = caid2_test_data['Label'].tolist()
    caid2_test_ids = caid2_test_data['IDs'].tolist()
    # fusions, heads, and tails
    fusion_ht_data = prep_data_for_ht_disorder_comparison()
    os.makedirs("processed_data/figures",exist_ok=True)

    head_data = fusion_ht_data.drop_duplicates(['hg_seq']).reset_index(drop=True)
    head_labels = head_data['hg_label'].tolist()
    head_ids = head_data['hgUniProt'].tolist()
    tail_data = fusion_ht_data.drop_duplicates(['tg_seq']).reset_index(drop=True)
    tail_labels = tail_data['tg_label'].tolist()
    tail_ids = tail_data['tgUniProt'].tolist()
    fusion_data = fusion_ht_data.drop_duplicates(['Fusion_Seq']).reset_index(drop=True)
    fusion_labels = fusion_data['fusion_label'].tolist()
    fusion_ids = fusion_data['seq_id'].tolist()
    
    plt.rc('text', usetex=False)
    math_part = r"$n$"

    os.makedirs("processed_data/figures/histograms",exist_ok=True)
    plot_disorder_content_hist(caid2_test_labels, caid2_test_ids, title=f"CAID2 Disorder-NOX ({math_part}={len(caid2_test_labels):,})", color="black", savepath='processed_data/figures/histograms/disorder_nox_histogram.png')
    plot_disorder_content_hist(head_labels, head_ids, title=f"Head Proteins ({math_part}={len(head_labels):,})", color="#df8385", savepath='processed_data/figures/histograms/heads_histogram.png')
    plot_disorder_content_hist(tail_labels, tail_ids, title=f"Tail Proteins ({math_part}={len(tail_labels):,})", color="#6ea4da", savepath='processed_data/figures/histograms/tails_histogram.png')
    plot_disorder_content_hist(fusion_labels, fusion_ids, title=f"Fusion Oncoproteins ({math_part}={len(fusion_labels):,})", color="mediumpurple", savepath='processed_data/figures/histograms/fusions_histogram.png')
    
    os.makedirs("processed_data/figures/fusion_disorder",exist_ok=True)
    make_all_favorite_fusion_pLDDT_plots([
                                        "EWSR1::FLI1",
                                        "PAX3::FOXO1",
                                        "EML4::ALK",
                                        "SS18::SSX1"],
                                            left_to_right=True)

if __name__ == "__main__":
    main()