svincoff commited on
Commit
3efa812
·
1 Parent(s): e048d40

mutation prediction discovery and recovery

Browse files
Files changed (40) hide show
  1. fuson_plm/benchmarking/caid/README.md +2 -2
  2. fuson_plm/benchmarking/idr_prediction/README.md +2 -2
  3. fuson_plm/benchmarking/mutation_prediction/README.md +114 -0
  4. fuson_plm/benchmarking/mutation_prediction/discovery/clean.py +71 -0
  5. fuson_plm/benchmarking/mutation_prediction/discovery/color_discovered_mutations.ipynb +418 -0
  6. fuson_plm/benchmarking/mutation_prediction/discovery/config.py +16 -0
  7. fuson_plm/benchmarking/mutation_prediction/discovery/discover.py +346 -0
  8. fuson_plm/benchmarking/mutation_prediction/discovery/make_color_bar.py +25 -0
  9. fuson_plm/benchmarking/mutation_prediction/discovery/plot.py +167 -0
  10. fuson_plm/benchmarking/mutation_prediction/discovery/processed_data/521_logit_bfactor.cif +0 -0
  11. pytorch_model.bin → fuson_plm/benchmarking/mutation_prediction/discovery/processed_data/domain_conservation_fusions_inputfile.csv +2 -2
  12. fuson_plm/benchmarking/mutation_prediction/discovery/processed_data/test_seqs_tftf_kk.csv +3 -0
  13. fuson_plm/benchmarking/mutation_prediction/discovery/raw_data/salokas_2020_tableS3.csv +3 -0
  14. fuson_plm/benchmarking/mutation_prediction/discovery/results/final/ETV6::NTRK3/conservation_heatmap.png +0 -0
  15. fuson_plm/benchmarking/mutation_prediction/discovery/results/final/ETV6::NTRK3/full_results_with_logits.csv +3 -0
  16. fuson_plm/benchmarking/mutation_prediction/discovery/results/final/ETV6::NTRK3/predicted_tokens.csv +3 -0
  17. fuson_plm/benchmarking/mutation_prediction/discovery/results/final/ETV6::NTRK3/raw_mutation_results.pkl +3 -0
  18. fuson_plm/benchmarking/mutation_prediction/discovery/results/final/EWSR1::FLI1/conservation_heatmap.png +0 -0
  19. fuson_plm/benchmarking/mutation_prediction/discovery/results/final/EWSR1::FLI1/full_results_with_logits.csv +3 -0
  20. fuson_plm/benchmarking/mutation_prediction/discovery/results/final/EWSR1::FLI1/predicted_tokens.csv +3 -0
  21. fuson_plm/benchmarking/mutation_prediction/discovery/results/final/EWSR1::FLI1/raw_mutation_results.pkl +3 -0
  22. fuson_plm/benchmarking/mutation_prediction/discovery/results/final/PAX3::FOXO1/conservation_heatmap.png +0 -0
  23. fuson_plm/benchmarking/mutation_prediction/discovery/results/final/PAX3::FOXO1/full_results_with_logits.csv +3 -0
  24. fuson_plm/benchmarking/mutation_prediction/discovery/results/final/PAX3::FOXO1/predicted_tokens.csv +3 -0
  25. fuson_plm/benchmarking/mutation_prediction/discovery/results/final/PAX3::FOXO1/raw_mutation_results.pkl +3 -0
  26. fuson_plm/benchmarking/mutation_prediction/discovery/results/final/TRIM24::RET/conservation_heatmap.png +0 -0
  27. fuson_plm/benchmarking/mutation_prediction/discovery/results/final/TRIM24::RET/full_results_with_logits.csv +3 -0
  28. fuson_plm/benchmarking/mutation_prediction/discovery/results/final/TRIM24::RET/predicted_tokens.csv +3 -0
  29. fuson_plm/benchmarking/mutation_prediction/discovery/results/final/TRIM24::RET/raw_mutation_results.pkl +3 -0
  30. fuson_plm/benchmarking/mutation_prediction/discovery/viridis_color_bar.png +0 -0
  31. fuson_plm/benchmarking/mutation_prediction/recovery/abl_mutations.csv +3 -0
  32. fuson_plm/benchmarking/mutation_prediction/recovery/alk_mutations.csv +3 -0
  33. fuson_plm/benchmarking/mutation_prediction/recovery/color_recovered_mutations_public.ipynb +314 -0
  34. fuson_plm/benchmarking/mutation_prediction/recovery/config.py +3 -0
  35. fuson_plm/benchmarking/mutation_prediction/recovery/recover_public.py +330 -0
  36. fuson_plm/benchmarking/mutation_prediction/recovery/results/final_public/BCR_ABL_mutation_recovery_fuson_mutated_pns_only.csv +3 -0
  37. fuson_plm/benchmarking/mutation_prediction/recovery/results/final_public/EML4_ALK_mutation_recovery_fuson_mutated_pns_only.csv +3 -0
  38. fuson_plm/benchmarking/mutation_prediction/recovery/results/final_public/Supplementary Tables - EML4 ALK Mutations.csv +3 -0
  39. fuson_plm/benchmarking/mutation_prediction/recovery/results/final_public/Supplementary Tables - BCR ABL Mutations.csv +3 -0
  40. fuson_plm/benchmarking/puncta/README.md +1 -1
fuson_plm/benchmarking/caid/README.md CHANGED
@@ -158,8 +158,8 @@ Here we describe what each script does and which files each script creates.
158
  a. Drops rows where no structure was successfully downloaded
159
  b. Drops rows where the FO sequence from FusionPDB does not match the FO sequence from its own AlphaFold2 structure file
160
  c. ⭐️Saves **two final, cleaned databases**⭐️:
161
- a. ⭐️ **`FusionPDB_level2-3_cleaned_FusionGID_info.csv`**: includes ful IDs and structural information for the Hgene and Tgene of each FO. Columns="FusionGID","FusionGene","Hgene","Tgene","URL","HGID","TGID","HGUniProtAcc","TGUniProtAcc","HGUniProtAcc_Source","TGUniProtAcc_Source","HG_pLDDT","HG_AA_pLDDTs","HG_Seq","TG_pLDDT","TG_AA_pLDDTs","TG_Seq".
162
- b. ⭐️ **`FusionPDB_level2-3_cleaned_structure_info.csv`**: includes full structural information for each FO. Columns= "FusionGID","FusionGene","Fusion_Seq","Fusion_Length","Hgene","Hchr","Hbp","Hstrand","Tgene","Tchr","Tbp","Tstrand","Level","Fusion_Structure_Link","Fusion_Structure_Type","Fusion_pLDDT","Fusion_AA_pLDDTs","Fusion_Seq_Source"
163
 
164
 
165
  ### Training
 
158
  a. Drops rows where no structure was successfully downloaded
159
  b. Drops rows where the FO sequence from FusionPDB does not match the FO sequence from its own AlphaFold2 structure file
160
  c. ⭐️Saves **two final, cleaned databases**⭐️:
161
+ a. ⭐️ **`FusionPDB_level2-3_cleaned_FusionGID_info.csv`**: includes ful IDs and structural information for the Hgene and Tgene of each FO. Columns = "FusionGID", "FusionGene", "Hgene", "Tgene", "URL", "HGID", "TGID", "HGUniProtAcc", "TGUniProtAcc", "HGUniProtAcc_Source", "TGUniProtAcc_Source", "HG_pLDDT", "HG_AA_pLDDTs", "HG_Seq", "TG_pLDDT", "TG_AA_pLDDTs", "TG_Seq".
162
+ b. ⭐️ **`FusionPDB_level2-3_cleaned_structure_info.csv`**: includes full structural information for each FO. Columns = "FusionGID", "FusionGene", "Fusion_Seq", "Fusion_Length", "Hgene", "Hchr", "Hbp", "Hstrand", "Tgene", "Tchr", "Tbp", "Tstrand", "Level", "Fusion_Structure_Link", "Fusion_Structure_Type", "Fusion_pLDDT", "Fusion_AA_pLDDTs", "Fusion_Seq_Source"
163
 
164
 
165
  ### Training
fuson_plm/benchmarking/idr_prediction/README.md CHANGED
@@ -14,7 +14,7 @@ python plot.py # if you want to remake r2 plots
14
  ```
15
 
16
  ### Downloading raw IDR data
17
- IDR properties from [Lotthammer et al. 2024](https://doi.org/10.1038/s41592-023-02159-5) (ALBATROSS model) were used to train FusOn-pLM-Diso. Sequences were downloaded from [this link](https://github.com/holehouse-lab/supportingdata/blob/master/2023/ALBATROSS_2023/simulations/data/all_sequences.tgz) and deposited in `raw_data`. All files in `raw_data` are from this direct download.
18
 
19
  ```
20
  benchmarking/
@@ -144,7 +144,7 @@ The model is defined in `model.py` and `utils.py`. The `train.py` script trains
144
  - All **raw outputs from models** are stored in `idr_prediction/trained_models/<embedding_path>`, where `embedding_path` represents the embeddings used to build the disorder predictor.
145
  - All **embeddings** made for training will be stored in a new folder called `idr_prediction/embeddings/` with subfolders for each model. This allows you to use the same model multiple times without regenerating embeddings.
146
 
147
- Below is the FusOn-pLM-Diso raw outputs folder, `trained_models/fuson_plm/best/`, and the results from the paper, `results/final/`...
148
 
149
  The outputs are structured as follows:
150
 
 
14
  ```
15
 
16
  ### Downloading raw IDR data
17
+ IDR properties from [Lotthammer et al. 2024](https://doi.org/10.1038/s41592-023-02159-5) (ALBATROSS model) were used to train FusOn-pLM-IDR. Sequences were downloaded from [this link](https://github.com/holehouse-lab/supportingdata/blob/master/2023/ALBATROSS_2023/simulations/data/all_sequences.tgz) and deposited in `raw_data`. All files in `raw_data` are from this direct download.
18
 
19
  ```
20
  benchmarking/
 
144
  - All **raw outputs from models** are stored in `idr_prediction/trained_models/<embedding_path>`, where `embedding_path` represents the embeddings used to build the disorder predictor.
145
  - All **embeddings** made for training will be stored in a new folder called `idr_prediction/embeddings/` with subfolders for each model. This allows you to use the same model multiple times without regenerating embeddings.
146
 
147
+ Below is the FusOn-pLM-IDR raw outputs folder, `trained_models/fuson_plm/best/`, and the results from the paper, `results/final/`...
148
 
149
  The outputs are structured as follows:
150
 
