File size: 33,907 Bytes
bae913a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
import os
import numpy as np
import re
import pandas as pd
import requests

from fuson_plm.utils.logging import open_logfile, log_update
from fuson_plm.utils.constants import DELIMITERS, VALID_AAS
from fuson_plm.utils.data_cleaning import check_columns_for_listlike, find_invalid_chars

from fuson_plm.benchmarking.caid.scrape_fusionpdb import scrape_fusionpdb_level_2_3
from fuson_plm.benchmarking.caid.process_fusion_structures import process_fusions_and_hts
            
def download_fasta(uniprotid, includeIsoform, output_file):
    try:
        url = f"https://rest.uniprot.org/uniprotkb/search?format=fasta&includeIsoform={includeIsoform}&query=accession%3A{uniprotid}&size=500&sort=accession+asc"
        # Send a GET request to the URL
        response = requests.get(url)
        
        # Raise an exception if the request was unsuccessful
        response.raise_for_status()
        
        # Write the content to a file in text mode
        with open(output_file, 'a+') as file:
            file.write(response.text)
        
        log_update(f"FASTA file for {uniprotid} successfully downloaded and added to '{output_file}'")
    
    except requests.exceptions.RequestException as e:
        log_update(f"An error occurred: {e}")

# Test Sequences (CAID-2 Disorder-NOX)
def parse_caid_txt(fast_file):
    '''
    Parses correctly fasta-formatted text file with conditions:
    Line 1: ID
    Line 2: Sequence
    Line 3: Label
    '''

    seq_to_label = {}
    id_to_sequence = {}

    with open(fast_file, 'r') as file:
        label = None
        sequence = ""
        seq_id = None
        reading_sequence = False
        for line in file:
            line = line.strip()
            if line.startswith(">"):
                if label is not None and sequence:
                    seq_to_label[sequence] = (label, seq_id)
                seq_id = line[1:]  # Capture the ID without the '>'
                label = None
                sequence = ""
                reading_sequence = True
            elif reading_sequence:
                if all(c in "01-" for c in line):
                    label = line
                    reading_sequence = False
                else:
                    sequence += line
        if label is not None and sequence:
            seq_to_label[sequence] = (label, seq_id)

    return seq_to_label

def check_df_for_mismatched_labels(sd):
    log_update("\tChecking dataframe for mismatched sequences and labels...")
    counter=0
    for idx, row in sd.iterrows():
        seq = row['Sequence']
        label = row['Label']

        if len(seq) != len(label):
            counter+=1
            log_update(f"\t\tLength mismatch at index {idx}: sequence length = {len(seq)}, label length = {len(label)}")
            
    log_update(f"\t\tTotal mismatched lengths/labels: {counter}")


def process_caid2_disorder_nox_test(caid_path):
    """
    Processes the CAID-2_Disorder_NOX_Testing_Sequences.fasta file
    """
    log_update("Processing CAID-2-Disorder-NOX Testing Dataset")
    # Parse the fasta file
    caid_dict = parse_caid_txt(caid_path)
    
    # Gather the sequences
    caid_seqs = {}
    for k, (v, seq_id) in caid_dict.items():
        caid_seqs[seq_id] = (k, v)
    log_update(f"\tTotal sequences: {len(caid_seqs)}")

    # Form dataframe from processed data
    caid_df = pd.DataFrame({
        'ID': list(caid_seqs.keys()),
        'Sequence': [seq for seq, _ in caid_seqs.values()],
        'Label': [lbl for _, lbl in caid_seqs.values()],
        'Split': 'Test'
    })
    
    check_df_for_mismatched_labels(caid_df)
    return caid_df
    
# Training Squences (fldpnn and IDP-CRF)
# fldpnn Training Sequences
def parse_fldpnn_fasta(file_path):
    """
    Parse flDPnn_Training_Dataset.txt, where there are 5 sequence lines. We only want the first 

    >Disprot ID
    Amino acid sequence
    Experimental annotation for intrinsic disorder
    Experimental annotation for disordered protein binding
    Experimental annotation for disordered DNA binding
    Experimental annotation for disordered RNA binding
    Experimental annotation for disordered flexible linkers
    """
    sequences = []
    labels = []
    ids = []

    with open(file_path, 'r') as file:
        lines = file.readlines()

        seq_id = ""
        current_sequence = ""
        seen_label_lines = 0    # should go up to 5 for each
        current_labels = []
        is_label = False

        for line in lines:
            line = line.strip()
            if line.startswith('>'):
                if current_sequence and current_labels:
                    assert seen_label_lines==5  # we should've seen 5 labels, otherwise something is wrong
                    ids.append(seq_id)
                    sequences.append(current_sequence)
                    labels.append(''.join(current_labels))
                seq_id = line[1:]  # Capture the ID without the '>'
                current_sequence = ""
                current_labels = []
                is_label = False
                seen_label_lines = 0
            elif re.match('^[A-Z]+$', line):  # Sequence lines
                current_sequence += line
            else:  # Label lines
                seen_label_lines+=1
                if seen_label_lines==1:
                    current_labels.append(line)
                is_label = True

        # Add the last sequence and labels
        if current_sequence and current_labels:
            sequences.append(current_sequence)
            labels.append(''.join(current_labels))
            ids.append(seq_id)

    return ids, sequences, labels

