File size: 38,220 Bytes
bae913a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
# Process fusion structures and the structures of their head and tail proteins for pLDDTs 

import requests
import json
import pandas as pd
import numpy as np

import requests
import re
import os
import shutil

from Bio.PDB import MMCIFParser
import Bio.PDB as PDB
from Bio import pairwise2
from Bio.pairwise2 import format_alignment
from bs4 import BeautifulSoup
import pdb

from fuson_plm.utils.logging import log_update, open_logfile

#@markdown Define AlphaFoldStructure class
class AlphaFoldStructure:
    '''
    This class processes an mmCIF file, either uploaded or downloaded from the AlphaFold2 database, to provide comprehensive information.
    '''
    def __init__(self, fold_path=None, uniprot_to_download=None, uniprot_output_dir= None, secondary_structure_types=None):
        # If the user provided a PDB path, convert their file to mmcif. Isolate the suffix
        if fold_path is not None:
          fold_fname = fold_path.split('/')[-1]
          prefix, suffix = fold_fname.split('.')

          if suffix == 'pdb': # convert to cif
            # make a directory for converted cif files
            conversion_path = 'mmcif_converted_files'
            if not(os.path.exists(conversion_path)):
              os.makedirs(conversion_path)

            fold_path = self.__convert_pdb_to_mmcif__(fold_path, f'{conversion_path}/{prefix}.cif')

        self.file_path = fold_path

        # If user provided a uniprot ID to download, download it and save it as the file path so it can be processed
        if uniprot_to_download is not None:
            if fold_path is not None:
              log_update("WARNING: both a fold_path and a uniprot_to_download were provided. Running default: downloading the CIF file for provided UniProt ID.")
            self.file_path = self.__download_mmCIF(uniprot_to_download, output_path=uniprot_output_dir)

        # Either they provide acceptable secondary structure types, or query the internet for them
        if secondary_structure_types is None:
          self.secondary_structure_types = self.__pull_secondary_structure_types()
        else:
          self.secondary_structure_types = secondary_structure_types

        # If there's a CIF file, initialize the object
        if self.file_path:
          self.cif_lines = self.__parse_cif()
          self.secondary_structures = self.__extract_secondary_structures()
          self.structure_dict = self.__calc_pLDDTs()
          self.sequence = self.structure_dict['seq']
          self.plddts = self.structure_dict['res_pLDDTs']
          self.avg_pLDDT = self.structure_dict['avg_pLDDT']
          self.residues_df = self.__create_residues_summary_dataframe()
          self.secondary_structures_df = self.__create_secondary_structures_summary_dataframe()
        # Otherwise, print an error.
        else:
          log_update("ERROR: structure could not be created. No CIF file found.")

    def __convert_pdb_to_mmcif__(self, pdb_filename, mmcif_filename):
      parser = PDB.PDBParser()
      structure = parser.get_structure('structure', pdb_filename)

      io = PDB.MMCIFIO()
      io.set_structure(structure)
      io.save(mmcif_filename)
      return mmcif_filename

    def __download_mmCIF(self, uniprot_id, output_path=None):
        '''
        Download mmCIF file with provided uniprot_id and optional output_path for the downloaded file.

        Return: path to downloaded file if successful, None otherwise
        '''
        full_file_name = f"AF-{uniprot_id}-F1-model_v4.cif"     # define file name that will be found on the AlphaFold2 database.
        # if output path not provided, just save locally under full_file_name
        if output_path is None:
            output_path = full_file_name
        else:
            output_path = f"{output_path}/{full_file_name}"

        # request the URL for the file
        url = f"https://alphafold.ebi.ac.uk/files/{full_file_name}"
        response = requests.get(url)

        if response.status_code == 200:
            with open(output_path, 'wb') as file:
                file.write(response.content)
            #log_update(f"File downloaded successfully and saved as {output_path}")
        else:
            log_update(f"Failed to download file. Status code: {response.status_code}")
            return None

        return output_path

    def __pull_secondary_structure_types(self):
        '''
        Pull a dictionary of secondary structure types and their descriptions from the PDB mmCIF website (necessary for annotating the CIF file)
        Only called if the user does not provide such a dictionary themselves.
        '''

        # request the .html tree from the website with all secondary structure terms
        url = "https://mmcif.wwpdb.org/dictionaries/mmcif_pdbx_v50.dic/Items/_struct_conf_type.id.html"
        response = requests.get(url)

        if response.status_code != 200:
            raise Exception("Failed to retrieve mmCIF dictionary")

        # Parse the response content
        soup = BeautifulSoup(response.content, 'html.parser')

        # Debug: Print the soup to understand the structure
        # log_update(soup.prettify())
        # write the prettified soup to a txt file
        #with open('mmcif_dictionary.txt', 'w') as f:
        #    f.write(soup.prettify())

        # Find the h4 header with the class "panel-title" and text "Controlled Vocabulary"
        header = soup.find('h4', class_='panel-title')
        if header is None or 'Controlled Vocabulary' not in header.text:
            raise Exception("Could not find the 'Controlled Vocabulary' header")

        # Debug: Print the found header
        #log_update(f"Found header: {header}")

        # The table should be the next sibling of the header
        table = header.find_next('table')
        if table is None:
            raise Exception("Could not find the table following the 'Controlled Vocabulary' header")

        # Debug: Print the found table (only the opening <table> tag)
        #log_update(f"Found table (showing header line): {str(table).split('<thead')[0]}")

        # Iterate through rows in the table and process each entry
        secondary_structure_types = {}
        rows = table.find_all('tr')
        for row in rows[1:]:  # Skip the header row
            cols = row.find_all('td')
            if len(cols) > 1:
                type_id = cols[0].text.strip()
                description = cols[1].text.replace('\t', ' ').strip()

                # Replace multiple spaces with a single space
                description = re.sub(' +', ' ', description)

                # If this is a protein secondary structure (the table also contains nucleic acid structures), add it to teh dictionary
                if '(protein)' in description:
                  secondary_structure_types[type_id] = description

        return secondary_structure_types

    def get_secondary_structure_types(self):
      '''
      Display secondary structure types
      '''
      log_update("Secondary Structure Types in mmCIF files:")
      for ss_type, description in self.secondary_structure_types.items():
          log_update(f"{ss_type}: {description}")

      return self.secondary_structure_types

    def __parse_cif(self):
        '''
        Read cif file lines from self.file_path
        '''
        with open(self.file_path, 'r') as file:
            lines = file.readlines()
        return lines

    def __extract_secondary_structures(self):
        '''
        Iterate through the lines of the cif files to find each secondary structure.
        Returns a tuple for each amino acid that has a secondary structure annotation. Tuple contains:
          1. Structure Type (e.g. STRN)
          2. Structure ID (e.g. STRN1)
          3. Description (e.g. beta strand)
          4. Position (e.g. 3)
        '''
        secondary_structures = []
        parsing_secondary_structure = False

        # iterate throhugh cif lines
        for line in self.cif_lines:
            # hone in on the right section of the cif file
            if line.startswith("_struct_conf.conf_type_id"):
                parsing_secondary_structure = True
                continue
            # if we're in the right section...
            if parsing_secondary_structure:
                if line.startswith("#"):
                    parsing_secondary_structure = False   # no longer in the right section
                    continue
                # still in the right section
                columns = line.split()
                # iterate through columns to find each piece of info we need
                if len(columns) >= 7:
                    sec_struc_type = columns[6]
                    sec_struc_id = columns[13]
                    start_res = int(columns[2])
                    end_res = int(columns[9])
                    sec_struc_name = self.secondary_structure_types.get(sec_struc_type, 'Unknown')
                    # make tuple for this position in the sequence
                    for pos in range(start_res, end_res + 1):
                        secondary_structures.append((sec_struc_type, sec_struc_id, sec_struc_name, pos))

        return secondary_structures

    def __calc_pLDDTs(self):
        '''
        This method iterates through the cif file to return a dictionary with a few key pieces of info:
          1. Sequence
          2. pLDDTs for each residue
          3. Average pLDDT
        '''

        # define dictionary needed to translate into single-letter AA code
        aa_dict = {
            "ALA": "A", "CYS": "C", "ASP": "D", "GLU": "E", "PHE": "F",
            "GLY": "G", "HIS": "H", "ILE": "I", "LYS": "K", "LEU": "L",
            "MET": "M", "ASN": "N", "PRO": "P", "GLN": "Q", "ARG": "R",
            "SER": "S", "THR": "T", "VAL": "V", "TRP": "W", "TYR": "Y"
        }

        parser = MMCIFParser(QUIET=True)    # create a parser
        data = parser.get_structure("structure", self.file_path)    # parse structure

        # count models and chains (should be 1 model and 1 chain; don't use this class to parse a complex)
        model = data.get_models()
        models = list(model)
        chains = list(models[0].get_chains())

        # iterate through the chains and get amino acid letters and pLDDTs
        all_pLDDTs = []
        for n in range(len(chains)):
            chainname = chr(n + 65)     # turn chain number into letter (e.g. 1 --> "A" so we have Chain A instead of Chain 1)
            residues = list(chains[n].get_residues())   # extract all residues
            seq = ''
            pLDDTs = [0] * len(residues)    # initialize empty pLDDT array for this chain

            # iterate through all residues in this chain
            for i in range(len(residues)):
                r = residues[i]
                # which amino acid is here?
                try:
                    seq += aa_dict[r.get_resname()]
                # error if it's not a real amino acid
                except KeyError:
                    log_update('residue name invalid')
                    break

                # look at each atom. Get its pLDDT (bfactor). make sure bfactor for all atoms within one residue are equal.
                atoms = list(r.get_atoms())
                bfactor = atoms[0].get_bfactor()
                for a in range(len(atoms)):
                    # if not all atoms within an AA have the same pLDDT, error.
                    if atoms[a].get_bfactor() != bfactor:
                        break

                pLDDTs[i] = bfactor   # add pLDDT for this residue to the list.

            all_pLDDTs.extend(pLDDTs) # add pLDDTs for this chain to list of all pLDDTs

        avg_pLDDT = np.mean(all_pLDDTs) # average pLDDTs across all chains
        return_dict = {
            'avg_pLDDT': round(avg_pLDDT, 2),
            'res_pLDDTs': all_pLDDTs,
            'seq': seq
        }
        return return_dict

    def __create_residues_summary_dataframe(self):
        '''
        Create a dataframe that summarizes the secondary structure information for each residue.
        Columns:
          1. Position: amino acid position (e.g. 3)
          2. Residue: amino acid 1-letter code (e.g. A)
          3. pLDDT: alphafold2's pLDDT score for this residue to 2 decimal places (e.g. 77.54)
          4. Structure Type: type of secondary structure (e.g. STRN)
          5. Structure ID: ID of this secondary structure (e.g. STRN1)
          5. Description: description of this secondary structure (e.g. beta strand)
          6. Disordered: is this residue disordered or not? A residue is not disordered if it's in a HELX or STRN. (True/False)

        '''
        # Convert the secondary structures to a dataframe
        df_secondary_structures = pd.DataFrame(self.secondary_structures, columns=['Structure Type', 'Structure ID', 'Description', 'Position'])

        # Add Residue and pLDDT columns to the dataframe
        df_temp = pd.DataFrame(
            data={
                'Position': list(range(1, len(self.sequence) + 1)),
                'Residue': list(self.sequence),
                'pLDDT': self.plddts
            })

        df_secondary_structures = pd.merge(df_secondary_structures, df_temp, on='Position', how='right')
        # Determine if each residue is disordered or not based on what Structure Type it's in. If helix or strand, it's ordered. If anything else or NaN, it's disordered.
        df_secondary_structures['Disordered'] = df_secondary_structures['Structure Type'].apply(
            lambda x: False if (type(x)==str and (('HELX' in x) or ('STRN' in x))) else True
        )

        return df_secondary_structures

    def __create_secondary_structures_summary_dataframe(self):
        '''
        Create a dataframe grouped by each Structure ID, providing a summary of each secondary structure in the chain.
        Columns:
          1. Structure ID: ID of this secondary structure (e.g. STRN1)
          2. Start: start position of this secondary structure (e.g. 3)
          3. End: end position of this secondary structure (e.g. 12)
          4. Start Residue: amino acid 1-letter code of the start position (e.g. A)
          5. End Residue: amino acid 1-letter code of the end position (e.g. L)
          6. Disordered: is this residue disordered or not? A residue is not disordered if it's in a HELX or STRN. (True/False)
          7. Description: description of this secondary structure (e.g. beta strand)
          8. Structure Type: type of secondary structure (e.g. STRN)
          9. avg_pLDDT: average pLDDT for this secondary structure (e.g. 77.54)
        '''

        # Apply groupby on self.residues_df to reorganize it by Structure ID
        secondary_structures_df = self.residues_df.groupby('Structure ID').agg({
            'Position': ['first', 'last'],
            'Residue': ['first','last'],
            'Disordered': 'first',
            'Description': 'first',
            'Structure Type': 'first',
            'pLDDT': 'mean'
        }).reset_index()

        # Flatten the multi-level columns
        secondary_structures_df.columns = ['Structure ID', 'Start', 'End', 'Start Residue', 'End Residue', 'Disordered', 'Description', 'Structure Type', 'avg_pLDDT']
        secondary_structures_df['avg_pLDDT'] = secondary_structures_df['avg_pLDDT'].round(2)

        # Display the summarized DataFrame
        return secondary_structures_df

    def get_residues_df(self):
        return self.residues_df

    def get_secondary_structures_df(self):
        return self.secondary_structures_df

    def get_full_sequence(self):
        return ''.join([res for res in self.residues_df['Residue']])

    def get_average_plddt(self):
        plddt_values = [plddt for plddt in self.residues_df['pLDDT'] if plddt is not None]
        return sum(plddt_values) / len(plddt_values) if plddt_values else None