fuson_plm/benchmarking/mutation_prediction/README.md ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Mutation Prediction benchmark
2
+
3
+ This folder contains all the data and code needed to perform the **mutation prediction benchmark** (Figure 5).
4
+
5
+ ## Recovery
6
+ In Fig. 5C, drug resistance mutations in BCR::ABL and EML4::ALK are recovered by FusOn-pLM. Since the full sequences are not publicly available, certain sections of the code and results files have been removed. All results for drug resistance mutation positions in each sequence are included in `mutation_prediction/recovery`.
7
+
8
+ ```
9
+ benchmarking/
10
+ └── mutation_prediction/
11
+ └── recovery/
12
+ └── results/final_public/
13
+ ├── BCR_ABL_mutation_recovery_fuson_mutated_pns_only.csv
14
+ ├── Supplementary Tables - BCR ABL Mutations.csv
15
+ ├── EML4_ALK_mutation_recovery_fuson_mutated_pns_only.csv
16
+ ├── Supplementary Tables - EML4 ALK Mutations.csv
17
+ ├── abl_mutations.csv
18
+ ├── alk_mutations.csv
19
+ ├── color_recovered_mutations_public.ipynb
20
+ ├── recover_public.py
21
+ ├── config.py
22
+ ```
23
+ In the 📁 **`results`** directory:
24
+ - **`abl_mutations.csv`**: raw data from the literature with BCR::ABL mutations [(O'Hare et al. 2007)](https://doi.org/10.1182/blood-2007-03-066936)
25
+ - **`alk_mutations.csv`**: raw data from the literature with EML4::ALK mutations [(Elshatlawy et al. 2023)](https://doi.org/10.1002/1878-0261.13446)
26
+ - **`color_recovered_mutations_public.csv`**: notebook file used to write PyMOL code for visualizations in Fig. 5C
27
+ - **`recover_public.py`**: python file to run the analysis. Since private sequences are removed, this notebook will not run, but it includes all steps taken.
28
+ - **`config.py`**: small config file to specify which FusOn-pLM checkpoint is being used.
29
+
30
+ In the 📁 **`results/final_public`** directory,
31
+ - **`BCR_ABL_mutation_recovery_fuson_mutated_pns_only.csv`**: raw logits for every possible mutation at all positions in BCR::ABL with known drug resistance mutations
32
+ - **`Supplementary Tables - BCR ABL Mutations.csv`**: supplementary table showing the calculations going into the hit rate for BCR::ABL
33
+ - **`EML4_ALK_mutation_recovery_fuson_mutated_pns_only.csv`**: raw logits for every possible mutation at all positions in EML4::ALK with known drug resistance mutations
34
+ - **`Supplementary Tables - EML4 ALK Mutations.csv`**: supplementary table showing the calculations going into the hit rate for EML4::ALK
35
+
36
+ ## Discovery
37
+ The `mutation_prediction/discovery/` directory contains all code and files needed to reproduce Fig. 5B and 5D, where mutations are predicted at each position in several fusion oncoproteins.
38
+
39
+ ### Data download
40
+ To help select TF and Kinase-containing fusions for investigation (Fig. 5B), Supplementary Table 3 from [Salokas et al. 2020](https://doi.org/10.1038/s41598-020-71040-8) was downloaded as a reference of transcription factors and kinases. In `clean.py`, this data is processed.
41
+
42
+ ```
43
+ benchmarking/
44
+ └── mutation_prediction/
45
+ └── discovery/
46
+ └── raw_data/
47
+ ├── salokas_2020_tableS3.csv
48
+ └── processed_data/
49
+ ├── test_seqs_tftf_kk.csv
50
+ ├── domain_conservation_fusions_inputfile.csv
51
+ ```
52
+
53
+ - **`raw_data/salokas_2020_tableS3.csv`**: Supplementary Table 3 from [Salokas et al. 2020](https://doi.org/10.1038/s41598-020-71040-8)
54
+ - **`processed_data/test_seqs_tftf_kk.csv`**: fusion oncoproteins in the FusOn-pLM test set that have either a Transcription Factor (TF) or a kinase as either the head or tail.
55
+ - **`processed_data/domain_conservation_fusions_inputfile.csv`**: input file for discovery, containing the longest sequences of EWSR1::FLI1, PAX3::FOXO1, TRIM24::RET, and ETV6::NTRK3.
56
+
57
+ ### Mutation discovery
58
+
59
+ Run `discover.py` to perform discovery on the sequences specified in `config.py` (either one or many):
60
+
61
+ ```
62
+ # Model settings: where is the model you wish to use for mutation discovery?
63
+ FUSON_PLM_CKPT = "FusOn-pLM"
64
+
65
+ #### Fill in this sectinon if you have one input
66
+ # Sequence settings: need full sequence of fusion oncoprotein, and the bounds of region of interest
67
+ FULL_FUSION_SEQUENCE = ""
68
+ FUSION_NAME = "fusion_example"
69
+ START_RESIDUE_INDEX = 1
70
+ END_RESIDUE_INDEX = 100
71
+ N = 3 # number of mutations to predict per amio acid
72
+
73
+ #### Fill in this section if you have multiple input
74
+ PATH_TO_INPUT_FILE = "processed_data/domain_conservation_fusions_inputfile.csv" # if you don't have an input file and want to do one sequence, set this variable to None
75
+
76
+ # GPU Settings: which GPUs should be available to run this discovery?
77
+ CUDA_VISIBLE_DEVICES = "0"
78
+ ```
79
+
80
+ To run, use:
81
+ ```
82
+ nohup python discover.py > discover.out 2> discover.err &
83
+ ```
84
+ - All **results** are stored in `idr_prediction/results/<timestamp>`, where `timestamp` is a unique string encoding the date and time when you started training.
85
+
86
+ Below are the FusOn-pLM paper results in `results/final`:
87
+
88
+ ```
89
+ benchmarking/
90
+ └── mutation_prediction/
91
+ └── discovery/
92
+ └── results/final/
93
+ └── EWSR1::FLI1/
94
+ ├── conservation_heatmap.png
95
+ ├── full_results_with_logits.csv
96
+ ├── predicted_tokens.csv
97
+ ├── raw_mutation_results.pkl
98
+ └── PAX3::FOXO1/ # same format as results/final/EWSR1::FLI1...
99
+ └── TRIM24::RET/ # same format as results/final/EWSR1::FLI1...
100
+ └── ETV6::NTRK3/ # same format as results/final/EWSR1::FLI1...
101
+ ```
102
+ In each fusion oncoprotein folder are:
103
+ - **`conservation_heatmap.png`**: the heatmap from Fig. 5B
104
+ - **`full_results_with_logits.csv`**: predictions for every residue in the sequence. Columns = "Residue", "original_residue", "original_residue_logit", "all_logits", "top_3_mutations"
105
+ - **`predicted_tokens.csv`**: simplified format, top three tokens per residue and 1/0 conserved/not conserved label. Columns = "Original Residue", "Predicted Residues", "Conserved", "Position"
106
+ - **`raw_mutation_results.pkl`**: raw logits in dictionary format.
107
+
108
+ ### Plotting
109
+
110
+ Three scripts aid with plotting:
111
+
112
+ 1. `make_color_bar.py`: can be run to generate `viridis_color_bar.png`, a labeled and scaled up version of the color bar used in the conservation heatmaps and to color an ETV6::NTRK3 structure (Fig. 5D)
113
+ 2. `plot.py`: includes code for heatmap generation. Can be run to regenerate heatmaps.
114
+ 3. `color_discovered_mutations.ipynb`: notebook used to generate a modified version of the ETV6::NTRK3 structure file, which has the logits predicted by FusOn-pLM as the b factor (`processed_data/521_logit_bfactor.cif`, displayed in Fig. 5D, right). Also has PyMOL code for a head/tail visualization with recovered drug resistance mutations (displayed in Fig. 5D, left)
fuson_plm/benchmarking/mutation_prediction/discovery/clean.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### Clean the Salokas data, find TF and Kinase fusions in the test set
2
+ import pandas as pd
3
+ import os
4
+
5
+ def get_gene_type(gene, d):
6
+ if gene in d:
7
+ if d[gene] == 'kinase':
8
+ return 'Kinase'
9
+ if d[gene] == 'tf':
10
+ return 'TF'
11
+ else:
12
+ return 'Other'
13
+
14
+ # Load TF and Kinase Fusions
15
+ def main():
16
+ os.makedirs("processed_data", exist_ok=True)
17
+
18
+ tf_kinase_parts = pd.read_csv("raw_data/salokas_2020_tableS3.csv")
19
+ print(tf_kinase_parts)
20
+ ht_tf_kinase_dict = dict(zip(tf_kinase_parts['Gene'],tf_kinase_parts['Kinase or TF']))
21
+
22
+ ## Categorize everything in fuson_db
23
+ fuson_db = pd.read_csv("../../../data/fuson_db.csv")
24
+ print(fuson_db['benchmark'].value_counts())
25
+ print(fuson_db.loc[fuson_db['benchmark'].notna()])
26
+ fgenes = fuson_db.loc[fuson_db['benchmark'].notna()]['fusiongenes'].to_list()
27
+ print(fuson_db.columns)
28
+ print(fuson_db)
29
+
30
+ # This one has each row with one fusiongene name
31
+ fuson_ht_db = pd.read_csv("../../../data/blast/fuson_ht_db.csv")
32
+ print(fuson_ht_db.columns)
33
+ print(fuson_ht_db)
34
+ fuson_ht_db[['hg','tg']] = fuson_ht_db['fusiongenes'].str.split("::",expand=True)
35
+ print(fuson_ht_db.loc[fuson_ht_db['hg']=='PAX3'])
36
+ print(fuson_ht_db)
37
+
38
+ fuson_ht_db['hg_type'] = fuson_ht_db['hg'].apply(lambda x: get_gene_type(x, ht_tf_kinase_dict))
39
+ fuson_ht_db['tg_type'] = fuson_ht_db['tg'].apply(lambda x: get_gene_type(x, ht_tf_kinase_dict))
40
+ fuson_ht_db['fusion_type'] = fuson_ht_db['hg_type']+'::'+fuson_ht_db['tg_type']
41
+ fuson_ht_db['type']=['fusion']*len(fuson_ht_db)
42
+
43
+ # Keep things in the test set
44
+ test_set = pd.read_csv("../../../data/splits/test_df.csv")
45
+ print(test_set.columns, len(test_set))
46
+ test_seqs = test_set['sequence'].tolist()
47
+ fuson_ht_db = fuson_ht_db.loc[
48
+ fuson_ht_db['aa_seq'].isin(test_seqs)
49
+ ].sort_values(by=['fusion_type']).reset_index(drop=True)
50
+ fuson_ht_db.to_csv("processed_data/test_seqs_tftf_kk.csv", index=False)
51
+
52
+ # isolate a few transcription factor fusions of interest and keep the longest sequence of each
53
+ fusion_genes_of_interest = [
54
+ "EWSR1::FLI1", "PAX3::FOXO1", "TRIM24::RET", "ETV6::NTRK3"
55
+ ]
56
+ df_of_interest = fuson_ht_db.loc[
57
+ fuson_ht_db['fusiongenes'].isin(fusion_genes_of_interest)
58
+ ].sort_values(by=['fusiongenes','length'],ascending=[True,False]).reset_index(drop=True).drop_duplicates(subset='fusiongenes').reset_index(drop=True)
59
+ #df_of_interest.to_csv("domain_conservation_fusions.csv",index=False)
60
+ # Make a file for input into
61
+ discovery_input = df_of_interest[['fusiongenes','length','aa_seq']]
62
+ discovery_input['start_residue_index'] = [1]*len(discovery_input)
63
+ discovery_input['n'] = [3]*len(discovery_input)
64
+ discovery_input = discovery_input.rename(columns={'length':'end_residue_index',
65
+ 'aa_seq': 'full_fusion_sequence',
66
+ 'fusiongenes':'fusion_name'})
67
+ discovery_input[['fusion_name','full_fusion_sequence','start_residue_index','end_residue_index','n']].to_csv("processed_data/domain_conservation_fusions_inputfile.csv",index=False)
68
+ print(discovery_input)
69
+
70
+ if __name__ == "__main__":
71
+ main()
fuson_plm/benchmarking/mutation_prediction/discovery/color_discovered_mutations.ipynb ADDED
@@ -0,0 +1,418 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {
6
+ "id": "FJd6a9gdZNjG"
7
+ },
8
+ "source": [
9
+ "### Imports"
10
+ ]
11
+ },
12
+ {
13
+ "cell_type": "code",
14
+ "execution_count": null,
15
+ "metadata": {},
16
+ "outputs": [],
17
+ "source": [
18
+ "!pip install torch pandas numpy py3Dmol scikit-learn"
19
+ ]
20
+ },
21
+ {
22
+ "cell_type": "code",
23
+ "execution_count": 37,
24
+ "metadata": {
25
+ "id": "ZEWZVc9lUxjI"
26
+ },
27
+ "outputs": [],
28
+ "source": [
29
+ "import torch\n",
30
+ "import torch.nn as nn\n",
31
+ "\n",
32
+ "import pickle\n",
33
+ "import pandas as pd\n",
34
+ "import numpy as np\n",
35
+ "\n",
36
+ "import py3Dmol\n",
37
+ "\n",
38
+ "from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, precision_recall_curve, average_precision_score"
39
+ ]
40
+ },
41
+ {
42
+ "cell_type": "code",
43
+ "execution_count": null,
44
+ "metadata": {},
45
+ "outputs": [],
46
+ "source": [
47
+ "import os\n",
48
+ "os.getcwd()"
49
+ ]
50
+ },
51
+ {
52
+ "cell_type": "markdown",
53
+ "metadata": {},
54
+ "source": [
55
+ "# Look at all results for ETV6::NTRK3 discovery"
56
+ ]
57
+ },
58
+ {
59
+ "cell_type": "code",
60
+ "execution_count": null,
61
+ "metadata": {},
62
+ "outputs": [],
63
+ "source": [
64
+ "import pickle\n",
65
+ "path_to_pkl = \"../discovery/results/final/ETV6::NTRK3/raw_mutation_results.pkl\"\n",
66
+ "with open(path_to_pkl, \"rb\") as f:\n",
67
+ " etv6_ntrk3_logits = pickle.load(f)\n",
68
+ "\n",
69
+ "print(etv6_ntrk3_logits)"
70
+ ]
71
+ },
72
+ {
73
+ "cell_type": "code",
74
+ "execution_count": 40,
75
+ "metadata": {},
76
+ "outputs": [],
77
+ "source": [
78
+ "# Define paths and dataframes that we will need \n",
79
+ "fusion_benchmark_set = pd.read_csv('../../caid/splits/fusion_bench_df.csv')\n",
80
+ "fusion_structure_data = pd.read_csv('../../caid/processed_data/fusionpdb/FusionPDB_level2-3_cleaned_structure_info.csv')\n",
81
+ "fusion_structure_data['Fusion_Structure_Link'] = fusion_structure_data['Fusion_Structure_Link'].apply(lambda x: x.split('/')[-1])\n",
82
+ "fusion_structure_folder = \"raw_data/fusionpdb/structures\""
83
+ ]
84
+ },
85
+ {
86
+ "cell_type": "code",
87
+ "execution_count": null,
88
+ "metadata": {},
89
+ "outputs": [],
90
+ "source": [
91
+ "# merge fusion data with seq ids \n",
92
+ "fuson_db = pd.read_csv('../../../data/fuson_db.csv')\n",
93
+ "print(fuson_db.columns)\n",
94
+ "fuson_db = fuson_db[['aa_seq','seq_id']].rename(columns={'aa_seq':'Fusion_Seq'})\n",
95
+ "print(f\"Length of fusion structure data before merge on seqid: {len(fusion_structure_data)}\")\n",
96
+ "fusion_structure_data = pd.merge(\n",
97
+ " fusion_structure_data,\n",
98
+ " fuson_db,\n",
99
+ " on='Fusion_Seq',\n",
100
+ " how='inner'\n",
101
+ ")\n",
102
+ "print(f\"Length of fusion structure data after merge on seqid: {len(fusion_structure_data)}\")\n",
103
+ "fusion_structure_data"
104
+ ]
105
+ },
106
+ {
107
+ "cell_type": "code",
108
+ "execution_count": null,
109
+ "metadata": {},
110
+ "outputs": [],
111
+ "source": [
112
+ "# merge fusion structure data with top swissprot alignments\n",
113
+ "swissprot_top_alignments = pd.read_csv(\"../../../data/blast/blast_outputs/swissprot_top_alignments.csv\")\n",
114
+ "fusion_structure_data = pd.merge(\n",
115
+ " fusion_structure_data,\n",
116
+ " swissprot_top_alignments,\n",
117
+ " on=\"seq_id\",\n",
118
+ " how=\"left\"\n",
119
+ ")\n",
120
+ "fusion_structure_data.head()"
121
+ ]
122
+ },
123
+ {
124
+ "cell_type": "code",
125
+ "execution_count": null,
126
+ "metadata": {},
127
+ "outputs": [],
128
+ "source": [
129
+ "selected_row = fusion_structure_data.loc[\n",
130
+ " (fusion_structure_data['FusionGene'].str.contains('ETV6::NTRK3')) &\n",
131
+ " (fusion_structure_data['aa_seq_len']==528)\n",
132
+ "].reset_index(drop=True).iloc[0,:]\n",
133
+ "selected_row"
134
+ ]
135
+ },
136
+ {
137
+ "cell_type": "code",
138
+ "execution_count": null,
139
+ "metadata": {},
140
+ "outputs": [],
141
+ "source": [
142
+ "seq = selected_row[\"Fusion_Seq\"]\n",
143
+ "kinase_seq = \"IVLKRELGEGAFGKVFLAECYNLSPTKDKMLVAVKALKDPTLAARKDFQREAELLTNLQHEHIVKFYGVCGDGDPLIMVFEYMKHGDLNKFLRAHGPDAMILVDGQPRQAKGELGLSQMLHIASQIASGMVYLASQHFVHRDLATRNCLVGANLLVKIGDFGMSRDVYSTDYYRLFNPSGNDFCIWCEVGGHTMLPIRWMPPESIMYRKFTTESDVWSFGVILWEIFTYGKQPWFQLSNTEVIECITQGRVLERPRVCPKEVYDVMLGCWQREPQQRLNIKEIYKILHALGKATPIYLDILG\"\n",
144
+ "print(\"Length of kinase domain: \", len(kinase_seq))\n",
145
+ "print(\"1-indexed start position of kinase domain\", seq.index(kinase_seq)-1)\n",
146
+ "print(\"1-indexed end position of kinase domain (inclusive)\", seq.index(kinase_seq)-1+len(kinase_seq)-1)\n"
147
+ ]
148
+ },
149
+ {
150
+ "cell_type": "code",
151
+ "execution_count": 45,
152
+ "metadata": {},
153
+ "outputs": [],
154
+ "source": [
155
+ "kinase_seq = \"IVLKRELGEGAFGKVFLAECYNLSPTKDKMLVAVKALKDPTLAARKDFQREAELLTNLQHEHIVKFYGVCGDGDPLIMVFEYMKHGDLNKFLRAHGPDAMILVDGQPRQAKGELGLSQMLHIASQIASGMVYLASQHFVHRDLATRNCLVGANLLVKIGDFGMSRDVYSTDYYRLFNPSGNDFCIWCEVGGHTMLPIRWMPPESIMYRKFTTESDVWSFGVILWEIFTYGKQPWFQLSNTEVIECITQGRVLERPRVCPKEVYDVMLGCWQREPQQRLNIKEIYKILHALGKATPIYLDILG\""
156
+ ]
157
+ },
158
+ {
159
+ "cell_type": "code",
160
+ "execution_count": null,
161
+ "metadata": {},
162
+ "outputs": [],
163
+ "source": [
164
+ "!pip install transformers"
165
+ ]
166
+ },
167
+ {
168
+ "cell_type": "code",
169
+ "execution_count": 47,
170
+ "metadata": {},
171
+ "outputs": [],
172
+ "source": [
173
+ "from transformers import AutoTokenizer\n",
174
+ "import torch.nn.functional as F\n",
175
+ "fuson_ckpt_path = \"../../../..\"\n",
176
+ "fuson_tokenizer = AutoTokenizer.from_pretrained(fuson_ckpt_path)"
177
+ ]
178
+ },
179
+ {
180
+ "cell_type": "code",
181
+ "execution_count": null,
182
+ "metadata": {},
183
+ "outputs": [],
184
+ "source": [
185
+ "print(len(etv6_ntrk3_logits['originals_logits']))\n",
186
+ "print(len(etv6_ntrk3_logits['conservation_likelihoods']))\n",
187
+ "print(len(etv6_ntrk3_logits['logits_for_each_AA']))\n",
188
+ "\n",
189
+ "start = etv6_ntrk3_logits['start']\n",
190
+ "end = etv6_ntrk3_logits['end']\n",
191
+ "originals_logits = etv6_ntrk3_logits['originals_logits']\n",
192
+ "conservation_likelihoods = etv6_ntrk3_logits['conservation_likelihoods']\n",
193
+ "logits = etv6_ntrk3_logits['logits']\n",
194
+ "logits_for_each_AA = etv6_ntrk3_logits['logits_for_each_AA']\n",
195
+ "filtered_indices = etv6_ntrk3_logits['filtered_indices']\n",
196
+ "top_n_mutations = etv6_ntrk3_logits['top_n_mutations']\n",
197
+ "\n",
198
+ "token_indices = torch.arange(logits.size(-1))\n",
199
+ "tokens = [fuson_tokenizer.decode([idx]) for idx in token_indices]\n",
200
+ "filtered_tokens = [tokens[i] for i in filtered_indices]\n",
201
+ "all_logits_array = np.vstack(logits_for_each_AA)\n",
202
+ "normalized_logits_array = F.softmax(torch.tensor(all_logits_array), dim=-1).numpy()\n",
203
+ "transposed_logits_array = normalized_logits_array.T"
204
+ ]
205
+ },
206
+ {
207
+ "cell_type": "code",
208
+ "execution_count": null,
209
+ "metadata": {},
210
+ "outputs": [],
211
+ "source": [
212
+ "# add these logits as a b factor to the pdb file \n",
213
+ "import re\n",
214
+ "\n",
215
+ "path_to_cif = f\"../../caid/raw_data/fusionpdb/structures/{selected_row['Fusion_Structure_Link']}\"\n",
216
+ "path_to_modified_cif = f\"{selected_row['Fusion_Structure_Link'].split('.')[0]}_logit_bfactor.cif\"\n",
217
+ "\n",
218
+ "def modify(path_to_cif, path_to_modified_cif, new_b_values):\n",
219
+ " with open(path_to_cif, 'r') as f:\n",
220
+ " lines = f.readlines()\n",
221
+ "\n",
222
+ " exline = ''\n",
223
+ " with open(path_to_modified_cif, 'w') as f:\n",
224
+ " in_aa_loop=False\n",
225
+ " in_atom_loop=False\n",
226
+ " done_aa=False\n",
227
+ " done_atom=False\n",
228
+ " counter_aa=0\n",
229
+ " counter_atom=0\n",
230
+ " seqlen=len(selected_row['Fusion_Seq'])\n",
231
+ " for line in lines:\n",
232
+ " newline=line\n",
233
+ " if line.startswith(\"A\\t\") and not(done_aa):\n",
234
+ " in_aa_loop=True\n",
235
+ " else:\n",
236
+ " in_aa_loop=False\n",
237
+ " if line.startswith(\"ATOM \") and not(done_atom):\n",
238
+ " in_atom_loop=True\n",
239
+ " else:\n",
240
+ " in_atom_loop=False\n",
241
+ " \n",
242
+ " if in_aa_loop:\n",
243
+ " counter_aa+=1\n",
244
+ " split_line = line.split()\n",
245
+ " one_indexed_pos = int(split_line[2])\n",
246
+ " new_value = new_b_values[one_indexed_pos]\n",
247
+ " new_value = round(new_value,4)\n",
248
+ " split_line[4] = str(new_value)\n",
249
+ " newline = '\\t'.join(split_line)+'\\n'\n",
250
+ " \n",
251
+ " if in_atom_loop:\n",
252
+ " counter_atom+=1\n",
253
+ " split_line = re.split(r'(\\s+)', line)\n",
254
+ " one_indexed_pos = int(split_line[16])\n",
255
+ " new_value = new_b_values[one_indexed_pos]\n",
256
+ " new_value = round(new_value,4)\n",
257
+ " split_line[28] = str(new_value)\n",
258
+ " newline = ''.join(split_line)\n",
259
+ " #print(split_line)\n",
260
+ " \n",
261
+ " if counter_aa==seqlen:\n",
262
+ " done_aa=True\n",
263
+ " in_aa_loop=False\n",
264
+ " \n",
265
+ " f.write(newline)\n"
266
+ ]
267
+ },
268
+ {
269
+ "cell_type": "code",
270
+ "execution_count": null,
271
+ "metadata": {},
272
+ "outputs": [],
273
+ "source": [
274
+ "# read ETV6::NTRK3 predictions\n",
275
+ "path_to_preds = \"../discovery/results/final/ETV6::NTRK3/full_results_with_logits.csv\"\n",
276
+ "preds = pd.read_csv(path_to_preds)\n",
277
+ "original_seq = ''.join(preds['original_residue'].tolist())\n",
278
+ "print(original_seq)\n",
279
+ "\n",
280
+ "# what is the structural data for this\n",
281
+ "structure_data = fusion_structure_data.loc[\n",
282
+ " fusion_structure_data['Fusion_Seq']==original_seq\n",
283
+ "].reset_index(drop=True).loc[0]\n",
284
+ "print(structure_data.to_string())\n",
285
+ "print(structure_data['top_hg_UniProt_fus_indices'], structure_data['top_tg_UniProt_fus_indices'])\n",
286
+ "print(structure_data['Fusion_Structure_Link'])\n"
287
+ ]
288
+ },
289
+ {
290
+ "cell_type": "code",
291
+ "execution_count": null,
292
+ "metadata": {},
293
+ "outputs": [],
294
+ "source": [
295
+ "path_to_cif = f\"../../caid/raw_data/fusionpdb/structures/{structure_data['Fusion_Structure_Link']}\"\n",
296
+ "path_to_modified_cif = f\"processed_data/{structure_data['Fusion_Structure_Link'].split('.')[0]}_logit_bfactor.cif\"\n",
297
+ "\n",
298
+ "new_b_values = dict(zip(preds[\"Residue\"],preds[\"original_residue_logit\"]))\n",
299
+ "\n",
300
+ "modify(path_to_cif, path_to_modified_cif, new_b_values)"
301
+ ]
302
+ },
303
+ {
304
+ "cell_type": "code",
305
+ "execution_count": null,
306
+ "metadata": {},
307
+ "outputs": [],
308
+ "source": [
309
+ "# spectrum b to color by logits\n",
310
+ "spectrum b\n",
311
+ "\n",
312
+ "sele head, resi 1-338\n",
313
+ "sele tail, resi 339-647\n",
314
+ "sele kinase, resi 346-647\n",
315
+ "set cartoon_transparency, 0.8, tail\n",
316
+ "set cartoon_transparency, 0.8, kinase\n",
317
+ "sele G623R, resi 431\n",
318
+ "sele G696A, resi 504\n",
319
+ "color orange, G623R\n",
320
+ "set cartoon_transparency, 0, G623R\n",
321
+ "show licorice, G623R\n",
322
+ "color orange, G696A\n",
323
+ "set cartoon_transparency, 0, G696A\n",
324
+ "show licorice, G696A"
325
+ ]
326
+ },
327
+ {
328
+ "cell_type": "code",
329
+ "execution_count": null,
330
+ "metadata": {},
331
+ "outputs": [],
332
+ "source": [
333
+ "# PYMOL code for ETV6::NTRK3\n",
334
+ "\n",
335
+ "\n",
336
+ "# Define a custom color for the kinase domain\n",
337
+ "set_color custom_red, [0xdf/255, 0x83/255, 0x85/255]\n",
338
+ "set_color custom_blue, [0x6e/255, 0xa4/255, 0xda/255]\n",
339
+ "\n",
340
+ "# Select and color residues 1085-1336\n",
341
+ "sele head, resi 1-338\n",
342
+ "sele tail, resi 339-647\n",
343
+ "sele kinase, resi 346-647\n",
344
+ "color custom_red, head\n",
345
+ "color custom_blue, tail\n",
346
+ "set cartoon_transparency, 0, tail\n",
347
+ "set cartoon_transparency, 0.8, kinase\n",
348
+ "sele G623R, resi 431\n",
349
+ "sele G696A, resi 504\n",
350
+ "color magenta, G623R\n",
351
+ "set cartoon_transparency, 0, G623R\n",
352
+ "show licorice, G623R\n",
353
+ "color magenta, G696A\n",
354
+ "set cartoon_transparency, 0, G696A\n",
355
+ "show licorice, G696A\n",
356
+ "\n",
357
+ "# Color the known mutations that impact drug efficacy\n",
358
+ "\n",
359
+ "# Select missed mutation residues, color them viridis, and make them fully opaque\n",
360
+ "# do the viridis outside the "
361
+ ]
362
+ },
363
+ {
364
+ "cell_type": "code",
365
+ "execution_count": null,
366
+ "metadata": {},
367
+ "outputs": [],
368
+ "source": [
369
+ "# PYMOL code for ETV6::NTRK3\n",
370
+ "# Set global cartoon transparency\n",
371
+ "# 0 for zoomed out, 0.5 for close up\n",
372
+ "color gray60\n",
373
+ "set cartoon_transparency, 0\n",
374
+ "\n",
375
+ "# Define a custom color for the kinase domain\n",
376
+ "set_color custom_blue, [0x6e/255, 0xa4/255, 0xda/255]\n",
377
+ "\n",
378
+ "# Select and color residues 1085-1336\n",
379
+ "sele ntrk3_kinase, resi 225-526\n",
380
+ "color custom_blue, ntrk3_kinase\n",
381
+ "\n",
382
+ "# Select missed mutation residues, color them orange, and make them fully opaque\n",
383
+ "spectrum b, blue_white_red, minimum=0.8, maximum=1"
384
+ ]
385
+ }
386
+ ],
387
+ "metadata": {
388
+ "colab": {
389
+ "collapsed_sections": [
390
+ "FJd6a9gdZNjG",
391
+ "zORkLJztZWp9",
392
+ "w25hagtZaV65",
393
+ "IbyqxlvAFUAK",
394
+ "0n5PSprbhLk7"
395
+ ],
396
+ "machine_shape": "hm",
397
+ "provenance": []
398
+ },
399
+ "kernelspec": {
400
+ "display_name": "Python 3",
401
+ "name": "python3"
402
+ },
403
+ "language_info": {
404
+ "codemirror_mode": {
405
+ "name": "ipython",
406
+ "version": 3
407
+ },
408
+ "file_extension": ".py",
409
+ "mimetype": "text/x-python",
410
+ "name": "python",
411
+ "nbconvert_exporter": "python",
412
+ "pygments_lexer": "ipython3",
413
+ "version": "3.10.12"
414
+ }
415
+ },
416
+ "nbformat": 4,
417
+ "nbformat_minor": 0
418
+ }
fuson_plm/benchmarking/mutation_prediction/discovery/config.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Model settings: where is the model you wish to use for mutation discovery?
2
+ FUSON_PLM_CKPT = "FusOn-pLM"
3
+
4
+ #### Fill in this sectinon if you have one input
5
+ # Sequence settings: need full sequence of fusion oncoprotein, and the bounds of region of interest
6
+ FULL_FUSION_SEQUENCE = ""
7
+ FUSION_NAME = "fusion_example"
8
+ START_RESIDUE_INDEX = 1
9
+ END_RESIDUE_INDEX = 100
10
+ N = 3 # number of mutations to predict per amio acid
11
+
12
+ #### Fill in this section if you have multiple input
13
+ PATH_TO_INPUT_FILE = "processed_data/domain_conservation_fusions_inputfile.csv" # if you don't have an input file and want to do one sequence, set this variable to None
14
+
15
+ # GPU Settings: which GPUs should be available to run this discovery?
16
+ CUDA_VISIBLE_DEVICES = "0"
fuson_plm/benchmarking/mutation_prediction/discovery/discover.py ADDED
@@ -0,0 +1,346 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ##### Discover mutations in new sequences. A tool
2
+ import fuson_plm.benchmarking.mutation_prediction.discovery.config as config
3
+ import os
4
+ import pickle
5
+ os.environ['CUDA_VISIBLE_DEVICES'] = config.CUDA_VISIBLE_DEVICES
6
+
7
+ import pandas as pd
8
+ import numpy as np
9
+ import transformers
10
+ from transformers import AutoTokenizer, AutoModelForMaskedLM
11
+ import logging
12
+ import torch
13
+ import matplotlib.pyplot as plt
14
+ import seaborn as sns
15
+ import torch.nn.functional as F
16
+
17
+ from fuson_plm.utils.logging import open_logfile, log_update, get_local_time, print_configpy
18
+ from fuson_plm.utils.embedding import load_esm2_type
19
+ from fuson_plm.utils.visualizing import set_font
20
+ from fuson_plm.benchmarking.mutation_prediction.recovery.recover import check_env_variables, predict_positionwise_mutations
21
+ from fuson_plm.benchmarking.mutation_prediction.discovery.plot import plot_conservation_heatmap, plot_full_heatmap
22
+
23
+ def check_seq_inputs(sequence, AAs_tokens):
24
+ # checking sequence inputs for validity
25
+ if not sequence.strip():
26
+ raise Exception("Error: The sequence input is empty. Please enter a valid protein sequence.")
27
+ return None, None, None
28
+ if any(char not in AAs_tokens for char in sequence):
29
+ raise Exception("Error: The sequence input contains non-amino acid characters. Please enter a valid protein sequence.")
30
+ return None, None, None
31
+
32
+ def check_domain_bounds(domain_bounds):
33
+ try:
34
+ start = int(domain_bounds['start'])
35
+ end = int(domain_bounds['end'])
36
+ return start, end
37
+ except ValueError:
38
+ raise Exception("Error: Start and end indices must be integers.")
39
+ return None, None
40
+ if start >= end:
41
+ raise Exception("Start index must be smaller than end index.")
42
+ return None, None
43
+ if start == 0 and end != 0:
44
+ raise Exception("Indexing starts at 1. Please enter valid domain bounds.")
45
+ return None, None
46
+ if start <= 0 or end <= 0:
47
+ raise Exception("Domain bounds must be positive integers. Please enter valid domain bounds.")
48
+ return None, None
49
+ if start > len(sequence) or end > len(sequence):
50
+ raise Exception("Domain bounds exceed sequence length.")
51
+ return None, None
52
+
53
+ def check_n_input(n):
54
+ if n < 1:
55
+ raise Exception("Choose N>=1")
56
+ return None, None, None
57
+
58
+ def predict_positionwise_mutations(sequence, domain_bounds, n, model, tokenizer, device):
59
+ # Define amino acids and their token indices
60
+ AAs_tokens = ['L', 'A', 'G', 'V', 'S', 'E', 'R', 'T', 'I', 'D', 'P', 'K', 'Q', 'N', 'F', 'Y', 'M', 'H', 'W', 'C']
61
+ AAs_tokens_indices = {'L' : 4, 'A' : 5, 'G' : 6, 'V': 7, 'S' : 8, 'E' : 9, 'R' : 10, 'T' : 11, 'I': 12, 'D' : 13, 'P' : 14,
62
+ 'K' : 15, 'Q' : 16, 'N' : 17, 'F' : 18, 'Y' : 19, 'M' : 20, 'H' : 21, 'W' : 22, 'C' : 23}
63
+
64
+ # checking all inputs for validity
65
+ log_update("\nChecking validity of sequence input, domain bounds, and N mutations")
66
+ check_seq_inputs(sequence, AAs_tokens)
67
+ start, end = check_domain_bounds(domain_bounds)
68
+ check_n_input(n)
69
+
70
+ # define start_index as start - 1 (because residues are 1-indexed, while Python is 0-indexed). end is same
71
+ start_index = start - 1
72
+ end_index = end
73
+
74
+ # place to store top n mutations and all logits
75
+ top_n_mutations = {}
76
+ top_n_advantage_mutations = {}
77
+ top_n_disadvantage_mutations = {}
78
+ logits_for_each_AA = []
79
+ llrs_for_each_AA = []
80
+
81
+ # storage for the conservation heatmap
82
+ originals_logits = []
83
+ conservation_likelihoods = {}
84
+
85
+ log_update("\nCalculating mutations. Printing currently masked position and mutation results.")
86
+ for i in range(len(sequence)):
87
+ # only iterate through the residues inside the domain
88
+ if start_index <= i <= (end_index - 1):
89
+ # isolate original residue and its index
90
+ original_residue = sequence[i]
91
+ original_residue_index = AAs_tokens_indices[original_residue]
92
+ masked_seq = sequence[:i] + '<mask>' + sequence[i+1:]
93
+
94
+ # prepare log
95
+ masked_seq_list = list(sequence[:i]) + ['<mask>'] + list(sequence[i+1:])
96
+ log_starti = i-min(5, i)
97
+ log_endi = i+5
98
+ log_update(f"\t{i+1}: residue = {original_residue}, masked sequence preview (pos {log_starti+1}-{log_endi}) = {''.join(masked_seq_list[log_starti:log_endi])}")
99
+ # prepare inputs
100
+ inputs = tokenizer(masked_seq, return_tensors="pt", padding=True, truncation=True, max_length=len(masked_seq)+2)
101
+ inputs = {k: v.to(device) for k, v in inputs.items()}
102
+ # forward pass
103
+ with torch.no_grad():
104
+ logits = model(**inputs).logits
105
+
106
+ # Find masked positions and extract their logits
107
+ mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1]
108
+ mask_token_logits = logits[0, mask_token_index, :] # shape: [1, vocab_size] == [1, 33]. logits for each vocab word at this position in the sequence
109
+
110
+ # Collect logits for the full heamtap
111
+ logits_array = mask_token_logits.cpu().numpy() # shape: [1, 33]
112
+ # filter out non-amino acid tokens
113
+ filtered_indices = list(range(4, 23 + 1)) # filtered indices are indices of amino acids
114
+ filtered_logits = logits_array[:, filtered_indices] # shape: [1, 20] only the 20 amino acids
115
+ logits_for_each_AA.append(filtered_logits) # get logits for each amino acid
116
+
117
+ # Collect LLRs for the LLR heatmap
118
+ log_probabilities = F.log_softmax(torch.tensor(mask_token_logits).cpu(), dim=-1).squeeze(0) # take log softmax of the [33] dimension
119
+ log_prob_og = log_probabilities[original_residue_index] # get the log probability of the TRUE residue underneath the mask
120
+ llrs = torch.tensor([(x-log_prob_og) for x in log_probabilities]) # calculate the LLR
121
+ #print(original_residue_index, llrs)
122
+ filtered_llrs = llrs[filtered_indices].numpy() # filter so it's [20], just the amino acids; only save this
123
+ filtered_llrs_array = np.array([filtered_llrs])
124
+ llrs_for_each_AA.append(filtered_llrs_array)
125
+
126
+ ######### Top tokens
127
+ # Get top tokens based on LOGITS
128
+ all_tokens_logits = mask_token_logits.squeeze(0) # shape: [vocab_size] == [33]
129
+ top_tokens_indices = torch.argsort(all_tokens_logits, dim=0, descending=True) # sort the logits
130
+ mutation = []
131
+ # make sure we don't include non-AA tokens
132
+ for token_index in top_tokens_indices:
133
+ decoded_token = tokenizer.decode([token_index.item()])
134
+ # decoded all tokens, pick the top n amino acid ones
135
+ if decoded_token in AAs_tokens:
136
+ mutation.append(decoded_token)
137
+ if len(mutation) == n:
138
+ break
139
+ top_n_mutations[(sequence[i], i)] = mutation
140
+ log_update(f"\t\ttop {n} predicted AAs: {','.join(mutation)}")
141
+
142
+ # Get top tokens based on LLR
143
+ top_advantage_tokens_indices = torch.argsort(llrs, dim=0, descending=True) # sort the LLRs
144
+ advantage_mutation = []
145
+ # make sure we don't include non-AA tokens
146
+ for token_index in top_advantage_tokens_indices:
147
+ decoded_token = tokenizer.decode([token_index.item()])
148
+ # decoded all tokens, pick the top n amino acid ones
149
+ if decoded_token in AAs_tokens:
150
+ advantage_mutation.append(decoded_token)
151
+ if len(advantage_mutation) == n:
152
+ break
153
+ top_n_advantage_mutations[(sequence[i], i)] = advantage_mutation
154
+ log_update(f"\t\ttop {n} predicted advantageous mutations: {','.join(advantage_mutation)}")
155
+
156
+ # Get top tokens based on LLR
157
+ top_disadvantage_tokens_indices = torch.argsort(llrs, dim=0, descending=False) # sort the LLRs
158
+ disadvantage_mutation = []
159
+ # make sure we don't include non-AA tokens
160
+ for token_index in top_disadvantage_tokens_indices:
161
+ decoded_token = tokenizer.decode([token_index.item()])
162
+ # decoded all tokens, pick the top n amino acid ones
163
+ if decoded_token in AAs_tokens:
164
+ disadvantage_mutation.append(decoded_token)
165
+ if len(disadvantage_mutation) == n:
166
+ break
167
+ top_n_disadvantage_mutations[(sequence[i], i)] = disadvantage_mutation
168
+ log_update(f"\t\ttop {n} predicted disadvantageous mutations: {','.join(disadvantage_mutation)}")
169
+
170
+ # fill in the logits and conservation likelihoods for the second array
171
+ normalized_mask_token_logits = F.softmax(torch.tensor(mask_token_logits).cpu(), dim=-1).numpy()
172
+ normalized_mask_token_logits = np.squeeze(normalized_mask_token_logits)
173
+ originals_logit = normalized_mask_token_logits[original_residue_index]
174
+ originals_logits.append(originals_logit)
175
+
176
+ # a region is conserved if the probability of the amino acid is > 0.7; AKA, probability of a mutation is <= 0.3
177
+ if originals_logit > 0.7:
178
+ conservation_likelihoods[(original_residue, i)] = 1
179
+ log_update("\t\tConserved position")
180
+ else:
181
+ conservation_likelihoods[(original_residue, i)] = 0
182
+ log_update("\t\tNot conserved position")
183
+
184
+ # return a dictionary of all the things we need for the next part
185
+ return {
186
+ 'start': start,
187
+ 'end': end,
188
+ 'originals_logits': originals_logits,
189
+ 'conservation_likelihoods': conservation_likelihoods,
190
+ 'logits': logits,
191
+ 'filtered_indices': filtered_indices,
192
+ 'top_n_mutations': top_n_mutations,
193
+ 'top_n_advantage_mutations': top_n_advantage_mutations,
194
+ 'top_n_disadvantage_mutations': top_n_disadvantage_mutations,
195
+ 'logits_for_each_AA': logits_for_each_AA,
196
+ 'llrs_for_each_AA': llrs_for_each_AA
197
+ }
198
+
199
+ def find_top_3(d):
200
+ temp = pd.DataFrame.from_dict(d, orient='index').reset_index()
201
+ temp = temp.sort_values(by=0,ascending=False).reset_index(drop=True)
202
+ temp = temp.iloc[0:3,:]
203
+ return_d = dict(zip(temp['index'],temp[0]))
204
+ return return_d
205
+
206
+ def make_full_results_df(mutation_results, tokenizer, original_sequence):
207
+ # Unpack mutation results
208
+ logits = mutation_results['logits']
209
+ logits_for_each_AA = mutation_results['logits_for_each_AA']
210
+ filtered_indices = mutation_results['filtered_indices']
211
+ llrs_for_each_AA = mutation_results['llrs_for_each_AA']
212
+
213
+ token_indices = torch.arange(logits.size(-1))
214
+ tokens = [tokenizer.decode([idx]) for idx in token_indices]
215
+ filtered_tokens = [tokens[i] for i in filtered_indices]
216
+ all_logits_array = np.vstack(logits_for_each_AA)
217
+ normalized_logits_array = F.softmax(torch.tensor(all_logits_array), dim=-1).numpy()
218
+ transposed_logits_array = normalized_logits_array.T
219
+
220
+ df = pd.DataFrame(transposed_logits_array.T)
221
+ df.columns = filtered_tokens
222
+ df.index = list(range(1, len(df)+1))
223
+ df['all_logits'] = df[filtered_tokens].to_dict(orient='index')
224
+ df['top_3_mutations'] = df['all_logits'].apply(lambda x: find_top_3(x))
225
+ df['original_residue'] = list(original_sequence)
226
+ df['original_residue_logit'] = df.apply(lambda row: row['all_logits'][row['original_residue']],axis=1)
227
+ df = df[['original_residue','original_residue_logit','all_logits','top_3_mutations']]
228
+ df = df.reset_index().rename(columns={'index':'Residue'})
229
+ return df
230
+
231
+ def make_small_results_df(mutation_results):
232
+ conservation_likelihoods = mutation_results['conservation_likelihoods']
233
+ top_n_mutations = mutation_results['top_n_mutations']
234
+
235
+ # store the predicted mutations in a dataframe
236
+ original_residues = []
237
+ mutations = []
238
+ positions = []
239
+ conserved = []
240
+
241
+ for key, value in top_n_mutations.items():
242
+ original_residue, position = key
243
+ original_residues.append(original_residue)
244
+ mutations.append(','.join(value))
245
+ positions.append(position + 1)
246
+
247
+ for i, (key, value) in enumerate(conservation_likelihoods.items()):
248
+ original_residue, position = key
249
+ if original_residues[i]==original_residue: # it should, otherwise something is wrong
250
+ conserved.append(value)
251
+
252
+ df = pd.DataFrame({
253
+ 'Original Residue': original_residues,
254
+ 'Predicted Residues': mutations,
255
+ 'Conserved': conserved,
256
+ 'Position': positions
257
+ })
258
+ return df
259
+
260
+
261
+ def main():
262
+ # Make results directory
263
+ os.makedirs('results',exist_ok=True)
264
+ output_dir = f'results/{get_local_time()}'
265
+ os.makedirs(output_dir,exist_ok=True)
266
+
267
+ # Predict mutations, writing results to a log inside of the output directory
268
+ with open_logfile(f"{output_dir}/mutation_discovery_log.txt"):
269
+ print_configpy(config)
270
+ # Make sure environment variables are set correctly
271
+ check_env_variables()
272
+
273
+ # Get device
274
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
275
+ log_update(f"Using device: {device}")
276
+
277
+ # Load fuson as AutoModelForMaskedLM
278
+ fuson_ckpt_path = config.FUSON_PLM_CKPT
279
+ if fuson_ckpt_path=="FusOn-pLM":
280
+ fuson_ckpt_path="../../../.."
281
+ model_name = "fuson_plm"
282
+ model_epoch = "best"
283
+ model_str = f"fuson_plm/best"
284
+ else:
285
+ model_name = list(fuson_ckpt_path.keys())[0]
286
+ epoch = list(fuson_ckpt_path.values())[0]
287
+ fuson_ckpt_path = f'../../training/checkpoints/{model_name}/checkpoint_epoch_{epoch}'
288
+ model_name, model_epoch = fuson_ckpt_path.split('/')[-2::]
289
+ model_epoch = model_epoch.split('checkpoint_')[-1]
290
+ model_str = f"{model_name}/{model_epoch}"
291
+
292
+ log_update(f"\nLoading FusOn-pLM model from {fuson_ckpt_path}")
293
+ fuson_tokenizer = AutoTokenizer.from_pretrained(fuson_ckpt_path)
294
+ fuson_model = AutoModelForMaskedLM.from_pretrained(fuson_ckpt_path)
295
+ fuson_model.to(device)
296
+ fuson_model.eval()
297
+
298
+ if (config.PATH_TO_INPUT_FILE is not None) and os.path.exists(config.PATH_TO_INPUT_FILE):
299
+ input_file = pd.read_csv(config.PATH_TO_INPUT_FILE)
300
+ else:
301
+ input_file = pd.DataFrame(
302
+ data={
303
+ 'fusion_name': [config.FUSION_NAME],
304
+ 'full_fusion_sequence': [config.FULL_FUSION_SEQUENCE],
305
+ 'start_residue_index': [config.START_RESIDUE_INDEX],
306
+ 'end_residue_index': [config.END_RESIDUE_INDEX],
307
+ 'n': [config.N]
308
+ }
309
+ )
310
+
311
+ log_update(f"\nThere are {len(input_file)} total sequences on which to perform mutation discovery. Fusion Genes:")
312
+ log_update("\t" + "\n\t".join(input_file['fusion_name']))
313
+ # Loop through each input and make a subfolder with its data
314
+ for i in range(len(input_file)):
315
+ row = input_file.loc[i,:]
316
+ fusion_name = row['fusion_name']
317
+ full_fusion_sequence = row['full_fusion_sequence']
318
+ start_residue_index = row['start_residue_index']
319
+ end_residue_index = row['end_residue_index']
320
+ n = row['n']
321
+ sub_output_dir = f"{output_dir}/{fusion_name}"
322
+ os.makedirs(sub_output_dir,exist_ok=True)
323
+
324
+ # Predict postionwise mutations, plot the results
325
+ domain_bounds = {'start': start_residue_index, 'end': end_residue_index}
326
+ mutation_results = predict_positionwise_mutations(full_fusion_sequence, domain_bounds, n,
327
+ fuson_model, fuson_tokenizer, device)
328
+
329
+ # Save mutation results
330
+ with open(f"{sub_output_dir}/raw_mutation_results.pkl", "wb") as f:
331
+ pickle.dump(mutation_results, f)
332
+
333
+ # Plot the heatmaps
334
+ plot_full_heatmap(mutation_results, fuson_tokenizer,
335
+ fusion_name=fusion_name, save_path=f"{sub_output_dir}/full_heatmap.png")
336
+ plot_conservation_heatmap(mutation_results,
337
+ fusion_name=fusion_name, save_path=f"{sub_output_dir}/conservation_heatmap.png")
338
+
339
+ # Make results dataframe
340
+ small_mutation_results_df = make_small_results_df(mutation_results)
341
+ small_mutation_results_df.to_csv(f"{sub_output_dir}/predicted_tokens.csv",index=False)
342
+ full_mutation_results_df = make_full_results_df(mutation_results, fuson_tokenizer, full_fusion_sequence)
343
+ full_mutation_results_df.to_csv(f"{sub_output_dir}/full_results_with_logits.csv",index=False)
344
+
345
+ if __name__ == "__main__":
346
+ main()
fuson_plm/benchmarking/mutation_prediction/discovery/make_color_bar.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ import numpy as np
3
+
4
+ def plot_color_bar():
5
+ """
6
+ Create a Viridis color bar ranging from 0 to 1.
7
+ """
8
+ # Create a gradient from 0 to 1
9
+ gradient = np.linspace(0, 1, 256).reshape(1, -1)
10
+
11
+ # Plot the gradient as a color bar
12
+ fig, ax = plt.subplots(figsize=(12, 3))
13
+ ax.imshow(gradient, aspect="auto", cmap="viridis")
14
+ ax.set_xticks([0, 255])
15
+ ax.set_xticklabels(["0\nmost likely\nto mutate", "1\nleast likely\nto mutate"], fontsize=40)
16
+ ax.set_yticks([])
17
+ ax.set_title("Original Residue Logits", fontsize=40)
18
+
19
+ # Save the figure
20
+ plt.tight_layout()
21
+ plt.show()
22
+ plt.savefig("viridis_color_bar.png", dpi=300)
23
+
24
+ # Call the function to create and display the color bar
25
+ plot_color_bar()
fuson_plm/benchmarking/mutation_prediction/discovery/plot.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import matplotlib.pyplot as plt
3
+ import seaborn as sns
4
+ import torch.nn.functional as F
5
+ import numpy as np
6
+ import os
7
+ import pandas as pd
8
+ import pickle
9
+ from transformers import AutoTokenizer
10
+ from fuson_plm.utils.visualizing import set_font
11
+ import fuson_plm.benchmarking.mutation_prediction.discovery.config as config
12
+
13
+ def get_x_tick_labels(start, end):
14
+ # Define start and end index which we actually use to index the sequence
15
+ start_index = start - 1
16
+ end_index = end
17
+
18
+ # Define domain length
19
+ domain_len = end - start
20
+ if 500 > domain_len > 100:
21
+ step_size = 50
22
+ elif 500 <= domain_len:
23
+ step_size = 100
24
+ elif domain_len < 10:
25
+ step_size = 1
26
+ else:
27
+ step_size = 10
28
+
29
+ # Define x tick positions based on step size
30
+ x_tick_positions = np.arange(start_index, end_index, step_size)
31
+ x_tick_labels = [str(pos + 1) for pos in x_tick_positions]
32
+
33
+ return x_tick_positions, x_tick_labels
34
+
35
+
36
+ def plot_conservation_heatmap(mutation_results, fusion_name="Fusion Oncoprotein", save_path="conservation_heatmap.png"):
37
+ start = mutation_results['start']
38
+ end = mutation_results['end']
39
+ originals_logits = mutation_results['originals_logits']
40
+ conservation_likelihoods = mutation_results['conservation_likelihoods']
41
+ logits = mutation_results['logits']
42
+ logits_for_each_AA = mutation_results['logits_for_each_AA']
43
+ filtered_indices = mutation_results['filtered_indices']
44
+ top_n_mutations = mutation_results['top_n_mutations']
45
+
46
+ # Get start index and end index
47
+ start_index = start - 1
48
+ end_index = end
49
+
50
+ # Make conservation likelihoods array for plotting
51
+ all_logits_array = np.vstack(originals_logits)
52
+ transposed_logits_array = all_logits_array.T
53
+ conservation_likelihoods_array = np.array(list(conservation_likelihoods.values())).reshape(1, -1)
54
+ # combine to make a 2D heatmap
55
+ combined_array = np.vstack((transposed_logits_array, conservation_likelihoods_array))
56
+
57
+ # Get ticks
58
+ x_tick_positions, x_tick_labels = get_x_tick_labels(start, end)
59
+
60
+ # Plot!
61
+ set_font()
62
+ # Adjust the figure size: constant height (e.g., 3) and width proportional to sequence length
63
+ sequence_length = end_index - start_index
64
+ fig = plt.figure(figsize=(min(15, sequence_length / 10), 3)) # Adjust width dynamically, keep height constant
65
+
66
+ #plt.rcParams.update({'font.size': 16.5}) # make font size bigger
67
+ ax = sns.heatmap(
68
+ combined_array,
69
+ cmap='viridis',
70
+ xticklabels=x_tick_labels,
71
+ yticklabels=['Original Logits', 'Conserved'],
72
+ cbar=True,
73
+ cbar_kws={'aspect': 2,
74
+ 'pad': 0.02,
75
+ 'shrink': 1.0, # Adjust the overall size of the color bar
76
+ }
77
+ )
78
+ # Access the color bar
79
+ cbar = ax.collections[0].colorbar
80
+
81
+ # Change the font size of the tick labels on the color bar
82
+ cbar.ax.tick_params(labelsize=20) # Adjust the font size of tick labels
83
+
84
+ plt.xticks(x_tick_positions - start_index + 0.5, x_tick_labels, rotation=90, fontsize=20)
85
+ plt.yticks(fontsize=20, rotation=0)
86
+ plt.title(f'{fusion_name} Residues {start}-{end}', fontsize=30)
87
+ plt.xlabel('Residue Index', fontsize=30)
88
+ plt.tight_layout()
89
+ plt.show()
90
+
91
+ # save the figure
92
+ plt.savefig(save_path, format='png', dpi=300)
93
+
94
+ # plotting heatmap 1
95
+ def plot_full_heatmap(mutation_results, tokenizer, fusion_name="Fusion Oncoprotein", save_path="full_heatmap.png"):
96
+ start = mutation_results['start']
97
+ end = mutation_results['end']
98
+ logits = mutation_results['logits']
99
+ logits_for_each_AA = mutation_results['logits_for_each_AA']
100
+ filtered_indices = mutation_results['filtered_indices']
101
+
102
+ # get start and end index
103
+ start_index = start - 1
104
+ end_index = end
105
+
106
+ # prepare data for plotting
107
+ token_indices = torch.arange(logits.size(-1))
108
+ tokens = [tokenizer.decode([idx]) for idx in token_indices]
109
+ filtered_tokens = [tokens[i] for i in filtered_indices]
110
+ all_logits_array = np.vstack(logits_for_each_AA)
111
+ normalized_logits_array = F.softmax(torch.tensor(all_logits_array), dim=-1).numpy()
112
+ transposed_logits_array = normalized_logits_array.T
113
+
114
+ # get x tick labels
115
+ x_tick_positions, x_tick_labels = get_x_tick_labels(start, end)
116
+
117
+ # make plot
118
+ set_font()
119
+ fig = plt.figure(figsize=(15, 8))
120
+ plt.rcParams.update({'font.size': 16.5})
121
+ sns.heatmap(transposed_logits_array, cmap='plasma', xticklabels=x_tick_labels, yticklabels=filtered_tokens)
122
+ plt.title(f'{fusion_name} Residues {start}-{end}: Token Probability')
123
+ plt.ylabel('Amino Acid')
124
+ plt.xlabel('Residue Index')
125
+ plt.yticks(rotation=0)
126
+ plt.xticks(x_tick_positions - start_index + 0.5, x_tick_labels, rotation=0)
127
+ plt.tight_layout()
128
+ plt.savefig(save_path, format='png', dpi = 300)
129
+
130
+ def plot_color_bar():
131
+ """
132
+ Create a Viridis color bar ranging from 0 to 1.
133
+ """
134
+ # Create a gradient from 0 to 1
135
+ gradient = np.linspace(0, 1, 256).reshape(1, -1)
136
+
137
+ # Plot the gradient as a color bar
138
+ fig, ax = plt.subplots(figsize=(12, 3))
139
+ ax.imshow(gradient, aspect="auto", cmap="viridis")
140
+ ax.set_xticks([0, 255])
141
+ ax.set_xticklabels(["0\nmost likely\nto mutate", "1\nleast likely\nto mutate"], fontsize=40)
142
+ ax.set_yticks([])
143
+ ax.set_title("Original Residue Logits", fontsize=40)
144
+
145
+ # Save the figure
146
+ plt.tight_layout()
147
+ plt.show()
148
+ plt.savefig("viridis_color_bar.png", dpi=300)
149
+
150
+ def main():
151
+ # Call the function to create and display the color bar
152
+ plot_color_bar()
153
+
154
+ results_dir = "results/final"
155
+ subfolders = os.listdir(results_dir)
156
+ for subfolder in subfolders:
157
+ full_path = f"{results_dir}/{subfolder}"
158
+ if os.path.isdir(full_path):
159
+ with open(f"{full_path}/raw_mutation_results.pkl", "rb") as f:
160
+ mutation_results = pickle.load(f)
161
+ plot_conservation_heatmap(mutation_results,
162
+ fusion_name=subfolder, save_path=f"{full_path}/conservation_heatmap.png")
163
+
164
+
165
+
166
+ if __name__ == "__main__":
167
+ main()
fuson_plm/benchmarking/mutation_prediction/discovery/processed_data/521_logit_bfactor.cif ADDED
The diff for this file is too large to render. See raw diff
 
pytorch_model.bin → fuson_plm/benchmarking/mutation_prediction/discovery/processed_data/domain_conservation_fusions_inputfile.csv RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:3a0b878f4a6dfdeb5bc3acb5ceb82980c7840c4afe5bc1063c23e20e9f8da623
3
- size 2609617594
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a7bef22b6753a20845f095872431e1a222ac2be40ffa20bfe36113d80a7b57b2
3
+ size 3119
fuson_plm/benchmarking/mutation_prediction/discovery/processed_data/test_seqs_tftf_kk.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4e77c7777c0163df7b4ea740d16c5aff0a02cec03a251130dde0cf9b291d4fa2
3
+ size 4023866
fuson_plm/benchmarking/mutation_prediction/discovery/raw_data/salokas_2020_tableS3.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d8bebc0871a4329015a3c6c7843f5bbc86c48811b2a836c42f1ef46b37f4282a
3
+ size 19626
fuson_plm/benchmarking/mutation_prediction/discovery/results/final/ETV6::NTRK3/conservation_heatmap.png ADDED
fuson_plm/benchmarking/mutation_prediction/discovery/results/final/ETV6::NTRK3/full_results_with_logits.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1e881d4a4c081da84a60991e2ad9598fcf1dba47cbcb92b5cb6d70c29f4217ea
3
+ size 432231
fuson_plm/benchmarking/mutation_prediction/discovery/results/final/ETV6::NTRK3/predicted_tokens.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3349b52a936dbce7810faf3e17a296cf63701fb87e64d51de2d880575778b1b0
3
+ size 10299
fuson_plm/benchmarking/mutation_prediction/discovery/results/final/ETV6::NTRK3/raw_mutation_results.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9762879aacee6bd42440326c3d7b68871df8caacf97127db213c47eec9fa7d1f
3
+ size 315798
fuson_plm/benchmarking/mutation_prediction/discovery/results/final/EWSR1::FLI1/conservation_heatmap.png ADDED
fuson_plm/benchmarking/mutation_prediction/discovery/results/final/EWSR1::FLI1/full_results_with_logits.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:82cddf491a55bf2254ab86b2a20baf161097b5538a5868c0660c88baf856f672
3
+ size 349203
fuson_plm/benchmarking/mutation_prediction/discovery/results/final/EWSR1::FLI1/predicted_tokens.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ebd172fd6717511efd5b4971e0370381f1a29d29b12dbc196b82585ed2ef5065
3
+ size 8363
fuson_plm/benchmarking/mutation_prediction/discovery/results/final/EWSR1::FLI1/raw_mutation_results.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4973a6fc35ee9647f02f83627181b83e5f861c5918dc40abf4abb872d3e7088b
3
+ size 256739
fuson_plm/benchmarking/mutation_prediction/discovery/results/final/PAX3::FOXO1/conservation_heatmap.png ADDED
fuson_plm/benchmarking/mutation_prediction/discovery/results/final/PAX3::FOXO1/full_results_with_logits.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6d7841257792c6d83d2cfcd5456bf8761b3d3fc0dd69cd912ecfc05a80ad3dec
3
+ size 545639
fuson_plm/benchmarking/mutation_prediction/discovery/results/final/PAX3::FOXO1/predicted_tokens.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e8d94cfad62e2b62bf60d62cfef02fa2328f5d37c7f1c141575d098896bc2103
3
+ size 13323
fuson_plm/benchmarking/mutation_prediction/discovery/results/final/PAX3::FOXO1/raw_mutation_results.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0dab0b3ee7ea7f9c3c67fdef16bf81dd48654bb8bfdf12eedd58e1972d908a24
3
+ size 408037
fuson_plm/benchmarking/mutation_prediction/discovery/results/final/TRIM24::RET/conservation_heatmap.png ADDED
fuson_plm/benchmarking/mutation_prediction/discovery/results/final/TRIM24::RET/full_results_with_logits.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9ad3fc01c9ce3c67330d8fc79d14a551aa4d7f8c8d596a49e8e8042d6a8fbb1b
3
+ size 635205
fuson_plm/benchmarking/mutation_prediction/discovery/results/final/TRIM24::RET/predicted_tokens.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:22cdcd8cfbd5f70f1067d2104e6f0e3fc5e4e5e3fddd736934ce905b35cd0cea
3
+ size 15195
fuson_plm/benchmarking/mutation_prediction/discovery/results/final/TRIM24::RET/raw_mutation_results.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1d719d23b4c7e6deb79ac9ef3c5c84a1c0c3a2e4ae9926f9a460731bd78f9e0c
3
+ size 465133
fuson_plm/benchmarking/mutation_prediction/discovery/viridis_color_bar.png ADDED
fuson_plm/benchmarking/mutation_prediction/recovery/abl_mutations.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:22679a1f55d49ed4a8d5309c27215a4a204f47b49ce0a7a266a674208e381e85
3
+ size 307
fuson_plm/benchmarking/mutation_prediction/recovery/alk_mutations.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9339f8cd5ea5369a9b87ca3b420ee2731c4e057d5c546c3c2a09c727154721df
3
+ size 192
fuson_plm/benchmarking/mutation_prediction/recovery/color_recovered_mutations_public.ipynb ADDED
@@ -0,0 +1,314 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {
6
+ "id": "FJd6a9gdZNjG"
7
+ },
8
+ "source": [
9
+ "### Imports"
10
+ ]
11
+ },
12
+ {
13
+ "cell_type": "code",
14
+ "execution_count": 1,
15
+ "metadata": {},
16
+ "outputs": [],
17
+ "source": [
18
+ "## Put path to model predictions you'd like to use for benchmarking\n",
19
+ "path_to_bcr_abl_preds = \"results/11-21-2024-16:03:59/Supplementary Tables - BCR ABL Mutations.csv\"\n",
20
+ "path_to_eml4_alk_preds = \"results/11-21-2024-16:03:59/Supplementary Tables - EML4 ALK Mutations.csv\""
21
+ ]
22
+ },
23
+ {
24
+ "cell_type": "code",
25
+ "execution_count": null,
26
+ "metadata": {},
27
+ "outputs": [],
28
+ "source": [
29
+ "!pip install torch pandas numpy py3Dmol scikit-learn"
30
+ ]
31
+ },
32
+ {
33
+ "cell_type": "code",
34
+ "execution_count": 3,
35
+ "metadata": {
36
+ "id": "ZEWZVc9lUxjI"
37
+ },
38
+ "outputs": [],
39
+ "source": [
40
+ "import torch\n",
41
+ "import torch.nn as nn\n",
42
+ "\n",
43
+ "import pickle\n",
44
+ "import pandas as pd\n",
45
+ "import numpy as np\n",
46
+ "\n",
47
+ "import py3Dmol\n",
48
+ "\n",
49
+ "from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, precision_recall_curve, average_precision_score"
50
+ ]
51
+ },
52
+ {
53
+ "cell_type": "code",
54
+ "execution_count": null,
55
+ "metadata": {},
56
+ "outputs": [],
57
+ "source": [
58
+ "bcr_abl_preds = pd.read_csv(path_to_bcr_abl_preds)\n",
59
+ "eml4_alk_preds = pd.read_csv(path_to_eml4_alk_preds)\n",
60
+ "bcr_abl_seq = ''.join(bcr_abl_preds['Original Residue'].tolist())\n",
61
+ "eml4_alk_seq = ''.join(eml4_alk_preds['Original Residue'].tolist())\n",
62
+ "\n",
63
+ "print(\"BCR::ABL seq: \", bcr_abl_seq)\n",
64
+ "print(\"EML4::ALK seq: \", eml4_alk_seq)"
65
+ ]
66
+ },
67
+ {
68
+ "cell_type": "code",
69
+ "execution_count": null,
70
+ "metadata": {},
71
+ "outputs": [],
72
+ "source": [
73
+ "bcr_abl_preds = bcr_abl_preds[bcr_abl_preds['Literature Mutation'].notna()].reset_index(drop=True)\n",
74
+ "eml4_alk_preds = eml4_alk_preds[eml4_alk_preds['Literature Mutation'].notna()].reset_index(drop=True)\n",
75
+ "\n",
76
+ "bcr_abl_preds"
77
+ ]
78
+ },
79
+ {
80
+ "cell_type": "code",
81
+ "execution_count": null,
82
+ "metadata": {},
83
+ "outputs": [],
84
+ "source": [
85
+ "# Calculate hits\n",
86
+ "print(\"BCR::ABL\", len(bcr_abl_preds), sum(bcr_abl_preds['Hit']))\n",
87
+ "print(\"EML4::ALK\", len(eml4_alk_preds), sum(eml4_alk_preds['Hit']))"
88
+ ]
89
+ },
90
+ {
91
+ "cell_type": "code",
92
+ "execution_count": null,
93
+ "metadata": {},
94
+ "outputs": [],
95
+ "source": [
96
+ "import os\n",
97
+ "os.getcwd()"
98
+ ]
99
+ },
100
+ {
101
+ "cell_type": "code",
102
+ "execution_count": null,
103
+ "metadata": {},
104
+ "outputs": [],
105
+ "source": [
106
+ "# Color EML4 ALK structure\n",
107
+ "path_to_eml4_alk_structure = np.nan # not publcly available\n",
108
+ "eml4_alk_seq = np.nan # not publicly available\n",
109
+ "alk_kinase_domain_seq = np.nan # not publicly available \n",
110
+ "alk_kinase_domain_resis = [eml4_alk_seq.index(alk_kinase_domain_seq)+1, eml4_alk_seq.index(alk_kinase_domain_seq)+len(alk_kinase_domain_seq)]\n",
111
+ "print(alk_kinase_domain_resis)\n",
112
+ "print(len(eml4_alk_seq))"
113
+ ]
114
+ },
115
+ {
116
+ "cell_type": "code",
117
+ "execution_count": null,
118
+ "metadata": {},
119
+ "outputs": [],
120
+ "source": [
121
+ "l = list(range(567,843+1))\n",
122
+ "l = [str(x) for x in l]\n",
123
+ "print('+'.join(l))"
124
+ ]
125
+ },
126
+ {
127
+ "cell_type": "code",
128
+ "execution_count": null,
129
+ "metadata": {},
130
+ "outputs": [],
131
+ "source": [
132
+ "eml4_alk_mutation_resis = eml4_alk_preds['Position'].tolist()\n",
133
+ "eml4_alk_hit_resis = eml4_alk_preds.loc[\n",
134
+ " eml4_alk_preds['Hit']\n",
135
+ "]['Position'].tolist()\n",
136
+ "kinase_domain_coloring = {\n",
137
+ " 'kinase': [alk_kinase_domain_resis[0], alk_kinase_domain_resis[1], '#6ea4da']\n",
138
+ "}\n",
139
+ "\n",
140
+ "missed_mut_resis = [x for x in eml4_alk_mutation_resis if x not in eml4_alk_hit_resis]\n",
141
+ "print('missed', '+'.join([str(x) for x in missed_mut_resis]))\n",
142
+ "\n",
143
+ "print('hit', '+'.join([str(x) for x in eml4_alk_hit_resis]))"
144
+ ]
145
+ },
146
+ {
147
+ "cell_type": "markdown",
148
+ "metadata": {},
149
+ "source": [
150
+ "# PYMOL code for EML4 ALK\n",
151
+ "# Set global cartoon transparency\n",
152
+ "# 0.5 for close up\n",
153
+ "color gray60\n",
154
+ "set cartoon_transparency, 0.6\n",
155
+ "\n",
156
+ "# Define a custom color\n",
157
+ "set_color custom_blue, [0x6e/255, 0xa4/255, 0xda/255]\n",
158
+ "\n",
159
+ "# Select and color residues 567-843\n",
160
+ "sele alk, resi 567-843\n",
161
+ "color custom_blue, alk\n",
162
+ "\n",
163
+ "# Select missed mutation residues, color them orange, and make them fully opaque\n",
164
+ "sele missed_mut_resis, resi 603+707\n",
165
+ "color orange, missed_mut_resis\n",
166
+ "set cartoon_transparency, 0, missed_mut_resis\n",
167
+ "\n",
168
+ "# Select hit mutation residues, color them magenta, and make them fully opaque\n",
169
+ "sele hit_mut_resis, resi 607+622+625+631+647+649+653+654+657+661+696+720\n",
170
+ "color magenta, hit_mut_resis\n",
171
+ "set cartoon_transparency, 0, hit_mut_resis\n"
172
+ ]
173
+ },
174
+ {
175
+ "cell_type": "code",
176
+ "execution_count": null,
177
+ "metadata": {},
178
+ "outputs": [],
179
+ "source": [
180
+ "def color_mutation_recovery(pdb_file, mutation_resis, hit_resis, other_domain_coloring=None):\n",
181
+ " # Create viewer\n",
182
+ " viewer = py3Dmol.view()\n",
183
+ "\n",
184
+ " # Load CIF file\n",
185
+ " with open(pdb_file, 'r') as f:\n",
186
+ " pdb_file = f.read()\n",
187
+ "\n",
188
+ " # Add structure\n",
189
+ " viewer.addModel(pdb_file, 'pdb')\n",
190
+ " \n",
191
+ " viewer.setStyle({})\n",
192
+ " viewer.setStyle({'cartoon': {'color': 'lightgrey','transparency': 0.7}})\n",
193
+ " \n",
194
+ " # Apply colors based on normalized disorder values\n",
195
+ " for i, value in enumerate(mutation_resis):\n",
196
+ " style = {'cartoon': {'transparency': 1}}\n",
197
+ " if value in hit_resis:\n",
198
+ " style['cartoon']['color'] = 'yellow'\n",
199
+ " else:\n",
200
+ " style['cartoon']['color'] = 'red'\n",
201
+ " # Apply combined style directly, avoiding any residual layering\n",
202
+ " viewer.setStyle({'resi': value}, style)\n",
203
+ " \n",
204
+ " # If you want to color specific domains by default, do it here\n",
205
+ " if other_domain_coloring is not None:\n",
206
+ " for name, (start, end, color) in other_domain_coloring.items():\n",
207
+ " resis = list(range(start, end+1))\n",
208
+ " resis = [x for x in resis if x not in mutation_resis]\n",
209
+ " viewer.setStyle({'resi': resis}, {'cartoon': {'color': color,'transparency': 0.7}})\n",
210
+ "\n",
211
+ " # Show viewer\n",
212
+ " viewer.zoomTo()\n",
213
+ " return viewer.show()\n",
214
+ "\n",
215
+ "color_mutation_recovery(path_to_eml4_alk_structure, eml4_alk_mutation_resis, eml4_alk_hit_resis, \n",
216
+ " other_domain_coloring=kinase_domain_coloring)"
217
+ ]
218
+ },
219
+ {
220
+ "cell_type": "code",
221
+ "execution_count": null,
222
+ "metadata": {},
223
+ "outputs": [],
224
+ "source": [
225
+ "# Color BCR::ABL structure\n",
226
+ "path_to_bcr_abl_structure = np.nan # not publicly available\n",
227
+ "bcr_abl_seq = np.nan # not publicly available \n",
228
+ "abl_kinase_domain_seq = np.nan # not publicly available \n",
229
+ "abl_kinase_domain_resis = [bcr_abl_seq.index(abl_kinase_domain_seq)+1, bcr_abl_seq.index(abl_kinase_domain_seq)+len(abl_kinase_domain_seq)]\n",
230
+ "print(abl_kinase_domain_resis)\n",
231
+ "print(abl_kinase_domain_seq)\n",
232
+ "print(len(bcr_abl_seq))"
233
+ ]
234
+ },
235
+ {
236
+ "cell_type": "code",
237
+ "execution_count": null,
238
+ "metadata": {},
239
+ "outputs": [],
240
+ "source": [
241
+ "bcr_abl_mutation_resis = bcr_abl_preds['Position'].tolist()\n",
242
+ "bcr_abl_hit_resis = bcr_abl_preds.loc[\n",
243
+ " bcr_abl_preds['Hit']\n",
244
+ "]['Position'].tolist()\n",
245
+ "kinase_domain_coloring = {\n",
246
+ " 'kinase': [abl_kinase_domain_resis[0], abl_kinase_domain_resis[1], '#6ea4da']\n",
247
+ "}\n",
248
+ "\n",
249
+ "bcr_abl_missed_mut_resis = [x for x in bcr_abl_mutation_resis if x not in bcr_abl_hit_resis]\n",
250
+ "print('missed residues', len(bcr_abl_missed_mut_resis), '+'.join([str(x) for x in bcr_abl_missed_mut_resis]))\n",
251
+ "print('hit residues', len(bcr_abl_hit_resis), '+'.join([str(x) for x in bcr_abl_hit_resis]))"
252
+ ]
253
+ },
254
+ {
255
+ "cell_type": "markdown",
256
+ "metadata": {},
257
+ "source": [
258
+ "# PYMOL code for BCR ABL\n",
259
+ "# Set global cartoon transparency\n",
260
+ "# 0 for zoomed out, 0.5 for close up\n",
261
+ "color gray60\n",
262
+ "set cartoon_transparency, 0.6\n",
263
+ "\n",
264
+ "# Define a custom color\n",
265
+ "set_color custom_blue, [0x6e/255, 0xa4/255, 0xda/255]\n",
266
+ "\n",
267
+ "# Select and color residues 1085-1336\n",
268
+ "sele abl, resi 1085-1336\n",
269
+ "color custom_blue, abl\n",
270
+ "\n",
271
+ "# Select missed mutation residues, color them orange, and make them fully opaque\n",
272
+ "sele missed_mut_resis, resi 1087+1090+1093+1095+1116+1125+1128+1135+1140+1158+1192+1194+1218+1249+1273\n",
273
+ "color orange, missed_mut_resis\n",
274
+ "set cartoon_transparency, 0, missed_mut_resis\n",
275
+ "\n",
276
+ "# Select hit mutation residues, color them magenta, and make them fully opaque\n",
277
+ "sele hit_mut_resis, resi 1091+1096+1098+1132+1142+1154+1160+1198+1202+1222+1227+1230+1239\n",
278
+ "color magenta, hit_mut_resis\n",
279
+ "set cartoon_transparency, 0, hit_mut_resis"
280
+ ]
281
+ }
282
+ ],
283
+ "metadata": {
284
+ "colab": {
285
+ "collapsed_sections": [
286
+ "FJd6a9gdZNjG",
287
+ "zORkLJztZWp9",
288
+ "w25hagtZaV65",
289
+ "IbyqxlvAFUAK",
290
+ "0n5PSprbhLk7"
291
+ ],
292
+ "machine_shape": "hm",
293
+ "provenance": []
294
+ },
295
+ "kernelspec": {
296
+ "display_name": "Python 3",
297
+ "name": "python3"
298
+ },
299
+ "language_info": {
300
+ "codemirror_mode": {
301
+ "name": "ipython",
302
+ "version": 3
303
+ },
304
+ "file_extension": ".py",
305
+ "mimetype": "text/x-python",
306
+ "name": "python",
307
+ "nbconvert_exporter": "python",
308
+ "pygments_lexer": "ipython3",
309
+ "version": "3.10.12"
310
+ }
311
+ },
312
+ "nbformat": 4,
313
+ "nbformat_minor": 0
314
+ }
fuson_plm/benchmarking/mutation_prediction/recovery/config.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ FUSON_PLM_CKPT = "FusOn-pLM" # Dictionary: key = run name, values = epochs, or string "FusOn-pLM"
2
+
3
+ CUDA_VISIBLE_DEVICES = "0"
fuson_plm/benchmarking/mutation_prediction/recovery/recover_public.py ADDED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #### Recover mutations from literature. A benchmark
2
+ import fuson_plm.benchmarking.mutation_prediction.recovery.config as config
3
+ import os
4
+ os.environ['CUDA_VISIBLE_DEVICES'] = config.CUDA_VISIBLE_DEVICES
5
+
6
+ import pandas as pd
7
+ import numpy as np
8
+ import transformers
9
+ from transformers import AutoTokenizer, AutoModelForMaskedLM
10
+ import logging
11
+ import torch
12
+ import matplotlib.pyplot as plt
13
+ import seaborn as sns
14
+ import argparse
15
+ import os
16
+ import torch.nn.functional as F
17
+
18
+ from fuson_plm.utils.logging import open_logfile, log_update, get_local_time, print_configpy
19
+ from fuson_plm.benchmarking.embed import load_fuson_model
20
+
21
+ def check_env_variables():
22
+ log_update("\nChecking on environment variables...")
23
+ log_update(f"\tCUDA_VISIBLE_DEVICES: {os.environ.get('CUDA_VISIBLE_DEVICES')}")
24
+ log_update(f"\ttorch.cuda.device_count(): {torch.cuda.device_count()}")
25
+ for i in range(torch.cuda.device_count()):
26
+ log_update(f"\t\tDevice {i}: {torch.cuda.get_device_name(i)}")
27
+
28
+ def get_top_k_aa_mutations(all_probabilities, sequence, i, top_k_mutations, k=10):
29
+ """
30
+ Should only return top AA mutations
31
+ """
32
+ all_probs = pd.DataFrame.from_dict(all_probabilities, orient='index').reset_index()
33
+ all_probs = all_probs.sort_values(by=0,ascending=False).reset_index(drop=True)
34
+ top_k_mutation = all_probs['index'].tolist()[0:k]
35
+ top_k_mutation = ",".join(top_k_mutation)
36
+ top_k_mutations[(sequence[i], i)] = (top_k_mutation, all_probabilities)
37
+
38
+ return top_k_mutations
39
+
40
+ def get_top_k_mutations(tokenizer, mask_token_logits, all_probabilities, sequence, i, top_k_mutations, k=3):
41
+ top_k_tokens = torch.topk(mask_token_logits, k, dim=1).indices[0].tolist()
42
+ top_k_mutation = []
43
+ for token in top_k_tokens:
44
+ replaced_text = tokenizer.decode([token])
45
+ top_k_mutation.append(replaced_text)
46
+
47
+ top_k_mutation = ",".join(top_k_mutation)
48
+ top_k_mutations[(sequence[i], i)] = (top_k_mutation, all_probabilities)
49
+
50
+ def predict_positionwise_mutations(model, tokenizer, device, sequence):
51
+ log_update("\t\tPredicting position-wise mutations...")
52
+ top_10_mutations = {}
53
+ decoded_full_sequence = ''
54
+ mut_count = 0
55
+
56
+ # Mask and unmask sequentially
57
+ for i in range(len(sequence)):
58
+ log_update(f"\t\t\t- pos {i+1}/{len(sequence)}")
59
+ all_probabilities = {} # stored probabilities of each AA at this position
60
+
61
+ # Mask JUST the current position
62
+ masked_seq = sequence[:i] + '<mask>' + sequence[i+1:]
63
+ inputs = tokenizer(masked_seq, return_tensors="pt", padding=True, truncation=True,max_length=2000)
64
+ inputs = {k: v.to(device) for k, v in inputs.items()}
65
+
66
+ # Forward pass
67
+ with torch.no_grad():
68
+ logits = model(**inputs).logits
69
+
70
+ # Find logits at masked positions (should just be 1!)
71
+ mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1]
72
+ mask_token_logits = logits[0, mask_token_index, :]
73
+ mask_token_probs = F.softmax(mask_token_logits, dim=-1)
74
+
75
+ # Collect probabilities for natural AAs (token IDs 4-23 inclusive)
76
+ for token_idx in range(4, 23 + 1):
77
+ token = mask_token_probs[0, token_idx]
78
+ replaced_text = tokenizer.decode([token_idx])
79
+ all_probabilities[replaced_text] = token.item()
80
+
81
+ # Isolate top n mutations
82
+ #get_top_k_mutations(tokenizer, mask_token_logits, all_probabilities, sequence, i, top_10_mutations, k=10)
83
+ get_top_k_aa_mutations(all_probabilities, sequence, i, top_10_mutations, k=10)
84
+
85
+ # Building whole decoded sequence with top 1 token
86
+ top_1_tokens = torch.topk(mask_token_logits, 1, dim=1).indices[0].item()
87
+ new_residue = tokenizer.decode([top_1_tokens])
88
+ decoded_full_sequence += new_residue
89
+
90
+ # Check how many mutations in total
91
+ if sequence[i] != new_residue:
92
+ mut_count += 1
93
+
94
+ # Convert results into DataFrame
95
+ original_residues = []
96
+ top10_mutations = []
97
+ positions = []
98
+ all_logits = []
99
+
100
+ for (original_residue, position), (top10, probs) in top_10_mutations.items():
101
+ original_residues.append(original_residue)
102
+ top10_mutations.append(top10)
103
+ positions.append(position+1) # originally this line was "position" but it should be position + 1
104
+ all_logits.append(probs)
105
+
106
+ df = pd.DataFrame({
107
+ 'Original Residue': original_residues,
108
+ 'Position': positions,
109
+ 'Top 10 Mutations': top10_mutations,
110
+ 'All Probabilities': all_logits,
111
+ })
112
+ df['Top Mutation'] = df['Top 10 Mutations'].apply(lambda x: x.split(',')[0])
113
+ df['Top 3 Mutations'] = df['Top 10 Mutations'].apply(lambda x: ','.join(x.split(',')[0:3]))
114
+ df['Top 4 Mutations'] = df['Top 10 Mutations'].apply(lambda x: ','.join(x.split(',')[0:4]))
115
+ df['Top 5 Mutations'] = df['Top 10 Mutations'].apply(lambda x: ','.join(x.split(',')[0:5]))
116
+
117
+ return df, decoded_full_sequence, mut_count
118
+
119
+ def evaluate_literature_mut_performance(predicted_mutations_df, literature_mutations_df, decoded_full_sequence, mut_count, sequence="", focus_region_start=0, focus_region_end=0, offset=0):
120
+ """
121
+ Given a dataframe of predicted mutations and literature mutations, see how well the predicted mutations did
122
+ """
123
+ log_update("\t\tComparing predicted mutations to literature-provided mutations")
124
+ return_df = predicted_mutations_df.copy(deep=True)
125
+ return_df['Literature Mutation'] = [np.nan]*len(return_df)
126
+ return_df['Top 1 Hit'] = [np.nan]*len(return_df)
127
+ return_df['Top 3 Hit'] = [np.nan]*len(return_df)
128
+ return_df['Top 4 Hit'] = [np.nan]*len(return_df)
129
+ return_df['Top 5 Hit'] = [np.nan]*len(return_df)
130
+ return_df['Top 10 Hit'] = [np.nan]*len(return_df)
131
+
132
+ log_update(f"\tFormula: new position = {focus_region_start} + lit_position - {offset}")
133
+ # Iterate through the literature mutations rows
134
+ for i, row in literature_mutations_df.iterrows():
135
+ lit_position = row['Position']
136
+ lit_mutations = row['Mutation']
137
+ original_residue = row['Original Residue']
138
+ seq_position = focus_region_start + (lit_position - offset) # find position of the sequence
139
+
140
+ matching_row = return_df[return_df['Position'] == seq_position]
141
+ matching_row_index = matching_row.index
142
+ matching_residue = matching_row.iloc[0]['Original Residue']
143
+ match = original_residue==matching_residue
144
+ log_update(f"\tLit pos: {lit_position}, OG residue: {original_residue}, Full sequence pos: {seq_position}, Full sequence residue: {matching_residue}\n\t\tMatch: {match}")
145
+
146
+ # Iterate through the matching rows. We are at the right spot if we have the right original residue.
147
+ if match:
148
+ top_mutation = matching_row.iloc[0]['Top Mutation'] # get top 3 mutations
149
+ top_mutation = top_mutation.split(',')
150
+ print(top_mutation)
151
+ return_df.loc[matching_row_index, 'Literature Mutation'] = lit_mutations # get desired mutation
152
+ # If we got any of the mutatios reported in the literature, hit!
153
+ if any(letter in lit_mutations for letter in top_mutation):
154
+ return_df.loc[matching_row_index, 'Top 1 Hit'] = True
155
+ else:
156
+ return_df.loc[matching_row_index, 'Top 1 Hit'] = False
157
+
158
+ for k in [3,4,5,10]:
159
+ top_k_mutations = matching_row.iloc[0][f'Top {k} Mutations'] # get top 3 mutations
160
+ top_k_mutations = top_k_mutations.split(",")
161
+ print(top_k_mutations)
162
+ return_df.loc[matching_row_index, 'Literature Mutation'] = lit_mutations # get desired mutation
163
+ # If we got any of the mutatios reported in the literature, hit!
164
+ if any(letter in lit_mutations for letter in top_k_mutations):
165
+ return_df.loc[matching_row_index, f'Top {k} Hit'] = True
166
+ else:
167
+ return_df.loc[matching_row_index, f'Top {k} Hit'] = False
168
+
169
+ return return_df, (decoded_full_sequence, mut_count, (mut_count/len(sequence)) * 100)
170
+
171
+ def evaluate_eml4_alk(model, tokenizer, device, model_str):
172
+ alk_muts = pd.read_csv("alk_mutations.csv")
173
+ decoded_full_sequence, mut_count = None, None
174
+
175
+ EML4_ALK_SEQ = np.nan ## not publicly available
176
+ cons_domain_alk = np.nan # no publicly available
177
+ focus_region_start = EML4_ALK_SEQ.find(cons_domain_alk)
178
+
179
+ if os.path.isfile(f"eml4_alk_mutations/{model_str}/mutated_df.csv"):
180
+ log_update(f"Mutation predictions for {model_str} have already been calculated. Loading from eml4_alk_mutations/{model_str}/mutated_df.csv")
181
+ mutated_df = pd.read_csv(f"eml4_alk_mutations/{model_str}/mutated_df.csv")
182
+ mutated_summary = pd.read_csv(f"eml4_alk_mutations/{model_str}/mutated_summary.csv")
183
+ decoded_full_sequence = mutated_summary['decoded_full_sequence'][0]
184
+ mut_count = mutated_summary['mut_count'][0]
185
+ else:
186
+ mutated_df, decoded_full_sequence, mut_count = predict_positionwise_mutations(model, tokenizer, device, EML4_ALK_SEQ)
187
+ mutated_summary = pd.DataFrame(data={'decoded_full_sequence':[decoded_full_sequence],'mut_count':[mut_count]})
188
+ mutated_df.to_csv(f"eml4_alk_mutations/{model_str}/mutated_df.csv",index=False)
189
+ mutated_summary.to_csv(f"eml4_alk_mutations/{model_str}/mutated_summary.csv",index=False)
190
+
191
+ lit_performance_df, (mut_seq, mut_count, mut_pcnt) = evaluate_literature_mut_performance(mutated_df, alk_muts, decoded_full_sequence, mut_count,
192
+ sequence=EML4_ALK_SEQ,
193
+ focus_region_start=focus_region_start,
194
+ focus_region_end = focus_region_start + len(cons_domain_alk),
195
+ offset=1115 # original: 1116
196
+ )
197
+
198
+ return lit_performance_df, (mut_seq, mut_count, mut_pcnt)
199
+
200
+ def evaluate_bcr_abl(model, tokenizer, device, model_str):
201
+ abl_muts = pd.read_csv("abl_mutations.csv")
202
+ decoded_full_sequence, mut_count = None, None
203
+
204
+ BCR_ABL_SEQ = np.nan ## not publicly available
205
+ cons_domain_abl = np.nan ## not publicly available
206
+ focus_region_start = BCR_ABL_SEQ.find(cons_domain_abl)
207
+
208
+ if os.path.isfile(f"bcr_abl_mutations/{model_str}/mutated_df.csv"):
209
+ log_update(f"Mutation predictions for {model_str} have already been calculated. Loading from bcr_abl_mutations/{model_str}/mutated_df.csv")
210
+ mutated_df = pd.read_csv(f"bcr_abl_mutations/{model_str}/mutated_df.csv")
211
+ mutated_summary = pd.read_csv(f"bcr_abl_mutations/{model_str}/mutated_summary.csv")
212
+ decoded_full_sequence = mutated_summary['decoded_full_sequence'][0]
213
+ mut_count = mutated_summary['mut_count'][0]
214
+ else:
215
+ mutated_df, decoded_full_sequence, mut_count = predict_positionwise_mutations(model, tokenizer, device, BCR_ABL_SEQ)
216
+ mutated_summary = pd.DataFrame(data={'decoded_full_sequence':[decoded_full_sequence],'mut_count':[mut_count]})
217
+ mutated_df.to_csv(f"bcr_abl_mutations/{model_str}/mutated_df.csv",index=False)
218
+ mutated_summary.to_csv(f"bcr_abl_mutations/{model_str}/mutated_summary.csv",index=False)
219
+
220
+ lit_performance_df, (mut_seq, mut_count, mut_pcnt) = evaluate_literature_mut_performance(mutated_df, abl_muts, decoded_full_sequence, mut_count,
221
+ sequence=BCR_ABL_SEQ,
222
+ focus_region_start=focus_region_start,
223
+ focus_region_end = focus_region_start + len(cons_domain_abl),
224
+ offset=241 # original: 242
225
+ )
226
+
227
+ return lit_performance_df, (mut_seq, mut_count, mut_pcnt)
228
+
229
+ def summarize_individual_performance(performance_df, path_to_lit_df):
230
+ """
231
+ performance_df = dataframe with stats on performance
232
+ path_to_lit_df = original dataframe
233
+ """
234
+ # Load original df
235
+ lit_muts = pd.read_csv(path_to_lit_df)
236
+
237
+ # Mutated Sequence,Original Residue,Position,Top 3 Mutations,Literature Mutation,Hit,All Probabilities
238
+ mut_rows = performance_df.loc[performance_df['Literature Mutation'].notna()].reset_index(drop=True)
239
+ mut_rows = mut_rows[['Original Residue','Position','Literature Mutation',
240
+ 'Top Mutation','Top 1 Hit',
241
+ 'Top 3 Mutations','Top 3 Hit',
242
+ 'Top 4 Mutations','Top 4 Hit',
243
+ 'Top 5 Mutations','Top 5 Hit',
244
+ 'Top 10 Mutations','Top 10 Hit'
245
+ ]]
246
+
247
+ mut_rows_str = mut_rows.to_string(index=False)
248
+ mut_rows_str = "\t\t" + mut_rows_str.replace("\n","\n\t\t")
249
+ log_update(f"\tPerformance on all mutated positions shown below:\n{mut_rows_str}")
250
+
251
+ # Summarize: total hits, percentage of hits
252
+ total_original_muts = len(lit_muts)
253
+ for k in [1,3,4,5,10]:
254
+ total_hits = len(mut_rows.loc[mut_rows[f'Top {k} Hit']==True])
255
+ total_misses = len(mut_rows.loc[mut_rows[f'Top {k} Hit']==False])
256
+ total_potential_muts = total_hits+total_misses
257
+ hit_pcnt = round(100*total_hits/total_potential_muts, 2)
258
+ miss_pcnt = round(100*total_misses/total_potential_muts, 2)
259
+
260
+ log_update(f"\tTotal positions tested / total positions mutated in literature: {total_potential_muts}/{total_original_muts}")
261
+ log_update(f"\n\t\tTop {k} hit performance:\n\t\t\tHits:{total_hits} ({hit_pcnt}%)\n\t\t\tMisses:{total_misses} ({miss_pcnt}%)")
262
+
263
+ def main():
264
+ os.makedirs('results',exist_ok=True)
265
+ output_dir = f'results/{get_local_time()}'
266
+ os.makedirs(output_dir,exist_ok=True)
267
+ os.makedirs("bcr_abl_mutations",exist_ok=True)
268
+ os.makedirs("eml4_alk_mutations",exist_ok=True)
269
+ with open_logfile(f"{output_dir}/mutation_discovery_log.txt"):
270
+ print_configpy(config)
271
+
272
+ # Make sure environment variables are set correctly
273
+ check_env_variables()
274
+
275
+ # Get device
276
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
277
+ log_update(f"Using device: {device}")
278
+
279
+ # Load fuson
280
+ fuson_ckpt_path = config.FUSON_PLM_CKPT
281
+ if fuson_ckpt_path=="FusOn-pLM":
282
+ fuson_ckpt_path="../../../.."
283
+ model_name = "fuson_plm"
284
+ model_epoch = "best"
285
+ model_str = f"fuson_plm/best"
286
+ else:
287
+ model_name = list(fuson_ckpt_path.keys())[0]
288
+ epoch = list(fuson_ckpt_path.values())[0]
289
+ fuson_ckpt_path = f'../../training/checkpoints/{model_name}/checkpoint_epoch_{epoch}'
290
+ model_name, model_epoch = fuson_ckpt_path.split('/')[-2::]
291
+ model_epoch = model_epoch.split('checkpoint_')[-1]
292
+ model_str = f"{model_name}/{model_epoch}"
293
+
294
+ log_update(f"\nLoading FusOn-pLM model from {fuson_ckpt_path}")
295
+ fuson_tokenizer = AutoTokenizer.from_pretrained(fuson_ckpt_path)
296
+ fuson_model = AutoModelForMaskedLM.from_pretrained(fuson_ckpt_path)
297
+ fuson_model.to(device)
298
+ fuson_model.eval()
299
+
300
+
301
+ # Evaluate BCR::ABL performance with FusOn
302
+ os.makedirs(f"bcr_abl_mutations/{model_name}",exist_ok=True)
303
+ os.makedirs(f"bcr_abl_mutations/{model_name}/{model_epoch}",exist_ok=True)
304
+ log_update("\tEvaluating performance on BCR::ABL mutation prediction with FusOn")
305
+ abl_lit_performance_fuson, (mut_seq, mut_count, mut_pcnt) = evaluate_bcr_abl(fuson_model, fuson_tokenizer, device, model_str)
306
+ abl_lit_performance_fuson.to_csv(f'{output_dir}/BCR_ABL_mutation_recovery_fuson.csv', index = False)
307
+ with open(f'{output_dir}/BCR_ABL_mutation_recovery_fuson_summary.txt', 'w') as f:
308
+ f.write(mut_seq)
309
+ f.write(f'number of mutations: {mut_count}')
310
+ f.write(f'percentage of seq mutated: {mut_pcnt}')
311
+
312
+ # Evaluate EML4::ALK performance with Fuson
313
+ os.makedirs(f"eml4_alk_mutations/{model_name}",exist_ok=True)
314
+ os.makedirs(f"eml4_alk_mutations/{model_name}/{model_epoch}",exist_ok=True)
315
+ log_update("\tEvaluating performance on EML4::ALK mutation prediction with FusOn")
316
+ alk_lit_performance_fuson, (mut_seq, mut_count, mut_pcnt) = evaluate_eml4_alk(fuson_model, fuson_tokenizer, device, model_str)
317
+ alk_lit_performance_fuson.to_csv(f'{output_dir}/EML4_ALK_mutation_recovery_fuson.csv', index = False)
318
+ with open(f'{output_dir}/EML4_ALK_mutation_recovery_fuson_summary.txt', 'w') as f:
319
+ f.write(mut_seq)
320
+ f.write(f'number of mutations: {mut_count}')
321
+ f.write(f'percentage of seq mutated: {mut_pcnt}')
322
+
323
+ ### Summarize
324
+ log_update("\nSummarizing FusOn-pLM performance on BCR::ABL")
325
+ summarize_individual_performance(abl_lit_performance_fuson, "abl_mutations.csv")
326
+ log_update("\nSummarizing FusOn-pLM performance on EML4::ALK")
327
+ summarize_individual_performance(alk_lit_performance_fuson, "alk_mutations.csv")
328
+
329
+ if __name__ == "__main__":
330
+ main()
fuson_plm/benchmarking/mutation_prediction/recovery/results/final_public/BCR_ABL_mutation_recovery_fuson_mutated_pns_only.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:779bac32bf18f7efdf843120274bdcd0ee2dd1a940655b7f1af1bc0d640bda1a
3
+ size 18267
fuson_plm/benchmarking/mutation_prediction/recovery/results/final_public/EML4_ALK_mutation_recovery_fuson_mutated_pns_only.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2fdc821e3e303aae226567cea331437eed8af96db43ebbd7487b7a48767de51e
3
+ size 9375
fuson_plm/benchmarking/mutation_prediction/recovery/results/final_public/Supplementary Tables - EML4 ALK Mutations.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bbc2b69ac1e403b060cda54b753c0e8cd9c0bbac1c64cdf38128aaa23081e12d
3
+ size 709
fuson_plm/benchmarking/mutation_prediction/recovery/results/final_public/Supplementary Tables - BCR ABL Mutations.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a100f40575539ae5009e0ee87aac73b4b7701ee108cacd417585a8e02155e5c1
3
+ size 1380
fuson_plm/benchmarking/puncta/README.md CHANGED
@@ -1,6 +1,6 @@
1
  ## Puncta Prediction Benchmark
2
 
3
- This folder contains all the data and code needed to perform the **puncta prediction benchmark** (Figure 3).
4
 
5
  ### From raw data to train/test splits
6
  To train the puncta predictors, we processed raw data from FOdb [(Tripathi et al. 2023)](https://doi.org/10.1038/s41467-023-41655-2) Supplementary dataset 4 (`fuson_plm/data/raw_data/FOdb_puncta.csv`) and Supplementary dataset 5 (`fuson_plm/data/raw_data/FODb_SD5.csv`) using the file `clean.py` in the `puncta` directory.
 
1
  ## Puncta Prediction Benchmark
2
 
3
+ This folder contains all the data and code needed to train FusOn-pLM-Puncta models and perform the **puncta prediction benchmark** (Figure 3).
4
 
5
  ### From raw data to train/test splits
6
  To train the puncta predictors, we processed raw data from FOdb [(Tripathi et al. 2023)](https://doi.org/10.1038/s41467-023-41655-2) Supplementary dataset 4 (`fuson_plm/data/raw_data/FOdb_puncta.csv`) and Supplementary dataset 5 (`fuson_plm/data/raw_data/FODb_SD5.csv`) using the file `clean.py` in the `puncta` directory.