def parse_idp_crf_fasta(file_path):
    sequences = []
    labels = []
    ids = []

    with open(file_path, 'r') as file:
        lines = file.readlines()

        seq_id = ""
        current_sequence = ""
        current_labels = []
        is_label = False

        for line in lines:
            line = line.strip()
            if line.startswith('>'):
                if current_sequence and current_labels:
                    ids.append(seq_id)
                    sequences.append(current_sequence)
                    labels.append(''.join(current_labels))
                seq_id = line[1:]  # Capture the ID without the '>'
                current_sequence = ""
                current_labels = []
                is_label = False
            elif re.match('^[A-Z]+$', line):  # Sequence lines
                current_sequence += line
            else:  # Label lines
                current_labels.append(line)
                is_label = True

        # Add the last sequence and labels
        if current_sequence and current_labels:
            sequences.append(current_sequence)
            labels.append(''.join(current_labels))
            ids.append(seq_id)

    return ids, sequences, labels

def process_fldpnn(fldpnn_path, split="training"):
    """
    Process the fldpnn_Training_Dataset
    """
    log_update(f"\nProcessing flDPnn {split} dataset")
    # Parse fasta
    fldpnn_ids, fldpnn_seqs, fldpnn_labels = parse_fldpnn_fasta(fldpnn_path)

    # Collect cleaned labels
    cleaned_fldpnn_ids = []
    cleaned_fldpnn_labels = []
    for i in range(len(fldpnn_seqs)):
        seq_len = len(fldpnn_seqs[i])
        label = fldpnn_labels[i]      # Should only be the first set of labels
        id = fldpnn_ids[i]
        cleaned_fldpnn_labels.append(label)
        
    log_update(f"\tTotal labels: {len(cleaned_fldpnn_labels)}, total sequences: {len(fldpnn_seqs)},total IDs: {len(fldpnn_ids)}")

    fldpnn_df = pd.DataFrame({'Sequence': fldpnn_seqs, 
                              'Label': cleaned_fldpnn_labels, 
                              "Split": "Train" if split=="training" else "Val",
                              "ID": fldpnn_ids})
    check_df_for_mismatched_labels(fldpnn_df)
    
    return fldpnn_df

def combine_fldpnn_train_val(fldpnn_train_df, fldpnn_val_df):
    log_update("\nJoining flDPnn train and val sets into one training set for CAID predictor")
    combined = pd.concat([fldpnn_train_df,fldpnn_val_df])
    
    # check for duplicates
    duplicates = combined[combined['Sequence'].duplicated()]['Sequence'].unique().tolist()
    n_rows_with_duplicates = len(combined[combined['Sequence'].isin(duplicates)])
    log_update(f"\t{len(duplicates)} sequences in both train and val datasets, corresponding to {n_rows_with_duplicates} rows")
    for dup in duplicates:
        train_id = combined.loc[(combined['Sequence']==dup) & (combined['Split']=='Train')]['ID'].reset_index(drop=True).iloc[0]
        val_id = combined.loc[(combined['Sequence']==dup) & (combined['Split']=='Val')]['ID'].reset_index(drop=True).iloc[0]
        train_label = combined.loc[(combined['Sequence']==dup) & (combined['Split']=='Train')]['Label'].reset_index(drop=True).iloc[0]
        val_label = combined.loc[(combined['Sequence']==dup) & (combined['Split']=='Val')]['Label'].reset_index(drop=True).iloc[0]
        log_update(f"\t\tTrain ID: {train_id}\tVal ID: {val_id}\tSame labels: {train_label==val_label}\tSequence: {dup}")
        
        # if the labels are not equal, get rid of it completely. Otherwise just get rid of the val case
        if not(train_label==val_label):
            log_update(f"\t\t\tSince labels are not equal, removing sequence completely")
            combined = combined[combined['Sequence']!=dup].reset_index(drop=True)
        else:
            log_update(f"\t\t\tSince labels are equal, removing validation copy")
            combined = combined.loc[(combined['Sequence']!=dup) |
                                    ((combined['Sequence']==dup) & (combined['Split']=='Train'))]
    # drop duplicates
    log_update(f"\tLength of joined flDPnn data: {len(combined)}")
    
    return combined
    