def pull_secondary_structure_types():
    url = "https://mmcif.wwpdb.org/dictionaries/mmcif_pdbx_v50.dic/Items/_struct_conf_type.id.html"
    response = requests.get(url)

    if response.status_code != 200:
        raise Exception("Failed to retrieve mmCIF dictionary")

    soup = BeautifulSoup(response.content, 'html.parser')

    # Debug: Print the soup to understand the structure
    # log_update(soup.prettify())
    # write the prettified soup to a txt file
    with open('mmcif_dictionary.txt', 'w') as f:
        f.write(soup.prettify())

    # Find the h4 header with the class "panel-title" and text "Controlled Vocabulary"
    header = soup.find('h4', class_='panel-title')
    if header is None or 'Controlled Vocabulary' not in header.text:
        raise Exception("Could not find the 'Controlled Vocabulary' header")

    # Debug: Print the found header
    #log_update(f"Found header: {header}")

    # The table should be the next sibling of the header
    table = header.find_next('table')
    if table is None:
        raise Exception("Could not find the table following the 'Controlled Vocabulary' header")

    # Debug: Print the found table (only the opening <table> tag)
    #log_update(f"Found table (showing header line): {str(table).split('<thead')[0]}")

    secondary_structure_types = {}
    rows = table.find_all('tr')
    for row in rows[1:]:  # Skip the header row
        cols = row.find_all('td')
        if len(cols) > 1:
            type_id = cols[0].text.strip()
            description = cols[1].text.replace('\t', ' ').strip()

            # Replace multiple spaces with a single space
            description = re.sub(' +', ' ', description)

            if '(protein)' in description:
              secondary_structure_types[type_id] = description

    return secondary_structure_types

# Process structures downloaded from FusionPDB 
def process_fusionpdb_fusion_files(files, level_2_3_structure_info, folder, save_path=None):
    # get secondary structure types so we can process PDBs
    secondary_structure_types = pull_secondary_structure_types()

    # Initialize 3 columns to store structural info - the AA seq in the fold (should match), the Avg pLDDT, and the per-residue pLDDTs (comma-separated, 2 decimal pts.)
    level_2_3_structure_info['Fold AA seq'] = ['']*len(level_2_3_structure_info)
    level_2_3_structure_info['Avg pLDDT'] = [0]*len(level_2_3_structure_info)
    level_2_3_structure_info['pLDDTs'] = ['']*len(level_2_3_structure_info)
    
    # pre-loop processed
    pre_loop_processed = []
    if os.path.exists(save_path):
        pre_loop_processed = pd.read_csv(save_path)
        pre_loop_processed = pre_loop_processed['Structure Link'].tolist()
        pre_loop_processed = [x.split('/')[-1] for x in pre_loop_processed]
        log_update(f"Total structures already processed: {len(pre_loop_processed)}")
    
    log_update("\nProcessing fusion structures...")
    # only process structures we haven't processed yet
    for i, structure in enumerate(files):
        log_update(f'\tProcessing #{i+1}: {structure}')
        
        # make sure we haven't already processed it and aren't wasting time
        if structure in pre_loop_processed:
            log_update(f"\t\tAlready processed. Continuing...")
            continue
        
        # create AlphaFoldStructure object
        obj = AlphaFoldStructure(fold_path=f'{folder}/{structure}', secondary_structure_types=secondary_structure_types)
        aa_seq = obj.get_full_sequence()
        avg_plddt = obj.get_average_plddt()
        residues_df = obj.get_residues_df()
        all_plddts = ",".join(residues_df['pLDDT'].astype(str).tolist())
        
        log_update(f"\t\tAvg pLDDT: {round(avg_plddt,2)}\tFold AA seq: {aa_seq}\tFirst 5 pLDDTs: {','.join(all_plddts.split(',')[0:5])}")

        level_2_3_structure_info.loc[level_2_3_structure_info['Structure Link'].str.contains(f"/{structure}"), 'Fold AA seq'] = aa_seq
        level_2_3_structure_info.loc[level_2_3_structure_info['Structure Link'].str.contains(f"/{structure}"), 'Avg pLDDT'] = avg_plddt
        level_2_3_structure_info.loc[level_2_3_structure_info['Structure Link'].str.contains(f"/{structure}"), 'pLDDTs'] = all_plddts
        
        # write level_2_3_structure_info to csv
        cur_df = level_2_3_structure_info.loc[level_2_3_structure_info['Structure Link'].str.contains(f"/{structure}")].reset_index(drop=True)
        if os.path.exists(save_path):
            cur_df.to_csv(save_path,mode='a',header=False,index=False)
        else:
            cur_df.to_csv(save_path,index=False)
            
    # now reload the completed dataframe
    level_2_3_structure_info = pd.read_csv(save_path)
    return level_2_3_structure_info