def process_idp_crf_train(idp_crf_train_path):
    """
    Process IDP-CRF_Training_Dataset
    
    Args:
        idp_crf_train_path
    """
    log_update("\nProcessing IDP-CRF training dataset")
    # Parse the fasta, get sequences and labels
    idp_crf_ids, idp_crf_seqs, idp_crf_labels = parse_idp_crf_fasta(idp_crf_train_path)
    log_update(f"\tTotal labels: {len(idp_crf_labels)}, total sequences: {len(idp_crf_seqs)}, total IDs: {len(idp_crf_ids)}")

    # Clean the labels
    cleaned_idp_ids, cleaned_idp_seqs, cleaned_idp_labels = [], [], []
    counter = 0
    log_update("\tCleaning labels and counting length-mismatched examples...")
    for i, label in enumerate(idp_crf_labels):
        # If length of sequence and labels doesn't match, log it
        if len(idp_crf_seqs[i]) != len(idp_crf_labels[i]):
            log_update(f"\t\tLength mismatch at index {i}: sequence length = {len(idp_crf_seqs[i])}, label length = {len(idp_crf_labels[i])}")
            
            counter += 1
        # Else, "clean" the labels by mapping them to ints and converting them to a list 
        else:
            cleaned_idp_ids.append(idp_crf_ids[i])
            cleaned_idp_labels.append(label)
            cleaned_idp_seqs.append(idp_crf_seqs[i])

    log_update(f"\t\tMismatched lengths/labels: {counter}")

    # Confirm that final database has no mismatched labels
    idp_crf_df = pd.DataFrame({'Sequence': cleaned_idp_seqs, 
                               'Label': cleaned_idp_labels, 
                               "Split": "Train",
                               "ID": cleaned_idp_ids})
    check_df_for_mismatched_labels(idp_crf_df)
            
    return idp_crf_df

def find_agreeing_labels(row, lab1="", lab2=""):
    """
    If there's only one possible label, return that label. If the two labels disagree, return np.nan
    """
    val1 = row[lab1]
    val2 = row[lab2]
    
    # If one of them is nan, then they won't match anyway, so return True because there is no conflict
    if type(val1)==float and np.isnan(val1):
        return val2
    elif type(val2)==float and np.isnan(val2):
        return val1
    else:
        if val1==val2:
            return val1
        else:
            return np.nan

def get_unique_ids(row):
    source_to_id = {
        "IDP-CRF": row["IDP-CRF ID"],
        "flDPnn": row["flDPnn ID"],
        "CAID-2_Disorder_NOX": row["CAID-2_Disorder_NOX ID"]
    }
    
    all_sources = row["Source"].split(",")
    all_ids = []
    # they are already in the desired order so just iterate through them
    for source in all_sources:
        candidate_id = source_to_id[source]
        if not(candidate_id in all_ids):
            all_ids.append(candidate_id)
    
    return ",".join(all_ids)