def process_fusionpdb_head_tail_files(ht, save_path='heads_and_tails_structures_processed.csv'):
    # ht is a list of head and tail proteins we have to process.  
    log_update("\nProcessing head and tail structures...")
    
    # get secondary structure types so we can process PDBs
    secondary_structure_types = pull_secondary_structure_types()
    
    # make directory to save alphafold DB structures of heads and tails
    os.makedirs('raw_data/fusionpdb/head_tail_af2db_structures',exist_ok=True)
    
    # pre-loop processed
    pre_loop_processed = []
    if os.path.exists(save_path):
        pre_loop_processed = pd.read_csv(save_path)
        pre_loop_processed = pre_loop_processed['UniProtID'].tolist()
        log_update(f"Heads and tails already processed: {len(pre_loop_processed)}")
        
    ht_structures_df = pd.DataFrame(
        data = {
            'UniProtID': ['']*len(ht),
            'Avg pLDDT': ['']*len(ht),
            'All pLDDTs': ['']*len(ht),
            'Seq': ['']*len(ht)
        }
    )
    
    for i, uniprotid in enumerate(ht):
        log_update(f'\tProcessing #{i+1}: {uniprotid}')
        aa_seq, avg_plddt, all_plddts = None, None, None
        
        # make sure we haven't processed it yet!
        if uniprotid in pre_loop_processed:
            log_update(f"\t\tAlready processed. Continuing")
            continue

        try:
            obj = AlphaFoldStructure(uniprot_to_download=uniprotid, secondary_structure_types=secondary_structure_types,
                                     uniprot_output_dir='raw_data/fusionpdb/head_tail_af2db_structures')
            aa_seq = obj.get_full_sequence()
            avg_plddt = obj.get_average_plddt()
            residues_df = obj.get_residues_df()
            all_plddts = ",".join(residues_df['pLDDT'].astype(str).tolist())

            log_update(f"\t\tAvg pLDDT: {round(avg_plddt,2)}\tFold AA seq: {aa_seq}\tFirst 5 pLDDTs: {','.join(all_plddts.split(',')[0:5])}")

        except:
            log_update(f"\t\tAvg pLDDT: {None}\tFold AA seq: {None}\tFirst 5 pLDDTs: {None}")
            
        # Fill in info for combined ht df
        ht_structures_df.loc[i, 'UniProtID'] = uniprotid
        ht_structures_df.loc[i, 'Avg pLDDT'] = avg_plddt
        ht_structures_df.loc[i, 'All pLDDTs'] = all_plddts
        ht_structures_df.loc[i, 'Seq'] = aa_seq
        
        # write level_2_3_structure_info to csv
        cur_df = pd.DataFrame(ht_structures_df.iloc[i,:]).T.reset_index(drop=True)
        if os.path.exists(save_path):
            cur_df.to_csv(save_path,mode='a',header=False,index=False)
        else:
            cur_df.to_csv(save_path,index=False)
    
    # ensure we got everything
    ht_structures_df = pd.read_csv(save_path)
    level_2_3 = pd.read_csv(f'processed_data/fusionpdb/intermediates/giant_level2-3_fusion_protein_head_tail_info.csv')
    level_2_3['FusionGene'] = level_2_3['FusionGene'].str.replace('-','::') 
    heads = level_2_3['HGUniProtAcc'].tolist()
    tails = level_2_3['TGUniProtAcc'].tolist()
    ht = heads + tails
    ht = set([x for x in ht if type(x)==str])
    ht = set(','.join(ht).split(','))

    log_update(f"total heads and tails: {len(ht)}")
    log_update(f"total processed: {len(ht_structures_df)}\t{len(ht_structures_df['UniProtID'].unique())}")

    # which ones are missing?
    missing = set(ht) - set(ht_structures_df['UniProtID'].unique())
    log_update(f"missing: {len(missing)}")
    log_update(missing)
    
    # Some heads and tails are not in the alphxwafold database. I folded these myself, externally. 
    ht_structures_df = ht_structures_df.replace('',np.nan)
    need_to_fold = ht_structures_df[ht_structures_df['Avg pLDDT'].isna()]['UniProtID'].tolist()
    with open('processed_data/fusionpdb/intermediates/uniprotids_not_in_afdb.txt','w') as f:
        for uniprotid in need_to_fold:
            f.write(f'{uniprotid}\n')

    idmap = pd.read_csv(f'raw_data/fusionpdb/not_in_afdb_idmap.txt',sep='\t')
    idmap = idmap[idmap['Entry'].isin(need_to_fold)].reset_index(drop=True)
    idmap = idmap[['Entry','Sequence']].rename(columns={
                                                        'Entry': 'ID'})
    idmap['Length'] = idmap['Sequence'].apply(len)

    log_update("Investigating heads and tails that were not in the AF2 database:")
    log_update(f"\tMin length: {min(idmap['Length'])}")
    log_update(f"\tMax length: {max(idmap['Length'])}")
    idmap = idmap.sort_values(by='Length',ascending=True).reset_index(drop=True)
    
    # Q9NNW7
    id='Q9NNW7'
    if id in idmap['ID'].tolist():
        ht_structures_df.loc[
            ht_structures_df['UniProtID']=='Q9NNW7', 'Avg pLDDT'
        ] = 91.68
        ht_structures_df.loc[
            ht_structures_df['UniProtID']=='Q9NNW7', 'Seq'
        ] = idmap.loc[
            idmap['ID']=='Q9NNW7', 'Sequence'
        ].item()
    
    ## Q16881
    id='Q16881'
    if id in idmap['ID'].tolist():
        ht_structures_df.loc[
            ht_structures_df['UniProtID']==id, 'Avg pLDDT'
        ] = 89.55
        ht_structures_df.loc[
            ht_structures_df['UniProtID']==id, 'Seq'
        ] = idmap.loc[
            idmap['ID']==id, 'Sequence'
        ].item()
    
    # Q86V15
    id='Q86V15'
    if id in idmap['ID'].tolist():
        ht_structures_df.loc[
            ht_structures_df['UniProtID']==id, 'Avg pLDDT'
        ] = 48.14
        ht_structures_df.loc[
            ht_structures_df['UniProtID']==id, 'Seq'
        ] = idmap.loc[
            idmap['ID']==id, 'Sequence'
        ].item()
            
    return ht_structures_df