def parse_caid2_results(processed_caid2_df,lines):
    # iterate through the lines
    all_caid2_disorder_nox_ids = processed_caid2_df['ID'].tolist()
    all_caid2_disorder_nox_sequences = processed_caid2_df['Sequence'].tolist()
    
    cur_id = None
    results = {
        }
    for i, line in enumerate(lines):
        # If line starts with >, that means we have a new ID
        if line[0]==">":
            # If we are currently on a different cur_id, finish that one out
            if not(cur_id is None):
                results[cur_id]['prob_1'] = ",".join(results[cur_id]['prob_1'])
                results[cur_id]['pred_labels'] = ",".join(results[cur_id]['pred_labels'])
                sequence = results[cur_id]['sequence']
                # get the true labels from the CAID2 dataset - IF POSSIBLE
                if (cur_id not in all_caid2_disorder_nox_ids) and (sequence not in all_caid2_disorder_nox_sequences):
                    results[cur_id]['labels'] = np.nan
                else: 
                    true_labels = processed_caid2_df.loc[
                        processed_caid2_df['ID']==cur_id,'Label'
                    ].item()
                    true_labels = ",".join(list(true_labels))
                    results[cur_id]['labels'] = true_labels
            # Now process the new one
            cur_id = line[1::].strip('\t').strip('\n')
            results[cur_id] = {
                'sequence': '',
                'prob_1': [],
                'pred_labels': []
            }
        # if cur id is not None
        else:
            # if we have a cur id to process, process it
            if not(cur_id is None):
                # Extract the information - not every .caid file as predicted labels!!
                lsplit =  line.strip('\n').split('\t')
                label=''
                idx, aa, prob = lsplit[0], lsplit[1], lsplit[2]
                if len(lsplit)==4: label=lsplit[3]
                # Add to dict
                results[cur_id]['sequence']+=aa
                results[cur_id]['prob_1'].append(prob)
                results[cur_id]['pred_labels'].append(label)

            # if we're on the last line, combine
            if i==len(lines)-1:
                results[cur_id]['prob_1'] = ",".join(results[cur_id]['prob_1'])
                results[cur_id]['pred_labels'] = ",".join(results[cur_id]['pred_labels'])
                sequence = results[cur_id]['sequence']
                # get the true labels from the CAID2 dataset - IF POSSIBLE
                if (cur_id not in all_caid2_disorder_nox_ids) and (sequence not in all_caid2_disorder_nox_sequences):
                    results[cur_id]['labels'] = np.nan
                else: 
                    true_labels = processed_caid2_df.loc[
                        processed_caid2_df['ID']==cur_id,'Label'
                    ].item()
                    true_labels = ",".join(list(true_labels))
                    results[cur_id]['labels'] = true_labels
                
    df = pd.DataFrame.from_dict(results,orient='index').reset_index().rename(columns={'index':'seq_id'})
    df = df.loc[df['labels'].notna()].reset_index(drop=True)
    # drop pred_labels if it's empty
    if set(','.join(df['pred_labels'].tolist()))=={','}: 
        df = df.drop(columns=['pred_labels'])
        log_update(f"\t\tno predicted labels provided for this dataset; only probabilities")
    log_update(f"\t\t{len(df)}/{len(all_caid2_disorder_nox_sequences)} total CAID2-Nox sequences")
    return df
    
def parse_all_caid2_results(processed_caid2_df, caid_raw_folder="raw_data/caid2_competition_results"):
    save_dir ="processed_data/caid2_competition_results" 
    os.makedirs(save_dir,exist_ok=True)
    
    log_update(f"\nExtracting all CAID-2_Disorder_NOX results from CAID2 competition results files...")
    all_caid_files = os.listdir(caid_raw_folder)
    for caid_file in all_caid_files:
        # figure out how to parse .caid files 
        with open(f"{caid_raw_folder}/{caid_file}", "r") as f:
            lines = f.readlines()
            log_update(f"\t{caid_file}:")
            results_df = parse_caid2_results(processed_caid2_df,lines)
            # save it
            competitor_name = caid_file.split('.caid')[0]
            results_df.to_csv(f"{save_dir}/{competitor_name}_CAID-2_Disorder_NOX.csv",index=False)
    