def process_fusions_and_hts():
    # Process the structures of fusion proteins downloaded from FusionPDB
    level_2_3_structure_info_og = pd.read_csv('processed_data/fusionpdb/intermediates/giant_level2-3_fusion_protein_structure_links.csv')

    # figure out which ones we have
    folder = 'raw_data/fusionpdb/structures'
    # get all the structure files in folder
    files = os.listdir(folder)
    log_update(f"total pdbs: {len(files)}")
    log_update(f"examples: {files[:5]}")
    
    os.makedirs('processed_data/fusionpdb', exist_ok=True)  
    
    # process the full fusion pdbs
    level_2_3_structure_info = process_fusionpdb_fusion_files(files, level_2_3_structure_info_og, folder, save_path='processed_data/fusionpdb/intermediates/giant_level2-3_fusion_protein_structures_processed.csv')  
    
    # process the head and tail pdbs
    level_2_3 = pd.read_csv(f'processed_data/fusionpdb/intermediates/giant_level2-3_fusion_protein_head_tail_info.csv') 
    level_2_3['FusionGene'] = level_2_3['FusionGene'].str.replace('-','::') 
    # Get the heads and tails, see how many unique proteins we're working with
    heads = level_2_3['HGUniProtAcc'].tolist()
    tails = level_2_3['TGUniProtAcc'].tolist()
    ht = heads + tails
    ht = set([x for x in ht if type(x)==str])
    ht = set(','.join(ht).split(','))
    log_update(f"Unique heads/tails: {len(ht)}")
    
    heads_tails_analyzed = process_fusionpdb_head_tail_files(list(ht), save_path='processed_data/fusionpdb/heads_tails_structural_data.csv') 
    
    # In the level_2_3 database, we only have the fusions with documented heads and tails.
    level_2 = pd.read_csv(f'raw_data/fusionpdb/FusionPDB_level2_curated_09_05_2024.csv')
    level_3 = pd.read_csv(f'raw_data/fusionpdb/FusionPDB_level3_curated_09_05_2024.csv')
    joined_23 = pd.concat([level_2,level_3]).reset_index(drop=True)
    joined_23['FusionGene'] = joined_23['FusionGene'].str.replace('-','::')     # use new notation with head::tail
    log_update(f"\nnumber of duplicated fusion gene rows: {len(joined_23[joined_23['FusionGene'].duplicated()])}")
    # make the dictionary
    fo_gid_dict = dict(zip(joined_23['FusionGene'],joined_23['FusionGID']))
    log_update(len(fo_gid_dict))

    # let's clean giant level 2 and level 3
    # first, drop anyting where Fold AA seq is nan. there is no fold.
    level_2_3_structure_info_clean = level_2_3_structure_info.replace('',np.nan)    # make sure there are nans where there should be
    level_2_3_structure_info_clean = level_2_3_structure_info_clean.dropna(subset=['Fold AA seq']).reset_index(drop=True)
    log_update(f"length of processed structure file: {len(level_2_3_structure_info_clean)}")
    level_2_3_structure_info_clean['pLDDT'] = level_2_3_structure_info_clean['Avg pLDDT'].round(2)
    level_2_3_structure_info_clean = level_2_3_structure_info_clean.drop(columns=['Avg pLDDT'])
    level_2_3_structure_info_clean['FusionGene'] = level_2_3_structure_info_clean['FusionGene'].str.replace('-','::')
    level_2_3_structure_info_clean['FusionGID'] = level_2_3_structure_info_clean['FusionGene'].apply(lambda x: fo_gid_dict[x])
    
    # now let's use the FusionPDB database we processed as ground truth for sequence, rather than the webpage
    log_update("Using FusionPDB as ground truth for sequences...")
    raw_download = pd.read_csv('../../data/raw_data/FusionPDB.txt',sep='\t',header=None)
    raw_download['FusionGene'] = raw_download[7]+ '::' + raw_download[11]
    raw_download = raw_download.rename(columns={18:'Raw Download AA Seq'})
    log_update(f"FusionPDB raw download size: {len(raw_download)}")

    level_2_3_structure_info_clean_ids = set(level_2_3_structure_info_clean['FusionGene'].tolist())
    level_2_3_structure_info_clean_seqs = set(level_2_3_structure_info_clean['Fold AA seq'].tolist())
    raw_download_ids = set(raw_download['FusionGene'].tolist())
    raw_download_seqs = set(raw_download['Raw Download AA Seq'].tolist())
    log_update(f"Number of overlapping gene IDs: {len(level_2_3_structure_info_clean_ids.intersection(raw_download_ids))}")
    log_update(f"Number of overlapping sequences: {len(level_2_3_structure_info_clean_seqs.intersection(raw_download_seqs))}")
    # attempt a merge on Raw Download AA Seq with both. ofthe ogs
    # Merging with the AlphaFold sequence
    test_merge_1 = pd.merge(
        level_2_3_structure_info_clean.rename(columns={'Fold AA seq': 'Raw Download AA Seq'}),
        raw_download,
        on=['FusionGene','Raw Download AA Seq'],
        how='inner'
    )
    test_merge_1 = test_merge_1.drop(columns=['AA seq'])
    test_merge_1['Seq Source'] = ['AlphaFold,Raw Download']*len(test_merge_1)
    log_update(f"Merge on AlphaFold AA Seq and raw Download AA Seq. len={len(test_merge_1)}")
    # Merging with the webpage sequence
    test_merge_2 = pd.merge(
        level_2_3_structure_info_clean.rename(columns={'AA seq': 'Raw Download AA Seq'}),
        raw_download,
        on=['FusionGene','Raw Download AA Seq'],
        how='inner'
    )
    test_merge_2 = test_merge_2.drop(columns=['Fold AA seq'])
    test_merge_2['Seq Source'] = ['Webpage,Raw Download']*len(test_merge_2)
    log_update(f"Merge on Webpage AA Seq and Raw Download AA Seq. len={len(test_merge_2)}")

    test_merge = pd.concat([test_merge_1,test_merge_2])
    test_merge['Len(AA seq)'] = test_merge['Raw Download AA Seq'].apply(lambda x: len(x))
    # drop duplicates
    test_merge = test_merge.drop_duplicates().reset_index(drop=True) 
    
    # for anything that has a CIF, keep the CIF
    log_update(f"len test_merge before keeping CIFs over identical PDBs: {len(test_merge)}")
    test_merge = test_merge.sort_values(by='Structure Type',ascending=True).reset_index(drop=True).groupby(['Hgene', 'Hchr', 'Hbp', 'Hstrand', 'Tgene', 'Tchr',
        'Tbp', 'Tstrand', 'Len(AA seq)', 'FusionGene',
        'Level', 'Raw Download AA Seq', 'pLDDT', 'pLDDTs','FusionGID', 'Seq Source']).agg(
            {
                'Structure Link': 'first',
                'Structure Type': 'first'
            }
        ).reset_index()
    log_update(f"len after: {len(test_merge)}")
    
    # for anything with multiple seq sources, concatenate them
    log_update(f"len test_merge before combining seq sources: {len(test_merge)}")
    test_merge = test_merge.groupby(['Structure Link','Hgene', 'Hchr', 'Hbp', 'Hstrand', 'Tgene', 'Tchr',
        'Tbp', 'Tstrand', 'Len(AA seq)', 'FusionGene','Structure Type',
        'Level', 'Raw Download AA Seq', 'pLDDT', 'pLDDTs', 'FusionGID', ]).agg(
            {
                'Seq Source': lambda x: ','.join(x)
            }
        ).reset_index()
    test_merge['Seq Source'] = test_merge['Seq Source'].apply(lambda x: ','.join(set(x.split(','))))
    log_update(f"len after: {len(test_merge)}")
    
    # are there cases of multiple folds for the same sequence? miraculously, yes!
    dup_seqs = test_merge[test_merge['Raw Download AA Seq'].duplicated()]['Raw Download AA Seq'].unique().tolist()

    # for anything with multiple folds / seq, randomly choose the first one. 
    log_update(f"len test_merge before randomly choosing first fold when one seq has multiple folds: {len(test_merge)}")
    test_merge = test_merge.groupby(['Hgene', 'Hchr', 'Hbp', 'Hstrand', 'Tgene', 'Tchr',
        'Tbp', 'Tstrand', 'Len(AA seq)', 'FusionGene',
        'Level', 'Raw Download AA Seq', 'FusionGID', ]).agg(
            {
                'Structure Link': 'first',
                'Structure Type': 'first',
                'Seq Source': 'first',
                'pLDDT': 'first',
                'pLDDTs': 'first'
            }
        ).reset_index()
    log_update(f"len after: {len(test_merge)}")
    
    # how many columns DO NOT have the right AlphaFold sequences?
    source_str = test_merge['Seq Source'].value_counts().reset_index().rename(columns={'index': 'Seq Source','Seq Source': 'count'}).to_string(index=False)
    source_str = "\t\t" + source_str.replace("\n","\n\t\t")
    log_update(f"Distribution of sequence sources:\n{source_str}")
    
    # dropping anything where AF sequence is wrong. Don't want to use these.
    test_merge = test_merge.loc[test_merge['Seq Source'].str.contains('AlphaFold')].reset_index(drop=True)
    log_update(f"Dropped rows where AlphaFold sequence was incorrect. New DataFrame length: {len(test_merge)}")
    # make sure there's only one FusionGID number for each sequence for each GID
    assert len(test_merge[test_merge.duplicated(['FusionGID','Raw Download AA Seq'])])==0
    
    # round pLDDTs
    test_merge['pLDDT'] = test_merge['pLDDT'].round(2)
    
    # Finally, select only the columns we want
    test_merge_v2 = test_merge[
    ['FusionGID', 'FusionGene', 'Raw Download AA Seq','Len(AA seq)', 'Hgene', 'Hchr', 'Hbp', 'Hstrand', 'Tgene', 'Tchr', 'Tbp', 'Tstrand',
       'Level','Structure Link', 'Structure Type', 'pLDDT', 'pLDDTs', 'Seq Source']
    ].rename(
        columns={
            'Raw Download AA Seq': 'Fusion_Seq',
            'Seq Source': 'Fusion_Seq_Source',
            'Structure Link': 'Fusion_Structure_Link',
            'Structure Type': 'Fusion_Structure_Type',
            'pLDDT': 'Fusion_pLDDT',
            'pLDDTs': 'Fusion_AA_pLDDTs',
            'Len(AA seq)': 'Fusion_Length'
        }
    )
    log_update(f"Unique FusionGIDs: {len(test_merge_v2['FusionGID'].unique())}")
    log_update(f"Number of structures: {len(test_merge_v2)}")
    
    # Note that test_merge_v2 will still have duplicates where it's same seq, different ID
    log_update("\nChecking for duplicate sequences..")
    log_update(f"\tThe structure-based fusion database of length {len(test_merge_v2)} has {len(test_merge_v2['Fusion_Seq'].unique())} unique fusion sequences.")
    dup_seqs = test_merge_v2[test_merge_v2['Fusion_Seq'].duplicated()]['Fusion_Seq'].tolist()
    dup_seqs_df = test_merge_v2.loc[test_merge_v2['Fusion_Seq'].isin(dup_seqs)].reset_index(drop=True)
    dup_seqs_df['FusionGID'] = dup_seqs_df['FusionGID'].astype(str)
    dup_seqs_df = dup_seqs_df.groupby('Fusion_Seq').agg({
        'FusionGID': lambda x: ','.join(x),
        'FusionGene': lambda x: ','.join(x)
    })
    dup_seqs_df_str = dup_seqs_df.to_string(index=False)
    dup_seqs_df_str = "\t"+dup_seqs_df_str.replace("\n","\n\t")
    log_update(f"\tShowing FUsionGIDs and FusionGenes for duplicated sequences below:\n{dup_seqs_df_str}")
    
    # round pLDDT column
    heads_tails_analyzed['Avg pLDDT'] = heads_tails_analyzed['Avg pLDDT'].round(2)
    # merge treating data as head data
    level_2_3_v2 = pd.merge(
        level_2_3,
        heads_tails_analyzed.rename(columns={'UniProtID': 'HGUniProtAcc', 'Avg pLDDT': 'HG_pLDDT', 'All pLDDTs': 'HG_AA_pLDDTs', 'Seq': 'HG_Seq'}),
        on='HGUniProtAcc',
        how='left'
    )
    # merge treating data as tail data
    level_2_3_v2 = pd.merge(
        level_2_3_v2,
        heads_tails_analyzed.rename(columns={'UniProtID': 'TGUniProtAcc', 'Avg pLDDT': 'TG_pLDDT', 'All pLDDTs': 'TG_AA_pLDDTs', 'Seq': 'TG_Seq'}),
        on='TGUniProtAcc',
        how='left'
    )
    
    # giant_level2_3 with valid structures only, no duplicate sequences, and head and tail proteins' uniprot IDs, pLDDTs, and sequences
    test_merge_v2.to_csv(f'processed_data/fusionpdb/FusionPDB_level2-3_cleaned_structure_info.csv',index=False)
    log_update("Saved file with all fusion structure pLDDTs to: processed_data/fusionpdb/FusionPDB_level2-3_cleaned_structure_info.csv")
    
    # level_2_3 with the head and tail proteins' uniprot IDs, pLDDTs, and sequences
    level_2_3_v2.to_csv(f'processed_data/fusionpdb/FusionPDB_level2-3_cleaned_FusionGID_info.csv',index=False)
    log_update("Saved file with all fusion protein heads and tails, and their structure pLDDTs to: processed_data/fusionpdb/FusionPDB_level2-3_cleaned_FusionGID_info.csv")
        
def main():
    with open_logfile("process_fusion_structures_log.txt"):
        process_fusions_and_hts()
        
if __name__ == "__main__":
    main()