def make_train_df(fldpnn_df, idp_crf_df):
    """
    Make training dataframe by concatenating the two processed training sets. 
    """
    # Add source columns so we can track where each sequence came from
    idp_crf_df = idp_crf_df.rename(columns={'Label':'IDP-CRF Label', 'ID': 'IDP-CRF ID'}).drop(columns=['Split'])
    fldpnn_df = fldpnn_df.rename(columns={'Label':'flDPnn Label', 'ID': 'flDPnn ID'}).drop(columns=['Split'])
    ########### Combine fldpnn and idp crf
    # Join
    log_update("\nJoining flDPnn and IDP-CRF data by sequence make unified training set")
    train_df = pd.merge(idp_crf_df, 
                        fldpnn_df,
                        on='Sequence',
                        how='outer',
                        indicator=True)
    train_df["Split"] = ["Train"]*len(train_df)
    # Map _merge column to desired labels
    train_df['Source'] = train_df['_merge'].map({
        'left_only': 'IDP-CRF',
        'right_only': 'flDPnn',
        'both': 'IDP-CRF,flDPnn'
    })
    train_df = train_df.drop(columns=["_merge"])
    log_update(f"\tIDP-CRF dataset size: {len(idp_crf_df)}\n\tfLDpnn dataset size: {len(fldpnn_df)}\n\tinitial train dataset size: {len(train_df)}")

    # Check for duplicate sequences
    log_update(f"\tChecking for sequences in both datasets...")
    duplicates = train_df[train_df["Source"].str.contains(",")]['Sequence'].unique().tolist()
    n_rows_with_duplicates = len(train_df[train_df['Sequence'].isin(duplicates)])
    log_update(f"\t\t{len(duplicates)} sequences in both datasets, corresponding to {n_rows_with_duplicates} rows")

    # Check for consistency between IDP-CRF Label and flDPnn label
    train_df["Label"] = train_df.apply(lambda row: find_agreeing_labels(row,lab1="IDP-CRF Label",lab2="flDPnn Label"),axis=1)
    train_df["No Label Conflicts"]= ~train_df["Label"].isna()
    log_update(f"\tChecked for label inconsistencies between IDP-CRF and flDPnn on the same sequence:")
    match_str = train_df['No Label Conflicts'].value_counts().reset_index().rename(columns={'index': 'No Label Conflicts','No Label Conflicts': 'count'}).to_string(index=False)
    match_str = "\t\t" + match_str.replace("\n","\n\t\t")
    log_update(match_str)
    
    # Dropping rows where labels don't match 
    #train_df[train_df['No Label Conflicts']==False][['Sequence','Split','IDP-CRF ID','flDPnn ID','IDP-CRF Label','flDPnn Label','No Label Conflicts']].to_csv('mismatch.csv',index=False)
    # Drop row with known conflict with disprot
    conflict_seq="MASREEEQRETTPERGRGAARRPPTMEDVSSPSPSPPPPRAPPKKRMRRRIESEDEEDSSQDALVPRTPSPRPSTSAADLAIAPKKKKKRPSPKPERPPSPEVIVDSEEEREDVALQMVGFSNPPVLIKHGKGGKRTVRRLNEDDPVARGMRTQEEEEEPSEAESEITVMNPLSVPIVSAWEKGMEAARALMDKYHVDNDLKANFKLLPDQVEALAAVCKTWLNEEHRGLQLTFTSKKTFVTMMGRFLQAYLQSFAEVTYKHHEPTGCALWLHRCAEIEGELKCLHGSIMINKEHVIEMDVTSENGQRALKEQSSKAKIVKNRWGRNVVQISNTDARCCVHDAACPANQFSGKSCGMFFSEGAKAQVAFKQIKAFMQALYPNAQTGHGHLLMPLRCECNSKPGHAPFLGRQLPKLTPFALSNAEDLDADLISDKSVLASVHHPALIVFQCCNPVYRNSRAQGGGPNCDFKISAPDLLNALVMVRSLWSENFTELPRMVVPEFKWSTKHQYRNVSLPVAHSDARQNPFDF"
    train_df = train_df.loc[train_df['Sequence']!=conflict_seq].reset_index(drop=True)
    log_update(f"\tDropping rows with label mismatch or known error (total={len(train_df[train_df['No Label Conflicts']==False])+1})")
    train_df = train_df.loc[train_df['No Label Conflicts']].reset_index(drop=True)
    
    # Make a new label column
    train_df = train_df.drop(columns=["IDP-CRF Label","flDPnn Label"])
    log_update(f"\t\tNew dataset size: {len(train_df)}")
    
    ######## Final checks
    # Check for any invalid sequences or invalid characters
    cols_of_interest =  ['Sequence','Split','Label','IDP-CRF ID','flDPnn ID']
    listlike_dict = check_columns_for_listlike(train_df, cols_of_interest, DELIMITERS)
    
    # Check for invalid characters
    train_df['invalid_chars'] = train_df['Sequence'].apply(lambda x: find_invalid_chars(x, VALID_AAS))
    train_df[train_df['invalid_chars'].str.len()>0].sort_values(by='Sequence')
    all_invalid_chars = set().union(*train_df['invalid_chars'])
    log_update(f"\tchecking for invalid characters...\n\t\tset of all invalid characters discovered within train_df: {all_invalid_chars}")

    # Dropping rows where invalid characters(should be none)
    log_update(f"\tDropping rows with invalid characters (total={len(train_df[train_df['invalid_chars'].str.len()>0])})")
    train_df = train_df.loc[train_df['invalid_chars'].str.len()==0].reset_index(drop=True)
    train_df = train_df.drop(columns=['invalid_chars'])
    log_update(f"\t\tNew dataset size: {len(train_df)}")
    
    source_str = train_df['Source'].value_counts().reset_index().rename(columns={'index': 'Source','Source': 'count'}).to_string(index=False)
    source_str = "\t\t" + source_str.replace("\n","\n\t\t")
    log_update(f"\tSources:\n{source_str}")
    return train_df

def make_train_and_test_df(train_df, test_df):
    """
    Combine the training and testing dataframe into one
    """
    log_update("\nMaking final dataframe with train and test splits")
    # Concatenate proposed train and test
    test_df["Source"] = ["CAID-2_Disorder_NOX"]*len(test_df)
    splits_df = pd.concat([train_df.drop(columns=['No Label Conflicts']),
                           test_df.rename(columns={'ID':'CAID-2_Disorder_NOX ID', 'Label': 'CAID-2_Disorder_NOX Label'})])
    split_str = splits_df['Split'].value_counts().reset_index().rename(columns={'index': 'Split','Split': 'count'}).to_string(index=False)
    split_str = "\t\t" + split_str.replace("\n","\n\t\t")
    log_update(f"\tTrain dataset size: {len(train_df)}\n\tTest dataset size: {len(test_df)}\n\tinitial combined dataset size: {len(splits_df)}")
    
    # Check for duplicates - if we find any, REMOVE them from train and keep them in test
    duplicates = splits_df[splits_df.duplicated('Sequence')]['Sequence'].unique().tolist()
    n_rows_with_duplicates = len(splits_df[splits_df['Sequence'].isin(duplicates)])
    log_update(f"\t\t{len(duplicates)} duplicated sequences, corresponding to {n_rows_with_duplicates} rows")
    for i, dup in enumerate(duplicates):
        fldpnn_id = splits_df.loc[(splits_df['Sequence']==dup)&(splits_df['Split']=='Train')]['flDPnn ID'].item()
        idp_crf_id = splits_df.loc[(splits_df['Sequence']==dup)&(splits_df['Split']=='Train')]['IDP-CRF ID'].item()
        caid2_disorder_nox_id = splits_df.loc[(splits_df['Sequence']==dup)&(splits_df['Split']=='Test')]['CAID-2_Disorder_NOX ID'].item()
        log_update(f"\t\t\t{i+1}: flDPnn ID: {fldpnn_id}\tIDP-CRF ID: {idp_crf_id}\tCAID-2_Disorder_NOX ID: {caid2_disorder_nox_id}\n\t\t\t\tSequence: {dup}")
    # remove from train and keep in test
    splits_df = splits_df.loc[
        (~splits_df['Sequence'].isin(duplicates)) |       # Either the sequence is NOT duplicated, or
        ((splits_df['Sequence'].isin(duplicates)) & (splits_df['Split']=='Test'))     # Sequence is duplicated, and it's in test set
    ].reset_index(drop=True)
    split_str = splits_df['Split'].value_counts().reset_index().rename(columns={'index': 'Split','Split': 'count'}).to_string(index=False)
    split_str = "\t\t" + split_str.replace("\n","\n\t\t")
    log_update(f"\tRemoved duplicate sequences from training split, kept in test split\n\t\tNew dataset size: {len(splits_df)}\n\n{split_str}")
    
    # Everything in the train set should have a label; nothing in the test set should
    assert splits_df[splits_df["Label"].isna()]["Split"].value_counts().reset_index()['index'].tolist()==['Test']
    splits_df.loc[
        splits_df["Split"]=="Test","Label"
    ] = splits_df.loc[
        splits_df["Split"]=="Test","CAID-2_Disorder_NOX Label"
    ] 
    splits_df = splits_df.drop(columns=["CAID-2_Disorder_NOX Label"])
    # Make sure there are no na's in label
    assert len(splits_df[splits_df["Label"].isna()])==0
    
    # Print out distribution of sources
    source_str = splits_df['Source'].value_counts().reset_index().rename(columns={'index': 'Source','Source': 'count'}).to_string(index=False)
    source_str = "\t\t" + source_str.replace("\n","\n\t\t")
    total_sources = sum(splits_df['Source'].value_counts().reset_index()['Source'])
    assert total_sources == len(splits_df)
    log_update(f"\n\tSource distribution:\n{source_str}\n\n\t\t\t\t\t\tSum:  {total_sources}")
    
    # Print largest and smallest seq len in each set
    longest_train = max(splits_df[splits_df['Split']=='Train']['Sequence'].apply(lambda x: len(x)).tolist())
    shortest_train = min(splits_df[splits_df['Split']=='Train']['Sequence'].apply(lambda x: len(x)).tolist())
    longest_test = max(splits_df[splits_df['Split']=='Test']['Sequence'].apply(lambda x: len(x)).tolist())
    shortest_test = min(splits_df[splits_df['Split']=='Test']['Sequence'].apply(lambda x: len(x)).tolist())
    log_update(f"\n\tLength distributions...\n\t\tTrain: max={longest_train}\tmin={shortest_train}\n\t\tTest: max={longest_test}\tmin={shortest_test}")
    
    # Consolidate the IDs a bit 
    splits_df["IDs"] = splits_df.apply(lambda row: get_unique_ids(row),axis=1)
    assert len(splits_df[splits_df["IDs"].isna()])==0
    n_different_ids = len(splits_df.loc[splits_df["IDs"].str.contains(",")])
    log_update(f"\n\tProvided comma-separated IDs in same listed order as Source\n\t\t- train: IDP-CRF first, flDPnn second ({n_different_ids} seqs have multiple distinct IDs)\n\t\t- test: CAID-2_Disorder_NOX")
    
    # Keep only desired columns
    splits_df = splits_df[[
        'Sequence','IDs','Split','Source','Label'
    ]]
    
    return splits_df

def main():
    with open_logfile("data_cleaning_log.txt"):
        rawdata_train_test_path = "raw_data/caid2_train_and_test_data"
        # make directory to save processed data
        processeddata_path = "processed_data"
        splits_path = "splits"
        os.makedirs(processeddata_path,exist_ok=True)
        os.makedirs(splits_path,exist_ok=True)
        
        # Process CAID-2_Disorder_NOX_Testing_Sequences dataset from fasta file
        caid_path = f"{rawdata_train_test_path}/CAID-2_Disorder_NOX_Testing_Sequences.fasta"
        caid_df = process_caid2_disorder_nox_test(caid_path)
        caid_df.to_csv(f"{processeddata_path}/CAID-2_Disorder_NOX_Processed.csv", index=False)
        
        # Process fldpnn Training and Validation Datasets
        fldpnn_train_path = f"{rawdata_train_test_path}/flDPnn_Training_Dataset.txt"
        fldpnn_val_path = f"{rawdata_train_test_path}/flDPnn_Validation_Annotation.txt"
        fldpnn_train_df = process_fldpnn(fldpnn_train_path, split="training")
        fldpnn_val_df = process_fldpnn(fldpnn_val_path, split="validation")
        fldpnn_train_df.to_csv(f"{processeddata_path}/flDPnn_Training_Dataset.csv", index=False)
        fldpnn_val_df.to_csv(f"{processeddata_path}/flDPnn_Validation_Dataset.csv", index=False)
        # Combine train and val
        fldpnn_df = combine_fldpnn_train_val(fldpnn_train_df, fldpnn_val_df)
        
        # Process IDP-CRF_Training_Dataset
        idp_crf_train_path = f"{rawdata_train_test_path}/IDP-CRF_Training_Dataset.txt"
        idp_crf_df= process_idp_crf_train(idp_crf_train_path)
        idp_crf_df.to_csv(f"{processeddata_path}/IDP-CRF_Training_Dataset.csv", index=False)
        
        # Merge
        train_df = make_train_df(fldpnn_df, idp_crf_df)
        
        # Make a full splits file 
        splits_df = make_train_and_test_df(train_df, caid_df)
        final_train_df = splits_df.loc[splits_df['Split']=='Train'].reset_index(drop=True)
        final_test_df = splits_df.loc[splits_df['Split']=='Test'].reset_index(drop=True)
        
        # Save final files
        final_train_df.to_csv(f"{splits_path}/train_df.csv", index=False)
        final_test_df.to_csv(f"{splits_path}/test_df.csv", index=False)
    
        # Process the caid competition results and save them in a more accessible format
        processed_caid2_df = pd.read_csv(f"{processeddata_path}/CAID-2_Disorder_NOX_Processed.csv")
        parse_all_caid2_results(processed_caid2_df)
        
        # Process data for visualizing fusion oncoproteins
        # Scrape FusionPDB
        scrape_fusionpdb_level_2_3()
        # Process the structures that we downloaded from scraping
        process_fusions_and_hts()
        
        # Now, figure out which structures are in the test set and isolate those for benchmarking in splits/fusion_bench_df.csv
        fusion_test_set = pd.read_csv("../../data/splits/test_df.csv")
        # columns are sequence, member length, snp_probabilities
        fusion_test_set = set(fusion_test_set['sequence'].tolist())
        log_update(f"\nFinding level 2 and 3 fusion structures that are in the FusOn-pLM test set...\n\tTest set size: {len(fusion_test_set)} seqs")
        level_2_3_info = pd.read_csv('processed_data/fusionpdb/FusionPDB_level2-3_cleaned_structure_info.csv')
        # there are duplicate sequences in here so drop duplicates randomly
        level_2_3_seqs = level_2_3_info.drop_duplicates('Fusion_Seq').reset_index(drop=True)
        level_2_3_seqs = set(level_2_3_seqs.loc[
            level_2_3_info['Fusion_pLDDT'].notna()  # make sure we've got a structure
            ]['Fusion_Seq'].tolist())
        # if it has a structure, it's in the test set, and it's not in the caid train set, we can benchmark with it
        test_benchmark_seqs = fusion_test_set.intersection(level_2_3_seqs)   
        log_update(f"\tTotal fusion proteins in the FusOn-pLM test set: {len(test_benchmark_seqs)}")
        caid_train_set = set(pd.read_csv('splits/train_df.csv')['Sequence'].tolist())
        test_benchmark_seqs = test_benchmark_seqs.difference(caid_train_set)    # subtract off the caid train set to be sure
        log_update(f"\tTotal fusion proteins in the FusOn-pLM test set and NOT in the CAID train set: {len(test_benchmark_seqs)}")
        
        # Finally, make a dataframe structured like train_df and test_df. Columns are: Sequence,IDs,Split,Source,Label
        # Let's make the IDs FusionGID
        test_benchmark_df = pd.DataFrame(
            data = {'Sequence': list(test_benchmark_seqs)}

        )
        seq_id_dict = dict(zip(level_2_3_info['Fusion_Seq'],level_2_3_info['FusionGID']))
        seq_plddts_dict = dict(zip(level_2_3_info['Fusion_Seq'],level_2_3_info['Fusion_AA_pLDDTs']))
        test_benchmark_df['IDs'] = test_benchmark_df['Sequence'].map(seq_id_dict)
        test_benchmark_df['Split'] = ['Fusion_Benchmark']*len(test_benchmark_df)
        test_benchmark_df['Source'] = ['FusionPDB_AlphaFold2']*len(test_benchmark_df)
        test_benchmark_df['Label'] = test_benchmark_df['Sequence'].map(seq_plddts_dict)
        # convert label to 1 or 0
        test_benchmark_df['Label'] = test_benchmark_df['Label'].apply(lambda x: x.split(","))
        test_benchmark_df['Label'] = test_benchmark_df['Label'].apply(lambda x: [float(y) for y in x])    # make it a float list of pLDDTs
        test_benchmark_df['Label'] = test_benchmark_df['Label'].apply(lambda x: ['1' if y < 68.8 else '0' for y in x])    # disordered if pLDDT is < 68.8, accoridng to AlphaFold-pLDDT published threshold
        test_benchmark_df['Label'] = test_benchmark_df['Label'].apply(lambda x: ''.join(x)) # change ['1','1','0''] to '110'
        
        # check lengths
        test_benchmark_df['SeqLen'] = test_benchmark_df['Sequence'].apply(lambda x: len(x))
        test_benchmark_df['LabelLen'] = test_benchmark_df['Label'].apply(lambda x: len(x))
        log_update(f"\tAll seq lengths and label lengths match: {(test_benchmark_df['SeqLen']==test_benchmark_df['LabelLen']).all()}")
        test_benchmark_df = test_benchmark_df.drop(columns=['SeqLen','LabelLen'])
        
        # convert to string
        test_benchmark_df_str = test_benchmark_df.head(10)
        test_benchmark_df_str['Sequence'] = test_benchmark_df_str['Sequence'].apply(lambda x: x[0:10]+'...')
        test_benchmark_df_str['Label'] = test_benchmark_df_str['Label'].apply(lambda x: x[0:10]+'...')
        test_benchmark_df_str = test_benchmark_df_str.to_string(index=False)
        test_benchmark_df_str = "\t" + test_benchmark_df_str.replace("\n","\n\t")
        log_update(f"\nPreview of benchmarking set:\n{test_benchmark_df_str}")
        test_benchmark_df.to_csv('splits/fusion_bench_df.csv',index=False)
        
        # Add the benchmarking sequences to split 
        log_update(f"\nAdding benchmarking sequences to splits_df.csv:\n\tLength before adding bench seqs: {len(splits_df)}")
        splits_df = pd.concat([splits_df,test_benchmark_df])
        log_update(f"\tLength after adding bench seqs: {len(splits_df)}")
        split_str = splits_df['Split'].value_counts().reset_index().rename(columns={'index': 'Split','Split': 'count'}).to_string(index=False)
        split_str = "\t" + split_str.replace("\n","\n\t")
        log_update(f"Distribution among splits:\n{split_str}")
        splits_df.to_csv(f"{splits_path}/splits.csv",index=False)

if __name__ == "__main__":
    main()