mutation prediction discovery and recovery
Browse files- fuson_plm/benchmarking/caid/README.md +2 -2
- fuson_plm/benchmarking/idr_prediction/README.md +2 -2
- fuson_plm/benchmarking/mutation_prediction/README.md +114 -0
- fuson_plm/benchmarking/mutation_prediction/discovery/clean.py +71 -0
- fuson_plm/benchmarking/mutation_prediction/discovery/color_discovered_mutations.ipynb +418 -0
- fuson_plm/benchmarking/mutation_prediction/discovery/config.py +16 -0
- fuson_plm/benchmarking/mutation_prediction/discovery/discover.py +346 -0
- fuson_plm/benchmarking/mutation_prediction/discovery/make_color_bar.py +25 -0
- fuson_plm/benchmarking/mutation_prediction/discovery/plot.py +167 -0
- fuson_plm/benchmarking/mutation_prediction/discovery/processed_data/521_logit_bfactor.cif +0 -0
- pytorch_model.bin → fuson_plm/benchmarking/mutation_prediction/discovery/processed_data/domain_conservation_fusions_inputfile.csv +2 -2
- fuson_plm/benchmarking/mutation_prediction/discovery/processed_data/test_seqs_tftf_kk.csv +3 -0
- fuson_plm/benchmarking/mutation_prediction/discovery/raw_data/salokas_2020_tableS3.csv +3 -0
- fuson_plm/benchmarking/mutation_prediction/discovery/results/final/ETV6::NTRK3/conservation_heatmap.png +0 -0
- fuson_plm/benchmarking/mutation_prediction/discovery/results/final/ETV6::NTRK3/full_results_with_logits.csv +3 -0
- fuson_plm/benchmarking/mutation_prediction/discovery/results/final/ETV6::NTRK3/predicted_tokens.csv +3 -0
- fuson_plm/benchmarking/mutation_prediction/discovery/results/final/ETV6::NTRK3/raw_mutation_results.pkl +3 -0
- fuson_plm/benchmarking/mutation_prediction/discovery/results/final/EWSR1::FLI1/conservation_heatmap.png +0 -0
- fuson_plm/benchmarking/mutation_prediction/discovery/results/final/EWSR1::FLI1/full_results_with_logits.csv +3 -0
- fuson_plm/benchmarking/mutation_prediction/discovery/results/final/EWSR1::FLI1/predicted_tokens.csv +3 -0
- fuson_plm/benchmarking/mutation_prediction/discovery/results/final/EWSR1::FLI1/raw_mutation_results.pkl +3 -0
- fuson_plm/benchmarking/mutation_prediction/discovery/results/final/PAX3::FOXO1/conservation_heatmap.png +0 -0
- fuson_plm/benchmarking/mutation_prediction/discovery/results/final/PAX3::FOXO1/full_results_with_logits.csv +3 -0
- fuson_plm/benchmarking/mutation_prediction/discovery/results/final/PAX3::FOXO1/predicted_tokens.csv +3 -0
- fuson_plm/benchmarking/mutation_prediction/discovery/results/final/PAX3::FOXO1/raw_mutation_results.pkl +3 -0
- fuson_plm/benchmarking/mutation_prediction/discovery/results/final/TRIM24::RET/conservation_heatmap.png +0 -0
- fuson_plm/benchmarking/mutation_prediction/discovery/results/final/TRIM24::RET/full_results_with_logits.csv +3 -0
- fuson_plm/benchmarking/mutation_prediction/discovery/results/final/TRIM24::RET/predicted_tokens.csv +3 -0
- fuson_plm/benchmarking/mutation_prediction/discovery/results/final/TRIM24::RET/raw_mutation_results.pkl +3 -0
- fuson_plm/benchmarking/mutation_prediction/discovery/viridis_color_bar.png +0 -0
- fuson_plm/benchmarking/mutation_prediction/recovery/abl_mutations.csv +3 -0
- fuson_plm/benchmarking/mutation_prediction/recovery/alk_mutations.csv +3 -0
- fuson_plm/benchmarking/mutation_prediction/recovery/color_recovered_mutations_public.ipynb +314 -0
- fuson_plm/benchmarking/mutation_prediction/recovery/config.py +3 -0
- fuson_plm/benchmarking/mutation_prediction/recovery/recover_public.py +330 -0
- fuson_plm/benchmarking/mutation_prediction/recovery/results/final_public/BCR_ABL_mutation_recovery_fuson_mutated_pns_only.csv +3 -0
- fuson_plm/benchmarking/mutation_prediction/recovery/results/final_public/EML4_ALK_mutation_recovery_fuson_mutated_pns_only.csv +3 -0
- fuson_plm/benchmarking/mutation_prediction/recovery/results/final_public/Supplementary Tables - EML4 ALK Mutations.csv +3 -0
- fuson_plm/benchmarking/mutation_prediction/recovery/results/final_public/Supplementary Tables - BCR ABL Mutations.csv +3 -0
- fuson_plm/benchmarking/puncta/README.md +1 -1
fuson_plm/benchmarking/caid/README.md
CHANGED
@@ -158,8 +158,8 @@ Here we describe what each script does and which files each script creates.
|
|
158 |
a. Drops rows where no structure was successfully downloaded
|
159 |
b. Drops rows where the FO sequence from FusionPDB does not match the FO sequence from its own AlphaFold2 structure file
|
160 |
c. ⭐️Saves **two final, cleaned databases**⭐️:
|
161 |
-
a. ⭐️ **`FusionPDB_level2-3_cleaned_FusionGID_info.csv`**: includes ful IDs and structural information for the Hgene and Tgene of each FO. Columns="FusionGID","FusionGene","Hgene","Tgene","URL","HGID","TGID","HGUniProtAcc","TGUniProtAcc","HGUniProtAcc_Source","TGUniProtAcc_Source","HG_pLDDT","HG_AA_pLDDTs","HG_Seq","TG_pLDDT","TG_AA_pLDDTs","TG_Seq".
|
162 |
-
b. ⭐️ **`FusionPDB_level2-3_cleaned_structure_info.csv`**: includes full structural information for each FO. Columns= "FusionGID","FusionGene","Fusion_Seq","Fusion_Length","Hgene","Hchr","Hbp","Hstrand","Tgene","Tchr","Tbp","Tstrand","Level","Fusion_Structure_Link","Fusion_Structure_Type","Fusion_pLDDT","Fusion_AA_pLDDTs","Fusion_Seq_Source"
|
163 |
|
164 |
|
165 |
### Training
|
|
|
158 |
a. Drops rows where no structure was successfully downloaded
|
159 |
b. Drops rows where the FO sequence from FusionPDB does not match the FO sequence from its own AlphaFold2 structure file
|
160 |
c. ⭐️Saves **two final, cleaned databases**⭐️:
|
161 |
+
a. ⭐️ **`FusionPDB_level2-3_cleaned_FusionGID_info.csv`**: includes ful IDs and structural information for the Hgene and Tgene of each FO. Columns = "FusionGID", "FusionGene", "Hgene", "Tgene", "URL", "HGID", "TGID", "HGUniProtAcc", "TGUniProtAcc", "HGUniProtAcc_Source", "TGUniProtAcc_Source", "HG_pLDDT", "HG_AA_pLDDTs", "HG_Seq", "TG_pLDDT", "TG_AA_pLDDTs", "TG_Seq".
|
162 |
+
b. ⭐️ **`FusionPDB_level2-3_cleaned_structure_info.csv`**: includes full structural information for each FO. Columns = "FusionGID", "FusionGene", "Fusion_Seq", "Fusion_Length", "Hgene", "Hchr", "Hbp", "Hstrand", "Tgene", "Tchr", "Tbp", "Tstrand", "Level", "Fusion_Structure_Link", "Fusion_Structure_Type", "Fusion_pLDDT", "Fusion_AA_pLDDTs", "Fusion_Seq_Source"
|
163 |
|
164 |
|
165 |
### Training
|
fuson_plm/benchmarking/idr_prediction/README.md
CHANGED
@@ -14,7 +14,7 @@ python plot.py # if you want to remake r2 plots
|
|
14 |
```
|
15 |
|
16 |
### Downloading raw IDR data
|
17 |
-
IDR properties from [Lotthammer et al. 2024](https://doi.org/10.1038/s41592-023-02159-5) (ALBATROSS model) were used to train FusOn-pLM-
|
18 |
|
19 |
```
|
20 |
benchmarking/
|
@@ -144,7 +144,7 @@ The model is defined in `model.py` and `utils.py`. The `train.py` script trains
|
|
144 |
- All **raw outputs from models** are stored in `idr_prediction/trained_models/<embedding_path>`, where `embedding_path` represents the embeddings used to build the disorder predictor.
|
145 |
- All **embeddings** made for training will be stored in a new folder called `idr_prediction/embeddings/` with subfolders for each model. This allows you to use the same model multiple times without regenerating embeddings.
|
146 |
|
147 |
-
Below is the FusOn-pLM-
|
148 |
|
149 |
The outputs are structured as follows:
|
150 |
|
|
|
14 |
```
|
15 |
|
16 |
### Downloading raw IDR data
|
17 |
+
IDR properties from [Lotthammer et al. 2024](https://doi.org/10.1038/s41592-023-02159-5) (ALBATROSS model) were used to train FusOn-pLM-IDR. Sequences were downloaded from [this link](https://github.com/holehouse-lab/supportingdata/blob/master/2023/ALBATROSS_2023/simulations/data/all_sequences.tgz) and deposited in `raw_data`. All files in `raw_data` are from this direct download.
|
18 |
|
19 |
```
|
20 |
benchmarking/
|
|
|
144 |
- All **raw outputs from models** are stored in `idr_prediction/trained_models/<embedding_path>`, where `embedding_path` represents the embeddings used to build the disorder predictor.
|
145 |
- All **embeddings** made for training will be stored in a new folder called `idr_prediction/embeddings/` with subfolders for each model. This allows you to use the same model multiple times without regenerating embeddings.
|
146 |
|
147 |
+
Below is the FusOn-pLM-IDR raw outputs folder, `trained_models/fuson_plm/best/`, and the results from the paper, `results/final/`...
|
148 |
|
149 |
The outputs are structured as follows:
|
150 |
|
fuson_plm/benchmarking/mutation_prediction/README.md
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Mutation Prediction benchmark
|
2 |
+
|
3 |
+
This folder contains all the data and code needed to perform the **mutation prediction benchmark** (Figure 5).
|
4 |
+
|
5 |
+
## Recovery
|
6 |
+
In Fig. 5C, drug resistance mutations in BCR::ABL and EML4::ALK are recovered by FusOn-pLM. Since the full sequences are not publicly available, certain sections of the code and results files have been removed. All results for drug resistance mutation positions in each sequence are included in `mutation_prediction/recovery`.
|
7 |
+
|
8 |
+
```
|
9 |
+
benchmarking/
|
10 |
+
└── mutation_prediction/
|
11 |
+
└── recovery/
|
12 |
+
└── results/final_public/
|
13 |
+
├── BCR_ABL_mutation_recovery_fuson_mutated_pns_only.csv
|
14 |
+
├── Supplementary Tables - BCR ABL Mutations.csv
|
15 |
+
├── EML4_ALK_mutation_recovery_fuson_mutated_pns_only.csv
|
16 |
+
├── Supplementary Tables - EML4 ALK Mutations.csv
|
17 |
+
├── abl_mutations.csv
|
18 |
+
├── alk_mutations.csv
|
19 |
+
├── color_recovered_mutations_public.ipynb
|
20 |
+
├── recover_public.py
|
21 |
+
├── config.py
|
22 |
+
```
|
23 |
+
In the 📁 **`results`** directory:
|
24 |
+
- **`abl_mutations.csv`**: raw data from the literature with BCR::ABL mutations [(O'Hare et al. 2007)](https://doi.org/10.1182/blood-2007-03-066936)
|
25 |
+
- **`alk_mutations.csv`**: raw data from the literature with EML4::ALK mutations [(Elshatlawy et al. 2023)](https://doi.org/10.1002/1878-0261.13446)
|
26 |
+
- **`color_recovered_mutations_public.csv`**: notebook file used to write PyMOL code for visualizations in Fig. 5C
|
27 |
+
- **`recover_public.py`**: python file to run the analysis. Since private sequences are removed, this notebook will not run, but it includes all steps taken.
|
28 |
+
- **`config.py`**: small config file to specify which FusOn-pLM checkpoint is being used.
|
29 |
+
|
30 |
+
In the 📁 **`results/final_public`** directory,
|
31 |
+
- **`BCR_ABL_mutation_recovery_fuson_mutated_pns_only.csv`**: raw logits for every possible mutation at all positions in BCR::ABL with known drug resistance mutations
|
32 |
+
- **`Supplementary Tables - BCR ABL Mutations.csv`**: supplementary table showing the calculations going into the hit rate for BCR::ABL
|
33 |
+
- **`EML4_ALK_mutation_recovery_fuson_mutated_pns_only.csv`**: raw logits for every possible mutation at all positions in EML4::ALK with known drug resistance mutations
|
34 |
+
- **`Supplementary Tables - EML4 ALK Mutations.csv`**: supplementary table showing the calculations going into the hit rate for EML4::ALK
|
35 |
+
|
36 |
+
## Discovery
|
37 |
+
The `mutation_prediction/discovery/` directory contains all code and files needed to reproduce Fig. 5B and 5D, where mutations are predicted at each position in several fusion oncoproteins.
|
38 |
+
|
39 |
+
### Data download
|
40 |
+
To help select TF and Kinase-containing fusions for investigation (Fig. 5B), Supplementary Table 3 from [Salokas et al. 2020](https://doi.org/10.1038/s41598-020-71040-8) was downloaded as a reference of transcription factors and kinases. In `clean.py`, this data is processed.
|
41 |
+
|
42 |
+
```
|
43 |
+
benchmarking/
|
44 |
+
└── mutation_prediction/
|
45 |
+
└── discovery/
|
46 |
+
└── raw_data/
|
47 |
+
├── salokas_2020_tableS3.csv
|
48 |
+
└── processed_data/
|
49 |
+
├── test_seqs_tftf_kk.csv
|
50 |
+
├── domain_conservation_fusions_inputfile.csv
|
51 |
+
```
|
52 |
+
|
53 |
+
- **`raw_data/salokas_2020_tableS3.csv`**: Supplementary Table 3 from [Salokas et al. 2020](https://doi.org/10.1038/s41598-020-71040-8)
|
54 |
+
- **`processed_data/test_seqs_tftf_kk.csv`**: fusion oncoproteins in the FusOn-pLM test set that have either a Transcription Factor (TF) or a kinase as either the head or tail.
|
55 |
+
- **`processed_data/domain_conservation_fusions_inputfile.csv`**: input file for discovery, containing the longest sequences of EWSR1::FLI1, PAX3::FOXO1, TRIM24::RET, and ETV6::NTRK3.
|
56 |
+
|
57 |
+
### Mutation discovery
|
58 |
+
|
59 |
+
Run `discover.py` to perform discovery on the sequences specified in `config.py` (either one or many):
|
60 |
+
|
61 |
+
```
|
62 |
+
# Model settings: where is the model you wish to use for mutation discovery?
|
63 |
+
FUSON_PLM_CKPT = "FusOn-pLM"
|
64 |
+
|
65 |
+
#### Fill in this sectinon if you have one input
|
66 |
+
# Sequence settings: need full sequence of fusion oncoprotein, and the bounds of region of interest
|
67 |
+
FULL_FUSION_SEQUENCE = ""
|
68 |
+
FUSION_NAME = "fusion_example"
|
69 |
+
START_RESIDUE_INDEX = 1
|
70 |
+
END_RESIDUE_INDEX = 100
|
71 |
+
N = 3 # number of mutations to predict per amio acid
|
72 |
+
|
73 |
+
#### Fill in this section if you have multiple input
|
74 |
+
PATH_TO_INPUT_FILE = "processed_data/domain_conservation_fusions_inputfile.csv" # if you don't have an input file and want to do one sequence, set this variable to None
|
75 |
+
|
76 |
+
# GPU Settings: which GPUs should be available to run this discovery?
|
77 |
+
CUDA_VISIBLE_DEVICES = "0"
|
78 |
+
```
|
79 |
+
|
80 |
+
To run, use:
|
81 |
+
```
|
82 |
+
nohup python discover.py > discover.out 2> discover.err &
|
83 |
+
```
|
84 |
+
- All **results** are stored in `idr_prediction/results/<timestamp>`, where `timestamp` is a unique string encoding the date and time when you started training.
|
85 |
+
|
86 |
+
Below are the FusOn-pLM paper results in `results/final`:
|
87 |
+
|
88 |
+
```
|
89 |
+
benchmarking/
|
90 |
+
└── mutation_prediction/
|
91 |
+
└── discovery/
|
92 |
+
└── results/final/
|
93 |
+
└── EWSR1::FLI1/
|
94 |
+
├── conservation_heatmap.png
|
95 |
+
├── full_results_with_logits.csv
|
96 |
+
├── predicted_tokens.csv
|
97 |
+
├── raw_mutation_results.pkl
|
98 |
+
└── PAX3::FOXO1/ # same format as results/final/EWSR1::FLI1...
|
99 |
+
└── TRIM24::RET/ # same format as results/final/EWSR1::FLI1...
|
100 |
+
└── ETV6::NTRK3/ # same format as results/final/EWSR1::FLI1...
|
101 |
+
```
|
102 |
+
In each fusion oncoprotein folder are:
|
103 |
+
- **`conservation_heatmap.png`**: the heatmap from Fig. 5B
|
104 |
+
- **`full_results_with_logits.csv`**: predictions for every residue in the sequence. Columns = "Residue", "original_residue", "original_residue_logit", "all_logits", "top_3_mutations"
|
105 |
+
- **`predicted_tokens.csv`**: simplified format, top three tokens per residue and 1/0 conserved/not conserved label. Columns = "Original Residue", "Predicted Residues", "Conserved", "Position"
|
106 |
+
- **`raw_mutation_results.pkl`**: raw logits in dictionary format.
|
107 |
+
|
108 |
+
### Plotting
|
109 |
+
|
110 |
+
Three scripts aid with plotting:
|
111 |
+
|
112 |
+
1. `make_color_bar.py`: can be run to generate `viridis_color_bar.png`, a labeled and scaled up version of the color bar used in the conservation heatmaps and to color an ETV6::NTRK3 structure (Fig. 5D)
|
113 |
+
2. `plot.py`: includes code for heatmap generation. Can be run to regenerate heatmaps.
|
114 |
+
3. `color_discovered_mutations.ipynb`: notebook used to generate a modified version of the ETV6::NTRK3 structure file, which has the logits predicted by FusOn-pLM as the b factor (`processed_data/521_logit_bfactor.cif`, displayed in Fig. 5D, right). Also has PyMOL code for a head/tail visualization with recovered drug resistance mutations (displayed in Fig. 5D, left)
|
fuson_plm/benchmarking/mutation_prediction/discovery/clean.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
### Clean the Salokas data, find TF and Kinase fusions in the test set
|
2 |
+
import pandas as pd
|
3 |
+
import os
|
4 |
+
|
5 |
+
def get_gene_type(gene, d):
|
6 |
+
if gene in d:
|
7 |
+
if d[gene] == 'kinase':
|
8 |
+
return 'Kinase'
|
9 |
+
if d[gene] == 'tf':
|
10 |
+
return 'TF'
|
11 |
+
else:
|
12 |
+
return 'Other'
|
13 |
+
|
14 |
+
# Load TF and Kinase Fusions
|
15 |
+
def main():
|
16 |
+
os.makedirs("processed_data", exist_ok=True)
|
17 |
+
|
18 |
+
tf_kinase_parts = pd.read_csv("raw_data/salokas_2020_tableS3.csv")
|
19 |
+
print(tf_kinase_parts)
|
20 |
+
ht_tf_kinase_dict = dict(zip(tf_kinase_parts['Gene'],tf_kinase_parts['Kinase or TF']))
|
21 |
+
|
22 |
+
## Categorize everything in fuson_db
|
23 |
+
fuson_db = pd.read_csv("../../../data/fuson_db.csv")
|
24 |
+
print(fuson_db['benchmark'].value_counts())
|
25 |
+
print(fuson_db.loc[fuson_db['benchmark'].notna()])
|
26 |
+
fgenes = fuson_db.loc[fuson_db['benchmark'].notna()]['fusiongenes'].to_list()
|
27 |
+
print(fuson_db.columns)
|
28 |
+
print(fuson_db)
|
29 |
+
|
30 |
+
# This one has each row with one fusiongene name
|
31 |
+
fuson_ht_db = pd.read_csv("../../../data/blast/fuson_ht_db.csv")
|
32 |
+
print(fuson_ht_db.columns)
|
33 |
+
print(fuson_ht_db)
|
34 |
+
fuson_ht_db[['hg','tg']] = fuson_ht_db['fusiongenes'].str.split("::",expand=True)
|
35 |
+
print(fuson_ht_db.loc[fuson_ht_db['hg']=='PAX3'])
|
36 |
+
print(fuson_ht_db)
|
37 |
+
|
38 |
+
fuson_ht_db['hg_type'] = fuson_ht_db['hg'].apply(lambda x: get_gene_type(x, ht_tf_kinase_dict))
|
39 |
+
fuson_ht_db['tg_type'] = fuson_ht_db['tg'].apply(lambda x: get_gene_type(x, ht_tf_kinase_dict))
|
40 |
+
fuson_ht_db['fusion_type'] = fuson_ht_db['hg_type']+'::'+fuson_ht_db['tg_type']
|
41 |
+
fuson_ht_db['type']=['fusion']*len(fuson_ht_db)
|
42 |
+
|
43 |
+
# Keep things in the test set
|
44 |
+
test_set = pd.read_csv("../../../data/splits/test_df.csv")
|
45 |
+
print(test_set.columns, len(test_set))
|
46 |
+
test_seqs = test_set['sequence'].tolist()
|
47 |
+
fuson_ht_db = fuson_ht_db.loc[
|
48 |
+
fuson_ht_db['aa_seq'].isin(test_seqs)
|
49 |
+
].sort_values(by=['fusion_type']).reset_index(drop=True)
|
50 |
+
fuson_ht_db.to_csv("processed_data/test_seqs_tftf_kk.csv", index=False)
|
51 |
+
|
52 |
+
# isolate a few transcription factor fusions of interest and keep the longest sequence of each
|
53 |
+
fusion_genes_of_interest = [
|
54 |
+
"EWSR1::FLI1", "PAX3::FOXO1", "TRIM24::RET", "ETV6::NTRK3"
|
55 |
+
]
|
56 |
+
df_of_interest = fuson_ht_db.loc[
|
57 |
+
fuson_ht_db['fusiongenes'].isin(fusion_genes_of_interest)
|
58 |
+
].sort_values(by=['fusiongenes','length'],ascending=[True,False]).reset_index(drop=True).drop_duplicates(subset='fusiongenes').reset_index(drop=True)
|
59 |
+
#df_of_interest.to_csv("domain_conservation_fusions.csv",index=False)
|
60 |
+
# Make a file for input into
|
61 |
+
discovery_input = df_of_interest[['fusiongenes','length','aa_seq']]
|
62 |
+
discovery_input['start_residue_index'] = [1]*len(discovery_input)
|
63 |
+
discovery_input['n'] = [3]*len(discovery_input)
|
64 |
+
discovery_input = discovery_input.rename(columns={'length':'end_residue_index',
|
65 |
+
'aa_seq': 'full_fusion_sequence',
|
66 |
+
'fusiongenes':'fusion_name'})
|
67 |
+
discovery_input[['fusion_name','full_fusion_sequence','start_residue_index','end_residue_index','n']].to_csv("processed_data/domain_conservation_fusions_inputfile.csv",index=False)
|
68 |
+
print(discovery_input)
|
69 |
+
|
70 |
+
if __name__ == "__main__":
|
71 |
+
main()
|
fuson_plm/benchmarking/mutation_prediction/discovery/color_discovered_mutations.ipynb
ADDED
@@ -0,0 +1,418 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"metadata": {
|
6 |
+
"id": "FJd6a9gdZNjG"
|
7 |
+
},
|
8 |
+
"source": [
|
9 |
+
"### Imports"
|
10 |
+
]
|
11 |
+
},
|
12 |
+
{
|
13 |
+
"cell_type": "code",
|
14 |
+
"execution_count": null,
|
15 |
+
"metadata": {},
|
16 |
+
"outputs": [],
|
17 |
+
"source": [
|
18 |
+
"!pip install torch pandas numpy py3Dmol scikit-learn"
|
19 |
+
]
|
20 |
+
},
|
21 |
+
{
|
22 |
+
"cell_type": "code",
|
23 |
+
"execution_count": 37,
|
24 |
+
"metadata": {
|
25 |
+
"id": "ZEWZVc9lUxjI"
|
26 |
+
},
|
27 |
+
"outputs": [],
|
28 |
+
"source": [
|
29 |
+
"import torch\n",
|
30 |
+
"import torch.nn as nn\n",
|
31 |
+
"\n",
|
32 |
+
"import pickle\n",
|
33 |
+
"import pandas as pd\n",
|
34 |
+
"import numpy as np\n",
|
35 |
+
"\n",
|
36 |
+
"import py3Dmol\n",
|
37 |
+
"\n",
|
38 |
+
"from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, precision_recall_curve, average_precision_score"
|
39 |
+
]
|
40 |
+
},
|
41 |
+
{
|
42 |
+
"cell_type": "code",
|
43 |
+
"execution_count": null,
|
44 |
+
"metadata": {},
|
45 |
+
"outputs": [],
|
46 |
+
"source": [
|
47 |
+
"import os\n",
|
48 |
+
"os.getcwd()"
|
49 |
+
]
|
50 |
+
},
|
51 |
+
{
|
52 |
+
"cell_type": "markdown",
|
53 |
+
"metadata": {},
|
54 |
+
"source": [
|
55 |
+
"# Look at all results for ETV6::NTRK3 discovery"
|
56 |
+
]
|
57 |
+
},
|
58 |
+
{
|
59 |
+
"cell_type": "code",
|
60 |
+
"execution_count": null,
|
61 |
+
"metadata": {},
|
62 |
+
"outputs": [],
|
63 |
+
"source": [
|
64 |
+
"import pickle\n",
|
65 |
+
"path_to_pkl = \"../discovery/results/final/ETV6::NTRK3/raw_mutation_results.pkl\"\n",
|
66 |
+
"with open(path_to_pkl, \"rb\") as f:\n",
|
67 |
+
" etv6_ntrk3_logits = pickle.load(f)\n",
|
68 |
+
"\n",
|
69 |
+
"print(etv6_ntrk3_logits)"
|
70 |
+
]
|
71 |
+
},
|
72 |
+
{
|
73 |
+
"cell_type": "code",
|
74 |
+
"execution_count": 40,
|
75 |
+
"metadata": {},
|
76 |
+
"outputs": [],
|
77 |
+
"source": [
|
78 |
+
"# Define paths and dataframes that we will need \n",
|
79 |
+
"fusion_benchmark_set = pd.read_csv('../../caid/splits/fusion_bench_df.csv')\n",
|
80 |
+
"fusion_structure_data = pd.read_csv('../../caid/processed_data/fusionpdb/FusionPDB_level2-3_cleaned_structure_info.csv')\n",
|
81 |
+
"fusion_structure_data['Fusion_Structure_Link'] = fusion_structure_data['Fusion_Structure_Link'].apply(lambda x: x.split('/')[-1])\n",
|
82 |
+
"fusion_structure_folder = \"raw_data/fusionpdb/structures\""
|
83 |
+
]
|
84 |
+
},
|
85 |
+
{
|
86 |
+
"cell_type": "code",
|
87 |
+
"execution_count": null,
|
88 |
+
"metadata": {},
|
89 |
+
"outputs": [],
|
90 |
+
"source": [
|
91 |
+
"# merge fusion data with seq ids \n",
|
92 |
+
"fuson_db = pd.read_csv('../../../data/fuson_db.csv')\n",
|
93 |
+
"print(fuson_db.columns)\n",
|
94 |
+
"fuson_db = fuson_db[['aa_seq','seq_id']].rename(columns={'aa_seq':'Fusion_Seq'})\n",
|
95 |
+
"print(f\"Length of fusion structure data before merge on seqid: {len(fusion_structure_data)}\")\n",
|
96 |
+
"fusion_structure_data = pd.merge(\n",
|
97 |
+
" fusion_structure_data,\n",
|
98 |
+
" fuson_db,\n",
|
99 |
+
" on='Fusion_Seq',\n",
|
100 |
+
" how='inner'\n",
|
101 |
+
")\n",
|
102 |
+
"print(f\"Length of fusion structure data after merge on seqid: {len(fusion_structure_data)}\")\n",
|
103 |
+
"fusion_structure_data"
|
104 |
+
]
|
105 |
+
},
|
106 |
+
{
|
107 |
+
"cell_type": "code",
|
108 |
+
"execution_count": null,
|
109 |
+
"metadata": {},
|
110 |
+
"outputs": [],
|
111 |
+
"source": [
|
112 |
+
"# merge fusion structure data with top swissprot alignments\n",
|
113 |
+
"swissprot_top_alignments = pd.read_csv(\"../../../data/blast/blast_outputs/swissprot_top_alignments.csv\")\n",
|
114 |
+
"fusion_structure_data = pd.merge(\n",
|
115 |
+
" fusion_structure_data,\n",
|
116 |
+
" swissprot_top_alignments,\n",
|
117 |
+
" on=\"seq_id\",\n",
|
118 |
+
" how=\"left\"\n",
|
119 |
+
")\n",
|
120 |
+
"fusion_structure_data.head()"
|
121 |
+
]
|
122 |
+
},
|
123 |
+
{
|
124 |
+
"cell_type": "code",
|
125 |
+
"execution_count": null,
|
126 |
+
"metadata": {},
|
127 |
+
"outputs": [],
|
128 |
+
"source": [
|
129 |
+
"selected_row = fusion_structure_data.loc[\n",
|
130 |
+
" (fusion_structure_data['FusionGene'].str.contains('ETV6::NTRK3')) &\n",
|
131 |
+
" (fusion_structure_data['aa_seq_len']==528)\n",
|
132 |
+
"].reset_index(drop=True).iloc[0,:]\n",
|
133 |
+
"selected_row"
|
134 |
+
]
|
135 |
+
},
|
136 |
+
{
|
137 |
+
"cell_type": "code",
|
138 |
+
"execution_count": null,
|
139 |
+
"metadata": {},
|
140 |
+
"outputs": [],
|
141 |
+
"source": [
|
142 |
+
"seq = selected_row[\"Fusion_Seq\"]\n",
|
143 |
+
"kinase_seq = \"IVLKRELGEGAFGKVFLAECYNLSPTKDKMLVAVKALKDPTLAARKDFQREAELLTNLQHEHIVKFYGVCGDGDPLIMVFEYMKHGDLNKFLRAHGPDAMILVDGQPRQAKGELGLSQMLHIASQIASGMVYLASQHFVHRDLATRNCLVGANLLVKIGDFGMSRDVYSTDYYRLFNPSGNDFCIWCEVGGHTMLPIRWMPPESIMYRKFTTESDVWSFGVILWEIFTYGKQPWFQLSNTEVIECITQGRVLERPRVCPKEVYDVMLGCWQREPQQRLNIKEIYKILHALGKATPIYLDILG\"\n",
|
144 |
+
"print(\"Length of kinase domain: \", len(kinase_seq))\n",
|
145 |
+
"print(\"1-indexed start position of kinase domain\", seq.index(kinase_seq)-1)\n",
|
146 |
+
"print(\"1-indexed end position of kinase domain (inclusive)\", seq.index(kinase_seq)-1+len(kinase_seq)-1)\n"
|
147 |
+
]
|
148 |
+
},
|
149 |
+
{
|
150 |
+
"cell_type": "code",
|
151 |
+
"execution_count": 45,
|
152 |
+
"metadata": {},
|
153 |
+
"outputs": [],
|
154 |
+
"source": [
|
155 |
+
"kinase_seq = \"IVLKRELGEGAFGKVFLAECYNLSPTKDKMLVAVKALKDPTLAARKDFQREAELLTNLQHEHIVKFYGVCGDGDPLIMVFEYMKHGDLNKFLRAHGPDAMILVDGQPRQAKGELGLSQMLHIASQIASGMVYLASQHFVHRDLATRNCLVGANLLVKIGDFGMSRDVYSTDYYRLFNPSGNDFCIWCEVGGHTMLPIRWMPPESIMYRKFTTESDVWSFGVILWEIFTYGKQPWFQLSNTEVIECITQGRVLERPRVCPKEVYDVMLGCWQREPQQRLNIKEIYKILHALGKATPIYLDILG\""
|
156 |
+
]
|
157 |
+
},
|
158 |
+
{
|
159 |
+
"cell_type": "code",
|
160 |
+
"execution_count": null,
|
161 |
+
"metadata": {},
|
162 |
+
"outputs": [],
|
163 |
+
"source": [
|
164 |
+
"!pip install transformers"
|
165 |
+
]
|
166 |
+
},
|
167 |
+
{
|
168 |
+
"cell_type": "code",
|
169 |
+
"execution_count": 47,
|
170 |
+
"metadata": {},
|
171 |
+
"outputs": [],
|
172 |
+
"source": [
|
173 |
+
"from transformers import AutoTokenizer\n",
|
174 |
+
"import torch.nn.functional as F\n",
|
175 |
+
"fuson_ckpt_path = \"../../../..\"\n",
|
176 |
+
"fuson_tokenizer = AutoTokenizer.from_pretrained(fuson_ckpt_path)"
|
177 |
+
]
|
178 |
+
},
|
179 |
+
{
|
180 |
+
"cell_type": "code",
|
181 |
+
"execution_count": null,
|
182 |
+
"metadata": {},
|
183 |
+
"outputs": [],
|
184 |
+
"source": [
|
185 |
+
"print(len(etv6_ntrk3_logits['originals_logits']))\n",
|
186 |
+
"print(len(etv6_ntrk3_logits['conservation_likelihoods']))\n",
|
187 |
+
"print(len(etv6_ntrk3_logits['logits_for_each_AA']))\n",
|
188 |
+
"\n",
|
189 |
+
"start = etv6_ntrk3_logits['start']\n",
|
190 |
+
"end = etv6_ntrk3_logits['end']\n",
|
191 |
+
"originals_logits = etv6_ntrk3_logits['originals_logits']\n",
|
192 |
+
"conservation_likelihoods = etv6_ntrk3_logits['conservation_likelihoods']\n",
|
193 |
+
"logits = etv6_ntrk3_logits['logits']\n",
|
194 |
+
"logits_for_each_AA = etv6_ntrk3_logits['logits_for_each_AA']\n",
|
195 |
+
"filtered_indices = etv6_ntrk3_logits['filtered_indices']\n",
|
196 |
+
"top_n_mutations = etv6_ntrk3_logits['top_n_mutations']\n",
|
197 |
+
"\n",
|
198 |
+
"token_indices = torch.arange(logits.size(-1))\n",
|
199 |
+
"tokens = [fuson_tokenizer.decode([idx]) for idx in token_indices]\n",
|
200 |
+
"filtered_tokens = [tokens[i] for i in filtered_indices]\n",
|
201 |
+
"all_logits_array = np.vstack(logits_for_each_AA)\n",
|
202 |
+
"normalized_logits_array = F.softmax(torch.tensor(all_logits_array), dim=-1).numpy()\n",
|
203 |
+
"transposed_logits_array = normalized_logits_array.T"
|
204 |
+
]
|
205 |
+
},
|
206 |
+
{
|
207 |
+
"cell_type": "code",
|
208 |
+
"execution_count": null,
|
209 |
+
"metadata": {},
|
210 |
+
"outputs": [],
|
211 |
+
"source": [
|
212 |
+
"# add these logits as a b factor to the pdb file \n",
|
213 |
+
"import re\n",
|
214 |
+
"\n",
|
215 |
+
"path_to_cif = f\"../../caid/raw_data/fusionpdb/structures/{selected_row['Fusion_Structure_Link']}\"\n",
|
216 |
+
"path_to_modified_cif = f\"{selected_row['Fusion_Structure_Link'].split('.')[0]}_logit_bfactor.cif\"\n",
|
217 |
+
"\n",
|
218 |
+
"def modify(path_to_cif, path_to_modified_cif, new_b_values):\n",
|
219 |
+
" with open(path_to_cif, 'r') as f:\n",
|
220 |
+
" lines = f.readlines()\n",
|
221 |
+
"\n",
|
222 |
+
" exline = ''\n",
|
223 |
+
" with open(path_to_modified_cif, 'w') as f:\n",
|
224 |
+
" in_aa_loop=False\n",
|
225 |
+
" in_atom_loop=False\n",
|
226 |
+
" done_aa=False\n",
|
227 |
+
" done_atom=False\n",
|
228 |
+
" counter_aa=0\n",
|
229 |
+
" counter_atom=0\n",
|
230 |
+
" seqlen=len(selected_row['Fusion_Seq'])\n",
|
231 |
+
" for line in lines:\n",
|
232 |
+
" newline=line\n",
|
233 |
+
" if line.startswith(\"A\\t\") and not(done_aa):\n",
|
234 |
+
" in_aa_loop=True\n",
|
235 |
+
" else:\n",
|
236 |
+
" in_aa_loop=False\n",
|
237 |
+
" if line.startswith(\"ATOM \") and not(done_atom):\n",
|
238 |
+
" in_atom_loop=True\n",
|
239 |
+
" else:\n",
|
240 |
+
" in_atom_loop=False\n",
|
241 |
+
" \n",
|
242 |
+
" if in_aa_loop:\n",
|
243 |
+
" counter_aa+=1\n",
|
244 |
+
" split_line = line.split()\n",
|
245 |
+
" one_indexed_pos = int(split_line[2])\n",
|
246 |
+
" new_value = new_b_values[one_indexed_pos]\n",
|
247 |
+
" new_value = round(new_value,4)\n",
|
248 |
+
" split_line[4] = str(new_value)\n",
|
249 |
+
" newline = '\\t'.join(split_line)+'\\n'\n",
|
250 |
+
" \n",
|
251 |
+
" if in_atom_loop:\n",
|
252 |
+
" counter_atom+=1\n",
|
253 |
+
" split_line = re.split(r'(\\s+)', line)\n",
|
254 |
+
" one_indexed_pos = int(split_line[16])\n",
|
255 |
+
" new_value = new_b_values[one_indexed_pos]\n",
|
256 |
+
" new_value = round(new_value,4)\n",
|
257 |
+
" split_line[28] = str(new_value)\n",
|
258 |
+
" newline = ''.join(split_line)\n",
|
259 |
+
" #print(split_line)\n",
|
260 |
+
" \n",
|
261 |
+
" if counter_aa==seqlen:\n",
|
262 |
+
" done_aa=True\n",
|
263 |
+
" in_aa_loop=False\n",
|
264 |
+
" \n",
|
265 |
+
" f.write(newline)\n"
|
266 |
+
]
|
267 |
+
},
|
268 |
+
{
|
269 |
+
"cell_type": "code",
|
270 |
+
"execution_count": null,
|
271 |
+
"metadata": {},
|
272 |
+
"outputs": [],
|
273 |
+
"source": [
|
274 |
+
"# read ETV6::NTRK3 predictions\n",
|
275 |
+
"path_to_preds = \"../discovery/results/final/ETV6::NTRK3/full_results_with_logits.csv\"\n",
|
276 |
+
"preds = pd.read_csv(path_to_preds)\n",
|
277 |
+
"original_seq = ''.join(preds['original_residue'].tolist())\n",
|
278 |
+
"print(original_seq)\n",
|
279 |
+
"\n",
|
280 |
+
"# what is the structural data for this\n",
|
281 |
+
"structure_data = fusion_structure_data.loc[\n",
|
282 |
+
" fusion_structure_data['Fusion_Seq']==original_seq\n",
|
283 |
+
"].reset_index(drop=True).loc[0]\n",
|
284 |
+
"print(structure_data.to_string())\n",
|
285 |
+
"print(structure_data['top_hg_UniProt_fus_indices'], structure_data['top_tg_UniProt_fus_indices'])\n",
|
286 |
+
"print(structure_data['Fusion_Structure_Link'])\n"
|
287 |
+
]
|
288 |
+
},
|
289 |
+
{
|
290 |
+
"cell_type": "code",
|
291 |
+
"execution_count": null,
|
292 |
+
"metadata": {},
|
293 |
+
"outputs": [],
|
294 |
+
"source": [
|
295 |
+
"path_to_cif = f\"../../caid/raw_data/fusionpdb/structures/{structure_data['Fusion_Structure_Link']}\"\n",
|
296 |
+
"path_to_modified_cif = f\"processed_data/{structure_data['Fusion_Structure_Link'].split('.')[0]}_logit_bfactor.cif\"\n",
|
297 |
+
"\n",
|
298 |
+
"new_b_values = dict(zip(preds[\"Residue\"],preds[\"original_residue_logit\"]))\n",
|
299 |
+
"\n",
|
300 |
+
"modify(path_to_cif, path_to_modified_cif, new_b_values)"
|
301 |
+
]
|
302 |
+
},
|
303 |
+
{
|
304 |
+
"cell_type": "code",
|
305 |
+
"execution_count": null,
|
306 |
+
"metadata": {},
|
307 |
+
"outputs": [],
|
308 |
+
"source": [
|
309 |
+
"# spectrum b to color by logits\n",
|
310 |
+
"spectrum b\n",
|
311 |
+
"\n",
|
312 |
+
"sele head, resi 1-338\n",
|
313 |
+
"sele tail, resi 339-647\n",
|
314 |
+
"sele kinase, resi 346-647\n",
|
315 |
+
"set cartoon_transparency, 0.8, tail\n",
|
316 |
+
"set cartoon_transparency, 0.8, kinase\n",
|
317 |
+
"sele G623R, resi 431\n",
|
318 |
+
"sele G696A, resi 504\n",
|
319 |
+
"color orange, G623R\n",
|
320 |
+
"set cartoon_transparency, 0, G623R\n",
|
321 |
+
"show licorice, G623R\n",
|
322 |
+
"color orange, G696A\n",
|
323 |
+
"set cartoon_transparency, 0, G696A\n",
|
324 |
+
"show licorice, G696A"
|
325 |
+
]
|
326 |
+
},
|
327 |
+
{
|
328 |
+
"cell_type": "code",
|
329 |
+
"execution_count": null,
|
330 |
+
"metadata": {},
|
331 |
+
"outputs": [],
|
332 |
+
"source": [
|
333 |
+
"# PYMOL code for ETV6::NTRK3\n",
|
334 |
+
"\n",
|
335 |
+
"\n",
|
336 |
+
"# Define a custom color for the kinase domain\n",
|
337 |
+
"set_color custom_red, [0xdf/255, 0x83/255, 0x85/255]\n",
|
338 |
+
"set_color custom_blue, [0x6e/255, 0xa4/255, 0xda/255]\n",
|
339 |
+
"\n",
|
340 |
+
"# Select and color residues 1085-1336\n",
|
341 |
+
"sele head, resi 1-338\n",
|
342 |
+
"sele tail, resi 339-647\n",
|
343 |
+
"sele kinase, resi 346-647\n",
|
344 |
+
"color custom_red, head\n",
|
345 |
+
"color custom_blue, tail\n",
|
346 |
+
"set cartoon_transparency, 0, tail\n",
|
347 |
+
"set cartoon_transparency, 0.8, kinase\n",
|
348 |
+
"sele G623R, resi 431\n",
|
349 |
+
"sele G696A, resi 504\n",
|
350 |
+
"color magenta, G623R\n",
|
351 |
+
"set cartoon_transparency, 0, G623R\n",
|
352 |
+
"show licorice, G623R\n",
|
353 |
+
"color magenta, G696A\n",
|
354 |
+
"set cartoon_transparency, 0, G696A\n",
|
355 |
+
"show licorice, G696A\n",
|
356 |
+
"\n",
|
357 |
+
"# Color the known mutations that impact drug efficacy\n",
|
358 |
+
"\n",
|
359 |
+
"# Select missed mutation residues, color them viridis, and make them fully opaque\n",
|
360 |
+
"# do the viridis outside the "
|
361 |
+
]
|
362 |
+
},
|
363 |
+
{
|
364 |
+
"cell_type": "code",
|
365 |
+
"execution_count": null,
|
366 |
+
"metadata": {},
|
367 |
+
"outputs": [],
|
368 |
+
"source": [
|
369 |
+
"# PYMOL code for ETV6::NTRK3\n",
|
370 |
+
"# Set global cartoon transparency\n",
|
371 |
+
"# 0 for zoomed out, 0.5 for close up\n",
|
372 |
+
"color gray60\n",
|
373 |
+
"set cartoon_transparency, 0\n",
|
374 |
+
"\n",
|
375 |
+
"# Define a custom color for the kinase domain\n",
|
376 |
+
"set_color custom_blue, [0x6e/255, 0xa4/255, 0xda/255]\n",
|
377 |
+
"\n",
|
378 |
+
"# Select and color residues 1085-1336\n",
|
379 |
+
"sele ntrk3_kinase, resi 225-526\n",
|
380 |
+
"color custom_blue, ntrk3_kinase\n",
|
381 |
+
"\n",
|
382 |
+
"# Select missed mutation residues, color them orange, and make them fully opaque\n",
|
383 |
+
"spectrum b, blue_white_red, minimum=0.8, maximum=1"
|
384 |
+
]
|
385 |
+
}
|
386 |
+
],
|
387 |
+
"metadata": {
|
388 |
+
"colab": {
|
389 |
+
"collapsed_sections": [
|
390 |
+
"FJd6a9gdZNjG",
|
391 |
+
"zORkLJztZWp9",
|
392 |
+
"w25hagtZaV65",
|
393 |
+
"IbyqxlvAFUAK",
|
394 |
+
"0n5PSprbhLk7"
|
395 |
+
],
|
396 |
+
"machine_shape": "hm",
|
397 |
+
"provenance": []
|
398 |
+
},
|
399 |
+
"kernelspec": {
|
400 |
+
"display_name": "Python 3",
|
401 |
+
"name": "python3"
|
402 |
+
},
|
403 |
+
"language_info": {
|
404 |
+
"codemirror_mode": {
|
405 |
+
"name": "ipython",
|
406 |
+
"version": 3
|
407 |
+
},
|
408 |
+
"file_extension": ".py",
|
409 |
+
"mimetype": "text/x-python",
|
410 |
+
"name": "python",
|
411 |
+
"nbconvert_exporter": "python",
|
412 |
+
"pygments_lexer": "ipython3",
|
413 |
+
"version": "3.10.12"
|
414 |
+
}
|
415 |
+
},
|
416 |
+
"nbformat": 4,
|
417 |
+
"nbformat_minor": 0
|
418 |
+
}
|
fuson_plm/benchmarking/mutation_prediction/discovery/config.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Model settings: where is the model you wish to use for mutation discovery?
|
2 |
+
FUSON_PLM_CKPT = "FusOn-pLM"
|
3 |
+
|
4 |
+
#### Fill in this sectinon if you have one input
|
5 |
+
# Sequence settings: need full sequence of fusion oncoprotein, and the bounds of region of interest
|
6 |
+
FULL_FUSION_SEQUENCE = ""
|
7 |
+
FUSION_NAME = "fusion_example"
|
8 |
+
START_RESIDUE_INDEX = 1
|
9 |
+
END_RESIDUE_INDEX = 100
|
10 |
+
N = 3 # number of mutations to predict per amio acid
|
11 |
+
|
12 |
+
#### Fill in this section if you have multiple input
|
13 |
+
PATH_TO_INPUT_FILE = "processed_data/domain_conservation_fusions_inputfile.csv" # if you don't have an input file and want to do one sequence, set this variable to None
|
14 |
+
|
15 |
+
# GPU Settings: which GPUs should be available to run this discovery?
|
16 |
+
CUDA_VISIBLE_DEVICES = "0"
|
fuson_plm/benchmarking/mutation_prediction/discovery/discover.py
ADDED
@@ -0,0 +1,346 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
##### Discover mutations in new sequences. A tool
|
2 |
+
import fuson_plm.benchmarking.mutation_prediction.discovery.config as config
|
3 |
+
import os
|
4 |
+
import pickle
|
5 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = config.CUDA_VISIBLE_DEVICES
|
6 |
+
|
7 |
+
import pandas as pd
|
8 |
+
import numpy as np
|
9 |
+
import transformers
|
10 |
+
from transformers import AutoTokenizer, AutoModelForMaskedLM
|
11 |
+
import logging
|
12 |
+
import torch
|
13 |
+
import matplotlib.pyplot as plt
|
14 |
+
import seaborn as sns
|
15 |
+
import torch.nn.functional as F
|
16 |
+
|
17 |
+
from fuson_plm.utils.logging import open_logfile, log_update, get_local_time, print_configpy
|
18 |
+
from fuson_plm.utils.embedding import load_esm2_type
|
19 |
+
from fuson_plm.utils.visualizing import set_font
|
20 |
+
from fuson_plm.benchmarking.mutation_prediction.recovery.recover import check_env_variables, predict_positionwise_mutations
|
21 |
+
from fuson_plm.benchmarking.mutation_prediction.discovery.plot import plot_conservation_heatmap, plot_full_heatmap
|
22 |
+
|
23 |
+
def check_seq_inputs(sequence, AAs_tokens):
|
24 |
+
# checking sequence inputs for validity
|
25 |
+
if not sequence.strip():
|
26 |
+
raise Exception("Error: The sequence input is empty. Please enter a valid protein sequence.")
|
27 |
+
return None, None, None
|
28 |
+
if any(char not in AAs_tokens for char in sequence):
|
29 |
+
raise Exception("Error: The sequence input contains non-amino acid characters. Please enter a valid protein sequence.")
|
30 |
+
return None, None, None
|
31 |
+
|
32 |
+
def check_domain_bounds(domain_bounds):
|
33 |
+
try:
|
34 |
+
start = int(domain_bounds['start'])
|
35 |
+
end = int(domain_bounds['end'])
|
36 |
+
return start, end
|
37 |
+
except ValueError:
|
38 |
+
raise Exception("Error: Start and end indices must be integers.")
|
39 |
+
return None, None
|
40 |
+
if start >= end:
|
41 |
+
raise Exception("Start index must be smaller than end index.")
|
42 |
+
return None, None
|
43 |
+
if start == 0 and end != 0:
|
44 |
+
raise Exception("Indexing starts at 1. Please enter valid domain bounds.")
|
45 |
+
return None, None
|
46 |
+
if start <= 0 or end <= 0:
|
47 |
+
raise Exception("Domain bounds must be positive integers. Please enter valid domain bounds.")
|
48 |
+
return None, None
|
49 |
+
if start > len(sequence) or end > len(sequence):
|
50 |
+
raise Exception("Domain bounds exceed sequence length.")
|
51 |
+
return None, None
|
52 |
+
|
53 |
+
def check_n_input(n):
|
54 |
+
if n < 1:
|
55 |
+
raise Exception("Choose N>=1")
|
56 |
+
return None, None, None
|
57 |
+
|
58 |
+
def predict_positionwise_mutations(sequence, domain_bounds, n, model, tokenizer, device):
|
59 |
+
# Define amino acids and their token indices
|
60 |
+
AAs_tokens = ['L', 'A', 'G', 'V', 'S', 'E', 'R', 'T', 'I', 'D', 'P', 'K', 'Q', 'N', 'F', 'Y', 'M', 'H', 'W', 'C']
|
61 |
+
AAs_tokens_indices = {'L' : 4, 'A' : 5, 'G' : 6, 'V': 7, 'S' : 8, 'E' : 9, 'R' : 10, 'T' : 11, 'I': 12, 'D' : 13, 'P' : 14,
|
62 |
+
'K' : 15, 'Q' : 16, 'N' : 17, 'F' : 18, 'Y' : 19, 'M' : 20, 'H' : 21, 'W' : 22, 'C' : 23}
|
63 |
+
|
64 |
+
# checking all inputs for validity
|
65 |
+
log_update("\nChecking validity of sequence input, domain bounds, and N mutations")
|
66 |
+
check_seq_inputs(sequence, AAs_tokens)
|
67 |
+
start, end = check_domain_bounds(domain_bounds)
|
68 |
+
check_n_input(n)
|
69 |
+
|
70 |
+
# define start_index as start - 1 (because residues are 1-indexed, while Python is 0-indexed). end is same
|
71 |
+
start_index = start - 1
|
72 |
+
end_index = end
|
73 |
+
|
74 |
+
# place to store top n mutations and all logits
|
75 |
+
top_n_mutations = {}
|
76 |
+
top_n_advantage_mutations = {}
|
77 |
+
top_n_disadvantage_mutations = {}
|
78 |
+
logits_for_each_AA = []
|
79 |
+
llrs_for_each_AA = []
|
80 |
+
|
81 |
+
# storage for the conservation heatmap
|
82 |
+
originals_logits = []
|
83 |
+
conservation_likelihoods = {}
|
84 |
+
|
85 |
+
log_update("\nCalculating mutations. Printing currently masked position and mutation results.")
|
86 |
+
for i in range(len(sequence)):
|
87 |
+
# only iterate through the residues inside the domain
|
88 |
+
if start_index <= i <= (end_index - 1):
|
89 |
+
# isolate original residue and its index
|
90 |
+
original_residue = sequence[i]
|
91 |
+
original_residue_index = AAs_tokens_indices[original_residue]
|
92 |
+
masked_seq = sequence[:i] + '<mask>' + sequence[i+1:]
|
93 |
+
|
94 |
+
# prepare log
|
95 |
+
masked_seq_list = list(sequence[:i]) + ['<mask>'] + list(sequence[i+1:])
|
96 |
+
log_starti = i-min(5, i)
|
97 |
+
log_endi = i+5
|
98 |
+
log_update(f"\t{i+1}: residue = {original_residue}, masked sequence preview (pos {log_starti+1}-{log_endi}) = {''.join(masked_seq_list[log_starti:log_endi])}")
|
99 |
+
# prepare inputs
|
100 |
+
inputs = tokenizer(masked_seq, return_tensors="pt", padding=True, truncation=True, max_length=len(masked_seq)+2)
|
101 |
+
inputs = {k: v.to(device) for k, v in inputs.items()}
|
102 |
+
# forward pass
|
103 |
+
with torch.no_grad():
|
104 |
+
logits = model(**inputs).logits
|
105 |
+
|
106 |
+
# Find masked positions and extract their logits
|
107 |
+
mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1]
|
108 |
+
mask_token_logits = logits[0, mask_token_index, :] # shape: [1, vocab_size] == [1, 33]. logits for each vocab word at this position in the sequence
|
109 |
+
|
110 |
+
# Collect logits for the full heamtap
|
111 |
+
logits_array = mask_token_logits.cpu().numpy() # shape: [1, 33]
|
112 |
+
# filter out non-amino acid tokens
|
113 |
+
filtered_indices = list(range(4, 23 + 1)) # filtered indices are indices of amino acids
|
114 |
+
filtered_logits = logits_array[:, filtered_indices] # shape: [1, 20] only the 20 amino acids
|
115 |
+
logits_for_each_AA.append(filtered_logits) # get logits for each amino acid
|
116 |
+
|
117 |
+
# Collect LLRs for the LLR heatmap
|
118 |
+
log_probabilities = F.log_softmax(torch.tensor(mask_token_logits).cpu(), dim=-1).squeeze(0) # take log softmax of the [33] dimension
|
119 |
+
log_prob_og = log_probabilities[original_residue_index] # get the log probability of the TRUE residue underneath the mask
|
120 |
+
llrs = torch.tensor([(x-log_prob_og) for x in log_probabilities]) # calculate the LLR
|
121 |
+
#print(original_residue_index, llrs)
|
122 |
+
filtered_llrs = llrs[filtered_indices].numpy() # filter so it's [20], just the amino acids; only save this
|
123 |
+
filtered_llrs_array = np.array([filtered_llrs])
|
124 |
+
llrs_for_each_AA.append(filtered_llrs_array)
|
125 |
+
|
126 |
+
######### Top tokens
|
127 |
+
# Get top tokens based on LOGITS
|
128 |
+
all_tokens_logits = mask_token_logits.squeeze(0) # shape: [vocab_size] == [33]
|
129 |
+
top_tokens_indices = torch.argsort(all_tokens_logits, dim=0, descending=True) # sort the logits
|
130 |
+
mutation = []
|
131 |
+
# make sure we don't include non-AA tokens
|
132 |
+
for token_index in top_tokens_indices:
|
133 |
+
decoded_token = tokenizer.decode([token_index.item()])
|
134 |
+
# decoded all tokens, pick the top n amino acid ones
|
135 |
+
if decoded_token in AAs_tokens:
|
136 |
+
mutation.append(decoded_token)
|
137 |
+
if len(mutation) == n:
|
138 |
+
break
|
139 |
+
top_n_mutations[(sequence[i], i)] = mutation
|
140 |
+
log_update(f"\t\ttop {n} predicted AAs: {','.join(mutation)}")
|
141 |
+
|
142 |
+
# Get top tokens based on LLR
|
143 |
+
top_advantage_tokens_indices = torch.argsort(llrs, dim=0, descending=True) # sort the LLRs
|
144 |
+
advantage_mutation = []
|
145 |
+
# make sure we don't include non-AA tokens
|
146 |
+
for token_index in top_advantage_tokens_indices:
|
147 |
+
decoded_token = tokenizer.decode([token_index.item()])
|
148 |
+
# decoded all tokens, pick the top n amino acid ones
|
149 |
+
if decoded_token in AAs_tokens:
|
150 |
+
advantage_mutation.append(decoded_token)
|
151 |
+
if len(advantage_mutation) == n:
|
152 |
+
break
|
153 |
+
top_n_advantage_mutations[(sequence[i], i)] = advantage_mutation
|
154 |
+
log_update(f"\t\ttop {n} predicted advantageous mutations: {','.join(advantage_mutation)}")
|
155 |
+
|
156 |
+
# Get top tokens based on LLR
|
157 |
+
top_disadvantage_tokens_indices = torch.argsort(llrs, dim=0, descending=False) # sort the LLRs
|
158 |
+
disadvantage_mutation = []
|
159 |
+
# make sure we don't include non-AA tokens
|
160 |
+
for token_index in top_disadvantage_tokens_indices:
|
161 |
+
decoded_token = tokenizer.decode([token_index.item()])
|
162 |
+
# decoded all tokens, pick the top n amino acid ones
|
163 |
+
if decoded_token in AAs_tokens:
|
164 |
+
disadvantage_mutation.append(decoded_token)
|
165 |
+
if len(disadvantage_mutation) == n:
|
166 |
+
break
|
167 |
+
top_n_disadvantage_mutations[(sequence[i], i)] = disadvantage_mutation
|
168 |
+
log_update(f"\t\ttop {n} predicted disadvantageous mutations: {','.join(disadvantage_mutation)}")
|
169 |
+
|
170 |
+
# fill in the logits and conservation likelihoods for the second array
|
171 |
+
normalized_mask_token_logits = F.softmax(torch.tensor(mask_token_logits).cpu(), dim=-1).numpy()
|
172 |
+
normalized_mask_token_logits = np.squeeze(normalized_mask_token_logits)
|
173 |
+
originals_logit = normalized_mask_token_logits[original_residue_index]
|
174 |
+
originals_logits.append(originals_logit)
|
175 |
+
|
176 |
+
# a region is conserved if the probability of the amino acid is > 0.7; AKA, probability of a mutation is <= 0.3
|
177 |
+
if originals_logit > 0.7:
|
178 |
+
conservation_likelihoods[(original_residue, i)] = 1
|
179 |
+
log_update("\t\tConserved position")
|
180 |
+
else:
|
181 |
+
conservation_likelihoods[(original_residue, i)] = 0
|
182 |
+
log_update("\t\tNot conserved position")
|
183 |
+
|
184 |
+
# return a dictionary of all the things we need for the next part
|
185 |
+
return {
|
186 |
+
'start': start,
|
187 |
+
'end': end,
|
188 |
+
'originals_logits': originals_logits,
|
189 |
+
'conservation_likelihoods': conservation_likelihoods,
|
190 |
+
'logits': logits,
|
191 |
+
'filtered_indices': filtered_indices,
|
192 |
+
'top_n_mutations': top_n_mutations,
|
193 |
+
'top_n_advantage_mutations': top_n_advantage_mutations,
|
194 |
+
'top_n_disadvantage_mutations': top_n_disadvantage_mutations,
|
195 |
+
'logits_for_each_AA': logits_for_each_AA,
|
196 |
+
'llrs_for_each_AA': llrs_for_each_AA
|
197 |
+
}
|
198 |
+
|
199 |
+
def find_top_3(d):
|
200 |
+
temp = pd.DataFrame.from_dict(d, orient='index').reset_index()
|
201 |
+
temp = temp.sort_values(by=0,ascending=False).reset_index(drop=True)
|
202 |
+
temp = temp.iloc[0:3,:]
|
203 |
+
return_d = dict(zip(temp['index'],temp[0]))
|
204 |
+
return return_d
|
205 |
+
|
206 |
+
def make_full_results_df(mutation_results, tokenizer, original_sequence):
|
207 |
+
# Unpack mutation results
|
208 |
+
logits = mutation_results['logits']
|
209 |
+
logits_for_each_AA = mutation_results['logits_for_each_AA']
|
210 |
+
filtered_indices = mutation_results['filtered_indices']
|
211 |
+
llrs_for_each_AA = mutation_results['llrs_for_each_AA']
|
212 |
+
|
213 |
+
token_indices = torch.arange(logits.size(-1))
|
214 |
+
tokens = [tokenizer.decode([idx]) for idx in token_indices]
|
215 |
+
filtered_tokens = [tokens[i] for i in filtered_indices]
|
216 |
+
all_logits_array = np.vstack(logits_for_each_AA)
|
217 |
+
normalized_logits_array = F.softmax(torch.tensor(all_logits_array), dim=-1).numpy()
|
218 |
+
transposed_logits_array = normalized_logits_array.T
|
219 |
+
|
220 |
+
df = pd.DataFrame(transposed_logits_array.T)
|
221 |
+
df.columns = filtered_tokens
|
222 |
+
df.index = list(range(1, len(df)+1))
|
223 |
+
df['all_logits'] = df[filtered_tokens].to_dict(orient='index')
|
224 |
+
df['top_3_mutations'] = df['all_logits'].apply(lambda x: find_top_3(x))
|
225 |
+
df['original_residue'] = list(original_sequence)
|
226 |
+
df['original_residue_logit'] = df.apply(lambda row: row['all_logits'][row['original_residue']],axis=1)
|
227 |
+
df = df[['original_residue','original_residue_logit','all_logits','top_3_mutations']]
|
228 |
+
df = df.reset_index().rename(columns={'index':'Residue'})
|
229 |
+
return df
|
230 |
+
|
231 |
+
def make_small_results_df(mutation_results):
|
232 |
+
conservation_likelihoods = mutation_results['conservation_likelihoods']
|
233 |
+
top_n_mutations = mutation_results['top_n_mutations']
|
234 |
+
|
235 |
+
# store the predicted mutations in a dataframe
|
236 |
+
original_residues = []
|
237 |
+
mutations = []
|
238 |
+
positions = []
|
239 |
+
conserved = []
|
240 |
+
|
241 |
+
for key, value in top_n_mutations.items():
|
242 |
+
original_residue, position = key
|
243 |
+
original_residues.append(original_residue)
|
244 |
+
mutations.append(','.join(value))
|
245 |
+
positions.append(position + 1)
|
246 |
+
|
247 |
+
for i, (key, value) in enumerate(conservation_likelihoods.items()):
|
248 |
+
original_residue, position = key
|
249 |
+
if original_residues[i]==original_residue: # it should, otherwise something is wrong
|
250 |
+
conserved.append(value)
|
251 |
+
|
252 |
+
df = pd.DataFrame({
|
253 |
+
'Original Residue': original_residues,
|
254 |
+
'Predicted Residues': mutations,
|
255 |
+
'Conserved': conserved,
|
256 |
+
'Position': positions
|
257 |
+
})
|
258 |
+
return df
|
259 |
+
|
260 |
+
|
261 |
+
def main():
|
262 |
+
# Make results directory
|
263 |
+
os.makedirs('results',exist_ok=True)
|
264 |
+
output_dir = f'results/{get_local_time()}'
|
265 |
+
os.makedirs(output_dir,exist_ok=True)
|
266 |
+
|
267 |
+
# Predict mutations, writing results to a log inside of the output directory
|
268 |
+
with open_logfile(f"{output_dir}/mutation_discovery_log.txt"):
|
269 |
+
print_configpy(config)
|
270 |
+
# Make sure environment variables are set correctly
|
271 |
+
check_env_variables()
|
272 |
+
|
273 |
+
# Get device
|
274 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
275 |
+
log_update(f"Using device: {device}")
|
276 |
+
|
277 |
+
# Load fuson as AutoModelForMaskedLM
|
278 |
+
fuson_ckpt_path = config.FUSON_PLM_CKPT
|
279 |
+
if fuson_ckpt_path=="FusOn-pLM":
|
280 |
+
fuson_ckpt_path="../../../.."
|
281 |
+
model_name = "fuson_plm"
|
282 |
+
model_epoch = "best"
|
283 |
+
model_str = f"fuson_plm/best"
|
284 |
+
else:
|
285 |
+
model_name = list(fuson_ckpt_path.keys())[0]
|
286 |
+
epoch = list(fuson_ckpt_path.values())[0]
|
287 |
+
fuson_ckpt_path = f'../../training/checkpoints/{model_name}/checkpoint_epoch_{epoch}'
|
288 |
+
model_name, model_epoch = fuson_ckpt_path.split('/')[-2::]
|
289 |
+
model_epoch = model_epoch.split('checkpoint_')[-1]
|
290 |
+
model_str = f"{model_name}/{model_epoch}"
|
291 |
+
|
292 |
+
log_update(f"\nLoading FusOn-pLM model from {fuson_ckpt_path}")
|
293 |
+
fuson_tokenizer = AutoTokenizer.from_pretrained(fuson_ckpt_path)
|
294 |
+
fuson_model = AutoModelForMaskedLM.from_pretrained(fuson_ckpt_path)
|
295 |
+
fuson_model.to(device)
|
296 |
+
fuson_model.eval()
|
297 |
+
|
298 |
+
if (config.PATH_TO_INPUT_FILE is not None) and os.path.exists(config.PATH_TO_INPUT_FILE):
|
299 |
+
input_file = pd.read_csv(config.PATH_TO_INPUT_FILE)
|
300 |
+
else:
|
301 |
+
input_file = pd.DataFrame(
|
302 |
+
data={
|
303 |
+
'fusion_name': [config.FUSION_NAME],
|
304 |
+
'full_fusion_sequence': [config.FULL_FUSION_SEQUENCE],
|
305 |
+
'start_residue_index': [config.START_RESIDUE_INDEX],
|
306 |
+
'end_residue_index': [config.END_RESIDUE_INDEX],
|
307 |
+
'n': [config.N]
|
308 |
+
}
|
309 |
+
)
|
310 |
+
|
311 |
+
log_update(f"\nThere are {len(input_file)} total sequences on which to perform mutation discovery. Fusion Genes:")
|
312 |
+
log_update("\t" + "\n\t".join(input_file['fusion_name']))
|
313 |
+
# Loop through each input and make a subfolder with its data
|
314 |
+
for i in range(len(input_file)):
|
315 |
+
row = input_file.loc[i,:]
|
316 |
+
fusion_name = row['fusion_name']
|
317 |
+
full_fusion_sequence = row['full_fusion_sequence']
|
318 |
+
start_residue_index = row['start_residue_index']
|
319 |
+
end_residue_index = row['end_residue_index']
|
320 |
+
n = row['n']
|
321 |
+
sub_output_dir = f"{output_dir}/{fusion_name}"
|
322 |
+
os.makedirs(sub_output_dir,exist_ok=True)
|
323 |
+
|
324 |
+
# Predict postionwise mutations, plot the results
|
325 |
+
domain_bounds = {'start': start_residue_index, 'end': end_residue_index}
|
326 |
+
mutation_results = predict_positionwise_mutations(full_fusion_sequence, domain_bounds, n,
|
327 |
+
fuson_model, fuson_tokenizer, device)
|
328 |
+
|
329 |
+
# Save mutation results
|
330 |
+
with open(f"{sub_output_dir}/raw_mutation_results.pkl", "wb") as f:
|
331 |
+
pickle.dump(mutation_results, f)
|
332 |
+
|
333 |
+
# Plot the heatmaps
|
334 |
+
plot_full_heatmap(mutation_results, fuson_tokenizer,
|
335 |
+
fusion_name=fusion_name, save_path=f"{sub_output_dir}/full_heatmap.png")
|
336 |
+
plot_conservation_heatmap(mutation_results,
|
337 |
+
fusion_name=fusion_name, save_path=f"{sub_output_dir}/conservation_heatmap.png")
|
338 |
+
|
339 |
+
# Make results dataframe
|
340 |
+
small_mutation_results_df = make_small_results_df(mutation_results)
|
341 |
+
small_mutation_results_df.to_csv(f"{sub_output_dir}/predicted_tokens.csv",index=False)
|
342 |
+
full_mutation_results_df = make_full_results_df(mutation_results, fuson_tokenizer, full_fusion_sequence)
|
343 |
+
full_mutation_results_df.to_csv(f"{sub_output_dir}/full_results_with_logits.csv",index=False)
|
344 |
+
|
345 |
+
if __name__ == "__main__":
|
346 |
+
main()
|
fuson_plm/benchmarking/mutation_prediction/discovery/make_color_bar.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import matplotlib.pyplot as plt
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
def plot_color_bar():
|
5 |
+
"""
|
6 |
+
Create a Viridis color bar ranging from 0 to 1.
|
7 |
+
"""
|
8 |
+
# Create a gradient from 0 to 1
|
9 |
+
gradient = np.linspace(0, 1, 256).reshape(1, -1)
|
10 |
+
|
11 |
+
# Plot the gradient as a color bar
|
12 |
+
fig, ax = plt.subplots(figsize=(12, 3))
|
13 |
+
ax.imshow(gradient, aspect="auto", cmap="viridis")
|
14 |
+
ax.set_xticks([0, 255])
|
15 |
+
ax.set_xticklabels(["0\nmost likely\nto mutate", "1\nleast likely\nto mutate"], fontsize=40)
|
16 |
+
ax.set_yticks([])
|
17 |
+
ax.set_title("Original Residue Logits", fontsize=40)
|
18 |
+
|
19 |
+
# Save the figure
|
20 |
+
plt.tight_layout()
|
21 |
+
plt.show()
|
22 |
+
plt.savefig("viridis_color_bar.png", dpi=300)
|
23 |
+
|
24 |
+
# Call the function to create and display the color bar
|
25 |
+
plot_color_bar()
|
fuson_plm/benchmarking/mutation_prediction/discovery/plot.py
ADDED
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import matplotlib.pyplot as plt
|
3 |
+
import seaborn as sns
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import numpy as np
|
6 |
+
import os
|
7 |
+
import pandas as pd
|
8 |
+
import pickle
|
9 |
+
from transformers import AutoTokenizer
|
10 |
+
from fuson_plm.utils.visualizing import set_font
|
11 |
+
import fuson_plm.benchmarking.mutation_prediction.discovery.config as config
|
12 |
+
|
13 |
+
def get_x_tick_labels(start, end):
|
14 |
+
# Define start and end index which we actually use to index the sequence
|
15 |
+
start_index = start - 1
|
16 |
+
end_index = end
|
17 |
+
|
18 |
+
# Define domain length
|
19 |
+
domain_len = end - start
|
20 |
+
if 500 > domain_len > 100:
|
21 |
+
step_size = 50
|
22 |
+
elif 500 <= domain_len:
|
23 |
+
step_size = 100
|
24 |
+
elif domain_len < 10:
|
25 |
+
step_size = 1
|
26 |
+
else:
|
27 |
+
step_size = 10
|
28 |
+
|
29 |
+
# Define x tick positions based on step size
|
30 |
+
x_tick_positions = np.arange(start_index, end_index, step_size)
|
31 |
+
x_tick_labels = [str(pos + 1) for pos in x_tick_positions]
|
32 |
+
|
33 |
+
return x_tick_positions, x_tick_labels
|
34 |
+
|
35 |
+
|
36 |
+
def plot_conservation_heatmap(mutation_results, fusion_name="Fusion Oncoprotein", save_path="conservation_heatmap.png"):
|
37 |
+
start = mutation_results['start']
|
38 |
+
end = mutation_results['end']
|
39 |
+
originals_logits = mutation_results['originals_logits']
|
40 |
+
conservation_likelihoods = mutation_results['conservation_likelihoods']
|
41 |
+
logits = mutation_results['logits']
|
42 |
+
logits_for_each_AA = mutation_results['logits_for_each_AA']
|
43 |
+
filtered_indices = mutation_results['filtered_indices']
|
44 |
+
top_n_mutations = mutation_results['top_n_mutations']
|
45 |
+
|
46 |
+
# Get start index and end index
|
47 |
+
start_index = start - 1
|
48 |
+
end_index = end
|
49 |
+
|
50 |
+
# Make conservation likelihoods array for plotting
|
51 |
+
all_logits_array = np.vstack(originals_logits)
|
52 |
+
transposed_logits_array = all_logits_array.T
|
53 |
+
conservation_likelihoods_array = np.array(list(conservation_likelihoods.values())).reshape(1, -1)
|
54 |
+
# combine to make a 2D heatmap
|
55 |
+
combined_array = np.vstack((transposed_logits_array, conservation_likelihoods_array))
|
56 |
+
|
57 |
+
# Get ticks
|
58 |
+
x_tick_positions, x_tick_labels = get_x_tick_labels(start, end)
|
59 |
+
|
60 |
+
# Plot!
|
61 |
+
set_font()
|
62 |
+
# Adjust the figure size: constant height (e.g., 3) and width proportional to sequence length
|
63 |
+
sequence_length = end_index - start_index
|
64 |
+
fig = plt.figure(figsize=(min(15, sequence_length / 10), 3)) # Adjust width dynamically, keep height constant
|
65 |
+
|
66 |
+
#plt.rcParams.update({'font.size': 16.5}) # make font size bigger
|
67 |
+
ax = sns.heatmap(
|
68 |
+
combined_array,
|
69 |
+
cmap='viridis',
|
70 |
+
xticklabels=x_tick_labels,
|
71 |
+
yticklabels=['Original Logits', 'Conserved'],
|
72 |
+
cbar=True,
|
73 |
+
cbar_kws={'aspect': 2,
|
74 |
+
'pad': 0.02,
|
75 |
+
'shrink': 1.0, # Adjust the overall size of the color bar
|
76 |
+
}
|
77 |
+
)
|
78 |
+
# Access the color bar
|
79 |
+
cbar = ax.collections[0].colorbar
|
80 |
+
|
81 |
+
# Change the font size of the tick labels on the color bar
|
82 |
+
cbar.ax.tick_params(labelsize=20) # Adjust the font size of tick labels
|
83 |
+
|
84 |
+
plt.xticks(x_tick_positions - start_index + 0.5, x_tick_labels, rotation=90, fontsize=20)
|
85 |
+
plt.yticks(fontsize=20, rotation=0)
|
86 |
+
plt.title(f'{fusion_name} Residues {start}-{end}', fontsize=30)
|
87 |
+
plt.xlabel('Residue Index', fontsize=30)
|
88 |
+
plt.tight_layout()
|
89 |
+
plt.show()
|
90 |
+
|
91 |
+
# save the figure
|
92 |
+
plt.savefig(save_path, format='png', dpi=300)
|
93 |
+
|
94 |
+
# plotting heatmap 1
|
95 |
+
def plot_full_heatmap(mutation_results, tokenizer, fusion_name="Fusion Oncoprotein", save_path="full_heatmap.png"):
|
96 |
+
start = mutation_results['start']
|
97 |
+
end = mutation_results['end']
|
98 |
+
logits = mutation_results['logits']
|
99 |
+
logits_for_each_AA = mutation_results['logits_for_each_AA']
|
100 |
+
filtered_indices = mutation_results['filtered_indices']
|
101 |
+
|
102 |
+
# get start and end index
|
103 |
+
start_index = start - 1
|
104 |
+
end_index = end
|
105 |
+
|
106 |
+
# prepare data for plotting
|
107 |
+
token_indices = torch.arange(logits.size(-1))
|
108 |
+
tokens = [tokenizer.decode([idx]) for idx in token_indices]
|
109 |
+
filtered_tokens = [tokens[i] for i in filtered_indices]
|
110 |
+
all_logits_array = np.vstack(logits_for_each_AA)
|
111 |
+
normalized_logits_array = F.softmax(torch.tensor(all_logits_array), dim=-1).numpy()
|
112 |
+
transposed_logits_array = normalized_logits_array.T
|
113 |
+
|
114 |
+
# get x tick labels
|
115 |
+
x_tick_positions, x_tick_labels = get_x_tick_labels(start, end)
|
116 |
+
|
117 |
+
# make plot
|
118 |
+
set_font()
|
119 |
+
fig = plt.figure(figsize=(15, 8))
|
120 |
+
plt.rcParams.update({'font.size': 16.5})
|
121 |
+
sns.heatmap(transposed_logits_array, cmap='plasma', xticklabels=x_tick_labels, yticklabels=filtered_tokens)
|
122 |
+
plt.title(f'{fusion_name} Residues {start}-{end}: Token Probability')
|
123 |
+
plt.ylabel('Amino Acid')
|
124 |
+
plt.xlabel('Residue Index')
|
125 |
+
plt.yticks(rotation=0)
|
126 |
+
plt.xticks(x_tick_positions - start_index + 0.5, x_tick_labels, rotation=0)
|
127 |
+
plt.tight_layout()
|
128 |
+
plt.savefig(save_path, format='png', dpi = 300)
|
129 |
+
|
130 |
+
def plot_color_bar():
|
131 |
+
"""
|
132 |
+
Create a Viridis color bar ranging from 0 to 1.
|
133 |
+
"""
|
134 |
+
# Create a gradient from 0 to 1
|
135 |
+
gradient = np.linspace(0, 1, 256).reshape(1, -1)
|
136 |
+
|
137 |
+
# Plot the gradient as a color bar
|
138 |
+
fig, ax = plt.subplots(figsize=(12, 3))
|
139 |
+
ax.imshow(gradient, aspect="auto", cmap="viridis")
|
140 |
+
ax.set_xticks([0, 255])
|
141 |
+
ax.set_xticklabels(["0\nmost likely\nto mutate", "1\nleast likely\nto mutate"], fontsize=40)
|
142 |
+
ax.set_yticks([])
|
143 |
+
ax.set_title("Original Residue Logits", fontsize=40)
|
144 |
+
|
145 |
+
# Save the figure
|
146 |
+
plt.tight_layout()
|
147 |
+
plt.show()
|
148 |
+
plt.savefig("viridis_color_bar.png", dpi=300)
|
149 |
+
|
150 |
+
def main():
|
151 |
+
# Call the function to create and display the color bar
|
152 |
+
plot_color_bar()
|
153 |
+
|
154 |
+
results_dir = "results/final"
|
155 |
+
subfolders = os.listdir(results_dir)
|
156 |
+
for subfolder in subfolders:
|
157 |
+
full_path = f"{results_dir}/{subfolder}"
|
158 |
+
if os.path.isdir(full_path):
|
159 |
+
with open(f"{full_path}/raw_mutation_results.pkl", "rb") as f:
|
160 |
+
mutation_results = pickle.load(f)
|
161 |
+
plot_conservation_heatmap(mutation_results,
|
162 |
+
fusion_name=subfolder, save_path=f"{full_path}/conservation_heatmap.png")
|
163 |
+
|
164 |
+
|
165 |
+
|
166 |
+
if __name__ == "__main__":
|
167 |
+
main()
|
fuson_plm/benchmarking/mutation_prediction/discovery/processed_data/521_logit_bfactor.cif
ADDED
The diff for this file is too large to render.
See raw diff
|
|
pytorch_model.bin → fuson_plm/benchmarking/mutation_prediction/discovery/processed_data/domain_conservation_fusions_inputfile.csv
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a7bef22b6753a20845f095872431e1a222ac2be40ffa20bfe36113d80a7b57b2
|
3 |
+
size 3119
|
fuson_plm/benchmarking/mutation_prediction/discovery/processed_data/test_seqs_tftf_kk.csv
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4e77c7777c0163df7b4ea740d16c5aff0a02cec03a251130dde0cf9b291d4fa2
|
3 |
+
size 4023866
|
fuson_plm/benchmarking/mutation_prediction/discovery/raw_data/salokas_2020_tableS3.csv
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d8bebc0871a4329015a3c6c7843f5bbc86c48811b2a836c42f1ef46b37f4282a
|
3 |
+
size 19626
|
fuson_plm/benchmarking/mutation_prediction/discovery/results/final/ETV6::NTRK3/conservation_heatmap.png
ADDED
![]() |
fuson_plm/benchmarking/mutation_prediction/discovery/results/final/ETV6::NTRK3/full_results_with_logits.csv
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1e881d4a4c081da84a60991e2ad9598fcf1dba47cbcb92b5cb6d70c29f4217ea
|
3 |
+
size 432231
|
fuson_plm/benchmarking/mutation_prediction/discovery/results/final/ETV6::NTRK3/predicted_tokens.csv
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3349b52a936dbce7810faf3e17a296cf63701fb87e64d51de2d880575778b1b0
|
3 |
+
size 10299
|
fuson_plm/benchmarking/mutation_prediction/discovery/results/final/ETV6::NTRK3/raw_mutation_results.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9762879aacee6bd42440326c3d7b68871df8caacf97127db213c47eec9fa7d1f
|
3 |
+
size 315798
|
fuson_plm/benchmarking/mutation_prediction/discovery/results/final/EWSR1::FLI1/conservation_heatmap.png
ADDED
![]() |
fuson_plm/benchmarking/mutation_prediction/discovery/results/final/EWSR1::FLI1/full_results_with_logits.csv
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:82cddf491a55bf2254ab86b2a20baf161097b5538a5868c0660c88baf856f672
|
3 |
+
size 349203
|
fuson_plm/benchmarking/mutation_prediction/discovery/results/final/EWSR1::FLI1/predicted_tokens.csv
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ebd172fd6717511efd5b4971e0370381f1a29d29b12dbc196b82585ed2ef5065
|
3 |
+
size 8363
|
fuson_plm/benchmarking/mutation_prediction/discovery/results/final/EWSR1::FLI1/raw_mutation_results.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4973a6fc35ee9647f02f83627181b83e5f861c5918dc40abf4abb872d3e7088b
|
3 |
+
size 256739
|
fuson_plm/benchmarking/mutation_prediction/discovery/results/final/PAX3::FOXO1/conservation_heatmap.png
ADDED
![]() |
fuson_plm/benchmarking/mutation_prediction/discovery/results/final/PAX3::FOXO1/full_results_with_logits.csv
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6d7841257792c6d83d2cfcd5456bf8761b3d3fc0dd69cd912ecfc05a80ad3dec
|
3 |
+
size 545639
|
fuson_plm/benchmarking/mutation_prediction/discovery/results/final/PAX3::FOXO1/predicted_tokens.csv
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e8d94cfad62e2b62bf60d62cfef02fa2328f5d37c7f1c141575d098896bc2103
|
3 |
+
size 13323
|
fuson_plm/benchmarking/mutation_prediction/discovery/results/final/PAX3::FOXO1/raw_mutation_results.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0dab0b3ee7ea7f9c3c67fdef16bf81dd48654bb8bfdf12eedd58e1972d908a24
|
3 |
+
size 408037
|
fuson_plm/benchmarking/mutation_prediction/discovery/results/final/TRIM24::RET/conservation_heatmap.png
ADDED
![]() |
fuson_plm/benchmarking/mutation_prediction/discovery/results/final/TRIM24::RET/full_results_with_logits.csv
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9ad3fc01c9ce3c67330d8fc79d14a551aa4d7f8c8d596a49e8e8042d6a8fbb1b
|
3 |
+
size 635205
|
fuson_plm/benchmarking/mutation_prediction/discovery/results/final/TRIM24::RET/predicted_tokens.csv
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:22cdcd8cfbd5f70f1067d2104e6f0e3fc5e4e5e3fddd736934ce905b35cd0cea
|
3 |
+
size 15195
|
fuson_plm/benchmarking/mutation_prediction/discovery/results/final/TRIM24::RET/raw_mutation_results.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1d719d23b4c7e6deb79ac9ef3c5c84a1c0c3a2e4ae9926f9a460731bd78f9e0c
|
3 |
+
size 465133
|
fuson_plm/benchmarking/mutation_prediction/discovery/viridis_color_bar.png
ADDED
![]() |
fuson_plm/benchmarking/mutation_prediction/recovery/abl_mutations.csv
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:22679a1f55d49ed4a8d5309c27215a4a204f47b49ce0a7a266a674208e381e85
|
3 |
+
size 307
|
fuson_plm/benchmarking/mutation_prediction/recovery/alk_mutations.csv
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9339f8cd5ea5369a9b87ca3b420ee2731c4e057d5c546c3c2a09c727154721df
|
3 |
+
size 192
|
fuson_plm/benchmarking/mutation_prediction/recovery/color_recovered_mutations_public.ipynb
ADDED
@@ -0,0 +1,314 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"metadata": {
|
6 |
+
"id": "FJd6a9gdZNjG"
|
7 |
+
},
|
8 |
+
"source": [
|
9 |
+
"### Imports"
|
10 |
+
]
|
11 |
+
},
|
12 |
+
{
|
13 |
+
"cell_type": "code",
|
14 |
+
"execution_count": 1,
|
15 |
+
"metadata": {},
|
16 |
+
"outputs": [],
|
17 |
+
"source": [
|
18 |
+
"## Put path to model predictions you'd like to use for benchmarking\n",
|
19 |
+
"path_to_bcr_abl_preds = \"results/11-21-2024-16:03:59/Supplementary Tables - BCR ABL Mutations.csv\"\n",
|
20 |
+
"path_to_eml4_alk_preds = \"results/11-21-2024-16:03:59/Supplementary Tables - EML4 ALK Mutations.csv\""
|
21 |
+
]
|
22 |
+
},
|
23 |
+
{
|
24 |
+
"cell_type": "code",
|
25 |
+
"execution_count": null,
|
26 |
+
"metadata": {},
|
27 |
+
"outputs": [],
|
28 |
+
"source": [
|
29 |
+
"!pip install torch pandas numpy py3Dmol scikit-learn"
|
30 |
+
]
|
31 |
+
},
|
32 |
+
{
|
33 |
+
"cell_type": "code",
|
34 |
+
"execution_count": 3,
|
35 |
+
"metadata": {
|
36 |
+
"id": "ZEWZVc9lUxjI"
|
37 |
+
},
|
38 |
+
"outputs": [],
|
39 |
+
"source": [
|
40 |
+
"import torch\n",
|
41 |
+
"import torch.nn as nn\n",
|
42 |
+
"\n",
|
43 |
+
"import pickle\n",
|
44 |
+
"import pandas as pd\n",
|
45 |
+
"import numpy as np\n",
|
46 |
+
"\n",
|
47 |
+
"import py3Dmol\n",
|
48 |
+
"\n",
|
49 |
+
"from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, precision_recall_curve, average_precision_score"
|
50 |
+
]
|
51 |
+
},
|
52 |
+
{
|
53 |
+
"cell_type": "code",
|
54 |
+
"execution_count": null,
|
55 |
+
"metadata": {},
|
56 |
+
"outputs": [],
|
57 |
+
"source": [
|
58 |
+
"bcr_abl_preds = pd.read_csv(path_to_bcr_abl_preds)\n",
|
59 |
+
"eml4_alk_preds = pd.read_csv(path_to_eml4_alk_preds)\n",
|
60 |
+
"bcr_abl_seq = ''.join(bcr_abl_preds['Original Residue'].tolist())\n",
|
61 |
+
"eml4_alk_seq = ''.join(eml4_alk_preds['Original Residue'].tolist())\n",
|
62 |
+
"\n",
|
63 |
+
"print(\"BCR::ABL seq: \", bcr_abl_seq)\n",
|
64 |
+
"print(\"EML4::ALK seq: \", eml4_alk_seq)"
|
65 |
+
]
|
66 |
+
},
|
67 |
+
{
|
68 |
+
"cell_type": "code",
|
69 |
+
"execution_count": null,
|
70 |
+
"metadata": {},
|
71 |
+
"outputs": [],
|
72 |
+
"source": [
|
73 |
+
"bcr_abl_preds = bcr_abl_preds[bcr_abl_preds['Literature Mutation'].notna()].reset_index(drop=True)\n",
|
74 |
+
"eml4_alk_preds = eml4_alk_preds[eml4_alk_preds['Literature Mutation'].notna()].reset_index(drop=True)\n",
|
75 |
+
"\n",
|
76 |
+
"bcr_abl_preds"
|
77 |
+
]
|
78 |
+
},
|
79 |
+
{
|
80 |
+
"cell_type": "code",
|
81 |
+
"execution_count": null,
|
82 |
+
"metadata": {},
|
83 |
+
"outputs": [],
|
84 |
+
"source": [
|
85 |
+
"# Calculate hits\n",
|
86 |
+
"print(\"BCR::ABL\", len(bcr_abl_preds), sum(bcr_abl_preds['Hit']))\n",
|
87 |
+
"print(\"EML4::ALK\", len(eml4_alk_preds), sum(eml4_alk_preds['Hit']))"
|
88 |
+
]
|
89 |
+
},
|
90 |
+
{
|
91 |
+
"cell_type": "code",
|
92 |
+
"execution_count": null,
|
93 |
+
"metadata": {},
|
94 |
+
"outputs": [],
|
95 |
+
"source": [
|
96 |
+
"import os\n",
|
97 |
+
"os.getcwd()"
|
98 |
+
]
|
99 |
+
},
|
100 |
+
{
|
101 |
+
"cell_type": "code",
|
102 |
+
"execution_count": null,
|
103 |
+
"metadata": {},
|
104 |
+
"outputs": [],
|
105 |
+
"source": [
|
106 |
+
"# Color EML4 ALK structure\n",
|
107 |
+
"path_to_eml4_alk_structure = np.nan # not publcly available\n",
|
108 |
+
"eml4_alk_seq = np.nan # not publicly available\n",
|
109 |
+
"alk_kinase_domain_seq = np.nan # not publicly available \n",
|
110 |
+
"alk_kinase_domain_resis = [eml4_alk_seq.index(alk_kinase_domain_seq)+1, eml4_alk_seq.index(alk_kinase_domain_seq)+len(alk_kinase_domain_seq)]\n",
|
111 |
+
"print(alk_kinase_domain_resis)\n",
|
112 |
+
"print(len(eml4_alk_seq))"
|
113 |
+
]
|
114 |
+
},
|
115 |
+
{
|
116 |
+
"cell_type": "code",
|
117 |
+
"execution_count": null,
|
118 |
+
"metadata": {},
|
119 |
+
"outputs": [],
|
120 |
+
"source": [
|
121 |
+
"l = list(range(567,843+1))\n",
|
122 |
+
"l = [str(x) for x in l]\n",
|
123 |
+
"print('+'.join(l))"
|
124 |
+
]
|
125 |
+
},
|
126 |
+
{
|
127 |
+
"cell_type": "code",
|
128 |
+
"execution_count": null,
|
129 |
+
"metadata": {},
|
130 |
+
"outputs": [],
|
131 |
+
"source": [
|
132 |
+
"eml4_alk_mutation_resis = eml4_alk_preds['Position'].tolist()\n",
|
133 |
+
"eml4_alk_hit_resis = eml4_alk_preds.loc[\n",
|
134 |
+
" eml4_alk_preds['Hit']\n",
|
135 |
+
"]['Position'].tolist()\n",
|
136 |
+
"kinase_domain_coloring = {\n",
|
137 |
+
" 'kinase': [alk_kinase_domain_resis[0], alk_kinase_domain_resis[1], '#6ea4da']\n",
|
138 |
+
"}\n",
|
139 |
+
"\n",
|
140 |
+
"missed_mut_resis = [x for x in eml4_alk_mutation_resis if x not in eml4_alk_hit_resis]\n",
|
141 |
+
"print('missed', '+'.join([str(x) for x in missed_mut_resis]))\n",
|
142 |
+
"\n",
|
143 |
+
"print('hit', '+'.join([str(x) for x in eml4_alk_hit_resis]))"
|
144 |
+
]
|
145 |
+
},
|
146 |
+
{
|
147 |
+
"cell_type": "markdown",
|
148 |
+
"metadata": {},
|
149 |
+
"source": [
|
150 |
+
"# PYMOL code for EML4 ALK\n",
|
151 |
+
"# Set global cartoon transparency\n",
|
152 |
+
"# 0.5 for close up\n",
|
153 |
+
"color gray60\n",
|
154 |
+
"set cartoon_transparency, 0.6\n",
|
155 |
+
"\n",
|
156 |
+
"# Define a custom color\n",
|
157 |
+
"set_color custom_blue, [0x6e/255, 0xa4/255, 0xda/255]\n",
|
158 |
+
"\n",
|
159 |
+
"# Select and color residues 567-843\n",
|
160 |
+
"sele alk, resi 567-843\n",
|
161 |
+
"color custom_blue, alk\n",
|
162 |
+
"\n",
|
163 |
+
"# Select missed mutation residues, color them orange, and make them fully opaque\n",
|
164 |
+
"sele missed_mut_resis, resi 603+707\n",
|
165 |
+
"color orange, missed_mut_resis\n",
|
166 |
+
"set cartoon_transparency, 0, missed_mut_resis\n",
|
167 |
+
"\n",
|
168 |
+
"# Select hit mutation residues, color them magenta, and make them fully opaque\n",
|
169 |
+
"sele hit_mut_resis, resi 607+622+625+631+647+649+653+654+657+661+696+720\n",
|
170 |
+
"color magenta, hit_mut_resis\n",
|
171 |
+
"set cartoon_transparency, 0, hit_mut_resis\n"
|
172 |
+
]
|
173 |
+
},
|
174 |
+
{
|
175 |
+
"cell_type": "code",
|
176 |
+
"execution_count": null,
|
177 |
+
"metadata": {},
|
178 |
+
"outputs": [],
|
179 |
+
"source": [
|
180 |
+
"def color_mutation_recovery(pdb_file, mutation_resis, hit_resis, other_domain_coloring=None):\n",
|
181 |
+
" # Create viewer\n",
|
182 |
+
" viewer = py3Dmol.view()\n",
|
183 |
+
"\n",
|
184 |
+
" # Load CIF file\n",
|
185 |
+
" with open(pdb_file, 'r') as f:\n",
|
186 |
+
" pdb_file = f.read()\n",
|
187 |
+
"\n",
|
188 |
+
" # Add structure\n",
|
189 |
+
" viewer.addModel(pdb_file, 'pdb')\n",
|
190 |
+
" \n",
|
191 |
+
" viewer.setStyle({})\n",
|
192 |
+
" viewer.setStyle({'cartoon': {'color': 'lightgrey','transparency': 0.7}})\n",
|
193 |
+
" \n",
|
194 |
+
" # Apply colors based on normalized disorder values\n",
|
195 |
+
" for i, value in enumerate(mutation_resis):\n",
|
196 |
+
" style = {'cartoon': {'transparency': 1}}\n",
|
197 |
+
" if value in hit_resis:\n",
|
198 |
+
" style['cartoon']['color'] = 'yellow'\n",
|
199 |
+
" else:\n",
|
200 |
+
" style['cartoon']['color'] = 'red'\n",
|
201 |
+
" # Apply combined style directly, avoiding any residual layering\n",
|
202 |
+
" viewer.setStyle({'resi': value}, style)\n",
|
203 |
+
" \n",
|
204 |
+
" # If you want to color specific domains by default, do it here\n",
|
205 |
+
" if other_domain_coloring is not None:\n",
|
206 |
+
" for name, (start, end, color) in other_domain_coloring.items():\n",
|
207 |
+
" resis = list(range(start, end+1))\n",
|
208 |
+
" resis = [x for x in resis if x not in mutation_resis]\n",
|
209 |
+
" viewer.setStyle({'resi': resis}, {'cartoon': {'color': color,'transparency': 0.7}})\n",
|
210 |
+
"\n",
|
211 |
+
" # Show viewer\n",
|
212 |
+
" viewer.zoomTo()\n",
|
213 |
+
" return viewer.show()\n",
|
214 |
+
"\n",
|
215 |
+
"color_mutation_recovery(path_to_eml4_alk_structure, eml4_alk_mutation_resis, eml4_alk_hit_resis, \n",
|
216 |
+
" other_domain_coloring=kinase_domain_coloring)"
|
217 |
+
]
|
218 |
+
},
|
219 |
+
{
|
220 |
+
"cell_type": "code",
|
221 |
+
"execution_count": null,
|
222 |
+
"metadata": {},
|
223 |
+
"outputs": [],
|
224 |
+
"source": [
|
225 |
+
"# Color BCR::ABL structure\n",
|
226 |
+
"path_to_bcr_abl_structure = np.nan # not publicly available\n",
|
227 |
+
"bcr_abl_seq = np.nan # not publicly available \n",
|
228 |
+
"abl_kinase_domain_seq = np.nan # not publicly available \n",
|
229 |
+
"abl_kinase_domain_resis = [bcr_abl_seq.index(abl_kinase_domain_seq)+1, bcr_abl_seq.index(abl_kinase_domain_seq)+len(abl_kinase_domain_seq)]\n",
|
230 |
+
"print(abl_kinase_domain_resis)\n",
|
231 |
+
"print(abl_kinase_domain_seq)\n",
|
232 |
+
"print(len(bcr_abl_seq))"
|
233 |
+
]
|
234 |
+
},
|
235 |
+
{
|
236 |
+
"cell_type": "code",
|
237 |
+
"execution_count": null,
|
238 |
+
"metadata": {},
|
239 |
+
"outputs": [],
|
240 |
+
"source": [
|
241 |
+
"bcr_abl_mutation_resis = bcr_abl_preds['Position'].tolist()\n",
|
242 |
+
"bcr_abl_hit_resis = bcr_abl_preds.loc[\n",
|
243 |
+
" bcr_abl_preds['Hit']\n",
|
244 |
+
"]['Position'].tolist()\n",
|
245 |
+
"kinase_domain_coloring = {\n",
|
246 |
+
" 'kinase': [abl_kinase_domain_resis[0], abl_kinase_domain_resis[1], '#6ea4da']\n",
|
247 |
+
"}\n",
|
248 |
+
"\n",
|
249 |
+
"bcr_abl_missed_mut_resis = [x for x in bcr_abl_mutation_resis if x not in bcr_abl_hit_resis]\n",
|
250 |
+
"print('missed residues', len(bcr_abl_missed_mut_resis), '+'.join([str(x) for x in bcr_abl_missed_mut_resis]))\n",
|
251 |
+
"print('hit residues', len(bcr_abl_hit_resis), '+'.join([str(x) for x in bcr_abl_hit_resis]))"
|
252 |
+
]
|
253 |
+
},
|
254 |
+
{
|
255 |
+
"cell_type": "markdown",
|
256 |
+
"metadata": {},
|
257 |
+
"source": [
|
258 |
+
"# PYMOL code for BCR ABL\n",
|
259 |
+
"# Set global cartoon transparency\n",
|
260 |
+
"# 0 for zoomed out, 0.5 for close up\n",
|
261 |
+
"color gray60\n",
|
262 |
+
"set cartoon_transparency, 0.6\n",
|
263 |
+
"\n",
|
264 |
+
"# Define a custom color\n",
|
265 |
+
"set_color custom_blue, [0x6e/255, 0xa4/255, 0xda/255]\n",
|
266 |
+
"\n",
|
267 |
+
"# Select and color residues 1085-1336\n",
|
268 |
+
"sele abl, resi 1085-1336\n",
|
269 |
+
"color custom_blue, abl\n",
|
270 |
+
"\n",
|
271 |
+
"# Select missed mutation residues, color them orange, and make them fully opaque\n",
|
272 |
+
"sele missed_mut_resis, resi 1087+1090+1093+1095+1116+1125+1128+1135+1140+1158+1192+1194+1218+1249+1273\n",
|
273 |
+
"color orange, missed_mut_resis\n",
|
274 |
+
"set cartoon_transparency, 0, missed_mut_resis\n",
|
275 |
+
"\n",
|
276 |
+
"# Select hit mutation residues, color them magenta, and make them fully opaque\n",
|
277 |
+
"sele hit_mut_resis, resi 1091+1096+1098+1132+1142+1154+1160+1198+1202+1222+1227+1230+1239\n",
|
278 |
+
"color magenta, hit_mut_resis\n",
|
279 |
+
"set cartoon_transparency, 0, hit_mut_resis"
|
280 |
+
]
|
281 |
+
}
|
282 |
+
],
|
283 |
+
"metadata": {
|
284 |
+
"colab": {
|
285 |
+
"collapsed_sections": [
|
286 |
+
"FJd6a9gdZNjG",
|
287 |
+
"zORkLJztZWp9",
|
288 |
+
"w25hagtZaV65",
|
289 |
+
"IbyqxlvAFUAK",
|
290 |
+
"0n5PSprbhLk7"
|
291 |
+
],
|
292 |
+
"machine_shape": "hm",
|
293 |
+
"provenance": []
|
294 |
+
},
|
295 |
+
"kernelspec": {
|
296 |
+
"display_name": "Python 3",
|
297 |
+
"name": "python3"
|
298 |
+
},
|
299 |
+
"language_info": {
|
300 |
+
"codemirror_mode": {
|
301 |
+
"name": "ipython",
|
302 |
+
"version": 3
|
303 |
+
},
|
304 |
+
"file_extension": ".py",
|
305 |
+
"mimetype": "text/x-python",
|
306 |
+
"name": "python",
|
307 |
+
"nbconvert_exporter": "python",
|
308 |
+
"pygments_lexer": "ipython3",
|
309 |
+
"version": "3.10.12"
|
310 |
+
}
|
311 |
+
},
|
312 |
+
"nbformat": 4,
|
313 |
+
"nbformat_minor": 0
|
314 |
+
}
|
fuson_plm/benchmarking/mutation_prediction/recovery/config.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
FUSON_PLM_CKPT = "FusOn-pLM" # Dictionary: key = run name, values = epochs, or string "FusOn-pLM"
|
2 |
+
|
3 |
+
CUDA_VISIBLE_DEVICES = "0"
|
fuson_plm/benchmarking/mutation_prediction/recovery/recover_public.py
ADDED
@@ -0,0 +1,330 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#### Recover mutations from literature. A benchmark
|
2 |
+
import fuson_plm.benchmarking.mutation_prediction.recovery.config as config
|
3 |
+
import os
|
4 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = config.CUDA_VISIBLE_DEVICES
|
5 |
+
|
6 |
+
import pandas as pd
|
7 |
+
import numpy as np
|
8 |
+
import transformers
|
9 |
+
from transformers import AutoTokenizer, AutoModelForMaskedLM
|
10 |
+
import logging
|
11 |
+
import torch
|
12 |
+
import matplotlib.pyplot as plt
|
13 |
+
import seaborn as sns
|
14 |
+
import argparse
|
15 |
+
import os
|
16 |
+
import torch.nn.functional as F
|
17 |
+
|
18 |
+
from fuson_plm.utils.logging import open_logfile, log_update, get_local_time, print_configpy
|
19 |
+
from fuson_plm.benchmarking.embed import load_fuson_model
|
20 |
+
|
21 |
+
def check_env_variables():
|
22 |
+
log_update("\nChecking on environment variables...")
|
23 |
+
log_update(f"\tCUDA_VISIBLE_DEVICES: {os.environ.get('CUDA_VISIBLE_DEVICES')}")
|
24 |
+
log_update(f"\ttorch.cuda.device_count(): {torch.cuda.device_count()}")
|
25 |
+
for i in range(torch.cuda.device_count()):
|
26 |
+
log_update(f"\t\tDevice {i}: {torch.cuda.get_device_name(i)}")
|
27 |
+
|
28 |
+
def get_top_k_aa_mutations(all_probabilities, sequence, i, top_k_mutations, k=10):
|
29 |
+
"""
|
30 |
+
Should only return top AA mutations
|
31 |
+
"""
|
32 |
+
all_probs = pd.DataFrame.from_dict(all_probabilities, orient='index').reset_index()
|
33 |
+
all_probs = all_probs.sort_values(by=0,ascending=False).reset_index(drop=True)
|
34 |
+
top_k_mutation = all_probs['index'].tolist()[0:k]
|
35 |
+
top_k_mutation = ",".join(top_k_mutation)
|
36 |
+
top_k_mutations[(sequence[i], i)] = (top_k_mutation, all_probabilities)
|
37 |
+
|
38 |
+
return top_k_mutations
|
39 |
+
|
40 |
+
def get_top_k_mutations(tokenizer, mask_token_logits, all_probabilities, sequence, i, top_k_mutations, k=3):
|
41 |
+
top_k_tokens = torch.topk(mask_token_logits, k, dim=1).indices[0].tolist()
|
42 |
+
top_k_mutation = []
|
43 |
+
for token in top_k_tokens:
|
44 |
+
replaced_text = tokenizer.decode([token])
|
45 |
+
top_k_mutation.append(replaced_text)
|
46 |
+
|
47 |
+
top_k_mutation = ",".join(top_k_mutation)
|
48 |
+
top_k_mutations[(sequence[i], i)] = (top_k_mutation, all_probabilities)
|
49 |
+
|
50 |
+
def predict_positionwise_mutations(model, tokenizer, device, sequence):
|
51 |
+
log_update("\t\tPredicting position-wise mutations...")
|
52 |
+
top_10_mutations = {}
|
53 |
+
decoded_full_sequence = ''
|
54 |
+
mut_count = 0
|
55 |
+
|
56 |
+
# Mask and unmask sequentially
|
57 |
+
for i in range(len(sequence)):
|
58 |
+
log_update(f"\t\t\t- pos {i+1}/{len(sequence)}")
|
59 |
+
all_probabilities = {} # stored probabilities of each AA at this position
|
60 |
+
|
61 |
+
# Mask JUST the current position
|
62 |
+
masked_seq = sequence[:i] + '<mask>' + sequence[i+1:]
|
63 |
+
inputs = tokenizer(masked_seq, return_tensors="pt", padding=True, truncation=True,max_length=2000)
|
64 |
+
inputs = {k: v.to(device) for k, v in inputs.items()}
|
65 |
+
|
66 |
+
# Forward pass
|
67 |
+
with torch.no_grad():
|
68 |
+
logits = model(**inputs).logits
|
69 |
+
|
70 |
+
# Find logits at masked positions (should just be 1!)
|
71 |
+
mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1]
|
72 |
+
mask_token_logits = logits[0, mask_token_index, :]
|
73 |
+
mask_token_probs = F.softmax(mask_token_logits, dim=-1)
|
74 |
+
|
75 |
+
# Collect probabilities for natural AAs (token IDs 4-23 inclusive)
|
76 |
+
for token_idx in range(4, 23 + 1):
|
77 |
+
token = mask_token_probs[0, token_idx]
|
78 |
+
replaced_text = tokenizer.decode([token_idx])
|
79 |
+
all_probabilities[replaced_text] = token.item()
|
80 |
+
|
81 |
+
# Isolate top n mutations
|
82 |
+
#get_top_k_mutations(tokenizer, mask_token_logits, all_probabilities, sequence, i, top_10_mutations, k=10)
|
83 |
+
get_top_k_aa_mutations(all_probabilities, sequence, i, top_10_mutations, k=10)
|
84 |
+
|
85 |
+
# Building whole decoded sequence with top 1 token
|
86 |
+
top_1_tokens = torch.topk(mask_token_logits, 1, dim=1).indices[0].item()
|
87 |
+
new_residue = tokenizer.decode([top_1_tokens])
|
88 |
+
decoded_full_sequence += new_residue
|
89 |
+
|
90 |
+
# Check how many mutations in total
|
91 |
+
if sequence[i] != new_residue:
|
92 |
+
mut_count += 1
|
93 |
+
|
94 |
+
# Convert results into DataFrame
|
95 |
+
original_residues = []
|
96 |
+
top10_mutations = []
|
97 |
+
positions = []
|
98 |
+
all_logits = []
|
99 |
+
|
100 |
+
for (original_residue, position), (top10, probs) in top_10_mutations.items():
|
101 |
+
original_residues.append(original_residue)
|
102 |
+
top10_mutations.append(top10)
|
103 |
+
positions.append(position+1) # originally this line was "position" but it should be position + 1
|
104 |
+
all_logits.append(probs)
|
105 |
+
|
106 |
+
df = pd.DataFrame({
|
107 |
+
'Original Residue': original_residues,
|
108 |
+
'Position': positions,
|
109 |
+
'Top 10 Mutations': top10_mutations,
|
110 |
+
'All Probabilities': all_logits,
|
111 |
+
})
|
112 |
+
df['Top Mutation'] = df['Top 10 Mutations'].apply(lambda x: x.split(',')[0])
|
113 |
+
df['Top 3 Mutations'] = df['Top 10 Mutations'].apply(lambda x: ','.join(x.split(',')[0:3]))
|
114 |
+
df['Top 4 Mutations'] = df['Top 10 Mutations'].apply(lambda x: ','.join(x.split(',')[0:4]))
|
115 |
+
df['Top 5 Mutations'] = df['Top 10 Mutations'].apply(lambda x: ','.join(x.split(',')[0:5]))
|
116 |
+
|
117 |
+
return df, decoded_full_sequence, mut_count
|
118 |
+
|
119 |
+
def evaluate_literature_mut_performance(predicted_mutations_df, literature_mutations_df, decoded_full_sequence, mut_count, sequence="", focus_region_start=0, focus_region_end=0, offset=0):
|
120 |
+
"""
|
121 |
+
Given a dataframe of predicted mutations and literature mutations, see how well the predicted mutations did
|
122 |
+
"""
|
123 |
+
log_update("\t\tComparing predicted mutations to literature-provided mutations")
|
124 |
+
return_df = predicted_mutations_df.copy(deep=True)
|
125 |
+
return_df['Literature Mutation'] = [np.nan]*len(return_df)
|
126 |
+
return_df['Top 1 Hit'] = [np.nan]*len(return_df)
|
127 |
+
return_df['Top 3 Hit'] = [np.nan]*len(return_df)
|
128 |
+
return_df['Top 4 Hit'] = [np.nan]*len(return_df)
|
129 |
+
return_df['Top 5 Hit'] = [np.nan]*len(return_df)
|
130 |
+
return_df['Top 10 Hit'] = [np.nan]*len(return_df)
|
131 |
+
|
132 |
+
log_update(f"\tFormula: new position = {focus_region_start} + lit_position - {offset}")
|
133 |
+
# Iterate through the literature mutations rows
|
134 |
+
for i, row in literature_mutations_df.iterrows():
|
135 |
+
lit_position = row['Position']
|
136 |
+
lit_mutations = row['Mutation']
|
137 |
+
original_residue = row['Original Residue']
|
138 |
+
seq_position = focus_region_start + (lit_position - offset) # find position of the sequence
|
139 |
+
|
140 |
+
matching_row = return_df[return_df['Position'] == seq_position]
|
141 |
+
matching_row_index = matching_row.index
|
142 |
+
matching_residue = matching_row.iloc[0]['Original Residue']
|
143 |
+
match = original_residue==matching_residue
|
144 |
+
log_update(f"\tLit pos: {lit_position}, OG residue: {original_residue}, Full sequence pos: {seq_position}, Full sequence residue: {matching_residue}\n\t\tMatch: {match}")
|
145 |
+
|
146 |
+
# Iterate through the matching rows. We are at the right spot if we have the right original residue.
|
147 |
+
if match:
|
148 |
+
top_mutation = matching_row.iloc[0]['Top Mutation'] # get top 3 mutations
|
149 |
+
top_mutation = top_mutation.split(',')
|
150 |
+
print(top_mutation)
|
151 |
+
return_df.loc[matching_row_index, 'Literature Mutation'] = lit_mutations # get desired mutation
|
152 |
+
# If we got any of the mutatios reported in the literature, hit!
|
153 |
+
if any(letter in lit_mutations for letter in top_mutation):
|
154 |
+
return_df.loc[matching_row_index, 'Top 1 Hit'] = True
|
155 |
+
else:
|
156 |
+
return_df.loc[matching_row_index, 'Top 1 Hit'] = False
|
157 |
+
|
158 |
+
for k in [3,4,5,10]:
|
159 |
+
top_k_mutations = matching_row.iloc[0][f'Top {k} Mutations'] # get top 3 mutations
|
160 |
+
top_k_mutations = top_k_mutations.split(",")
|
161 |
+
print(top_k_mutations)
|
162 |
+
return_df.loc[matching_row_index, 'Literature Mutation'] = lit_mutations # get desired mutation
|
163 |
+
# If we got any of the mutatios reported in the literature, hit!
|
164 |
+
if any(letter in lit_mutations for letter in top_k_mutations):
|
165 |
+
return_df.loc[matching_row_index, f'Top {k} Hit'] = True
|
166 |
+
else:
|
167 |
+
return_df.loc[matching_row_index, f'Top {k} Hit'] = False
|
168 |
+
|
169 |
+
return return_df, (decoded_full_sequence, mut_count, (mut_count/len(sequence)) * 100)
|
170 |
+
|
171 |
+
def evaluate_eml4_alk(model, tokenizer, device, model_str):
|
172 |
+
alk_muts = pd.read_csv("alk_mutations.csv")
|
173 |
+
decoded_full_sequence, mut_count = None, None
|
174 |
+
|
175 |
+
EML4_ALK_SEQ = np.nan ## not publicly available
|
176 |
+
cons_domain_alk = np.nan # no publicly available
|
177 |
+
focus_region_start = EML4_ALK_SEQ.find(cons_domain_alk)
|
178 |
+
|
179 |
+
if os.path.isfile(f"eml4_alk_mutations/{model_str}/mutated_df.csv"):
|
180 |
+
log_update(f"Mutation predictions for {model_str} have already been calculated. Loading from eml4_alk_mutations/{model_str}/mutated_df.csv")
|
181 |
+
mutated_df = pd.read_csv(f"eml4_alk_mutations/{model_str}/mutated_df.csv")
|
182 |
+
mutated_summary = pd.read_csv(f"eml4_alk_mutations/{model_str}/mutated_summary.csv")
|
183 |
+
decoded_full_sequence = mutated_summary['decoded_full_sequence'][0]
|
184 |
+
mut_count = mutated_summary['mut_count'][0]
|
185 |
+
else:
|
186 |
+
mutated_df, decoded_full_sequence, mut_count = predict_positionwise_mutations(model, tokenizer, device, EML4_ALK_SEQ)
|
187 |
+
mutated_summary = pd.DataFrame(data={'decoded_full_sequence':[decoded_full_sequence],'mut_count':[mut_count]})
|
188 |
+
mutated_df.to_csv(f"eml4_alk_mutations/{model_str}/mutated_df.csv",index=False)
|
189 |
+
mutated_summary.to_csv(f"eml4_alk_mutations/{model_str}/mutated_summary.csv",index=False)
|
190 |
+
|
191 |
+
lit_performance_df, (mut_seq, mut_count, mut_pcnt) = evaluate_literature_mut_performance(mutated_df, alk_muts, decoded_full_sequence, mut_count,
|
192 |
+
sequence=EML4_ALK_SEQ,
|
193 |
+
focus_region_start=focus_region_start,
|
194 |
+
focus_region_end = focus_region_start + len(cons_domain_alk),
|
195 |
+
offset=1115 # original: 1116
|
196 |
+
)
|
197 |
+
|
198 |
+
return lit_performance_df, (mut_seq, mut_count, mut_pcnt)
|
199 |
+
|
200 |
+
def evaluate_bcr_abl(model, tokenizer, device, model_str):
|
201 |
+
abl_muts = pd.read_csv("abl_mutations.csv")
|
202 |
+
decoded_full_sequence, mut_count = None, None
|
203 |
+
|
204 |
+
BCR_ABL_SEQ = np.nan ## not publicly available
|
205 |
+
cons_domain_abl = np.nan ## not publicly available
|
206 |
+
focus_region_start = BCR_ABL_SEQ.find(cons_domain_abl)
|
207 |
+
|
208 |
+
if os.path.isfile(f"bcr_abl_mutations/{model_str}/mutated_df.csv"):
|
209 |
+
log_update(f"Mutation predictions for {model_str} have already been calculated. Loading from bcr_abl_mutations/{model_str}/mutated_df.csv")
|
210 |
+
mutated_df = pd.read_csv(f"bcr_abl_mutations/{model_str}/mutated_df.csv")
|
211 |
+
mutated_summary = pd.read_csv(f"bcr_abl_mutations/{model_str}/mutated_summary.csv")
|
212 |
+
decoded_full_sequence = mutated_summary['decoded_full_sequence'][0]
|
213 |
+
mut_count = mutated_summary['mut_count'][0]
|
214 |
+
else:
|
215 |
+
mutated_df, decoded_full_sequence, mut_count = predict_positionwise_mutations(model, tokenizer, device, BCR_ABL_SEQ)
|
216 |
+
mutated_summary = pd.DataFrame(data={'decoded_full_sequence':[decoded_full_sequence],'mut_count':[mut_count]})
|
217 |
+
mutated_df.to_csv(f"bcr_abl_mutations/{model_str}/mutated_df.csv",index=False)
|
218 |
+
mutated_summary.to_csv(f"bcr_abl_mutations/{model_str}/mutated_summary.csv",index=False)
|
219 |
+
|
220 |
+
lit_performance_df, (mut_seq, mut_count, mut_pcnt) = evaluate_literature_mut_performance(mutated_df, abl_muts, decoded_full_sequence, mut_count,
|
221 |
+
sequence=BCR_ABL_SEQ,
|
222 |
+
focus_region_start=focus_region_start,
|
223 |
+
focus_region_end = focus_region_start + len(cons_domain_abl),
|
224 |
+
offset=241 # original: 242
|
225 |
+
)
|
226 |
+
|
227 |
+
return lit_performance_df, (mut_seq, mut_count, mut_pcnt)
|
228 |
+
|
229 |
+
def summarize_individual_performance(performance_df, path_to_lit_df):
|
230 |
+
"""
|
231 |
+
performance_df = dataframe with stats on performance
|
232 |
+
path_to_lit_df = original dataframe
|
233 |
+
"""
|
234 |
+
# Load original df
|
235 |
+
lit_muts = pd.read_csv(path_to_lit_df)
|
236 |
+
|
237 |
+
# Mutated Sequence,Original Residue,Position,Top 3 Mutations,Literature Mutation,Hit,All Probabilities
|
238 |
+
mut_rows = performance_df.loc[performance_df['Literature Mutation'].notna()].reset_index(drop=True)
|
239 |
+
mut_rows = mut_rows[['Original Residue','Position','Literature Mutation',
|
240 |
+
'Top Mutation','Top 1 Hit',
|
241 |
+
'Top 3 Mutations','Top 3 Hit',
|
242 |
+
'Top 4 Mutations','Top 4 Hit',
|
243 |
+
'Top 5 Mutations','Top 5 Hit',
|
244 |
+
'Top 10 Mutations','Top 10 Hit'
|
245 |
+
]]
|
246 |
+
|
247 |
+
mut_rows_str = mut_rows.to_string(index=False)
|
248 |
+
mut_rows_str = "\t\t" + mut_rows_str.replace("\n","\n\t\t")
|
249 |
+
log_update(f"\tPerformance on all mutated positions shown below:\n{mut_rows_str}")
|
250 |
+
|
251 |
+
# Summarize: total hits, percentage of hits
|
252 |
+
total_original_muts = len(lit_muts)
|
253 |
+
for k in [1,3,4,5,10]:
|
254 |
+
total_hits = len(mut_rows.loc[mut_rows[f'Top {k} Hit']==True])
|
255 |
+
total_misses = len(mut_rows.loc[mut_rows[f'Top {k} Hit']==False])
|
256 |
+
total_potential_muts = total_hits+total_misses
|
257 |
+
hit_pcnt = round(100*total_hits/total_potential_muts, 2)
|
258 |
+
miss_pcnt = round(100*total_misses/total_potential_muts, 2)
|
259 |
+
|
260 |
+
log_update(f"\tTotal positions tested / total positions mutated in literature: {total_potential_muts}/{total_original_muts}")
|
261 |
+
log_update(f"\n\t\tTop {k} hit performance:\n\t\t\tHits:{total_hits} ({hit_pcnt}%)\n\t\t\tMisses:{total_misses} ({miss_pcnt}%)")
|
262 |
+
|
263 |
+
def main():
|
264 |
+
os.makedirs('results',exist_ok=True)
|
265 |
+
output_dir = f'results/{get_local_time()}'
|
266 |
+
os.makedirs(output_dir,exist_ok=True)
|
267 |
+
os.makedirs("bcr_abl_mutations",exist_ok=True)
|
268 |
+
os.makedirs("eml4_alk_mutations",exist_ok=True)
|
269 |
+
with open_logfile(f"{output_dir}/mutation_discovery_log.txt"):
|
270 |
+
print_configpy(config)
|
271 |
+
|
272 |
+
# Make sure environment variables are set correctly
|
273 |
+
check_env_variables()
|
274 |
+
|
275 |
+
# Get device
|
276 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
277 |
+
log_update(f"Using device: {device}")
|
278 |
+
|
279 |
+
# Load fuson
|
280 |
+
fuson_ckpt_path = config.FUSON_PLM_CKPT
|
281 |
+
if fuson_ckpt_path=="FusOn-pLM":
|
282 |
+
fuson_ckpt_path="../../../.."
|
283 |
+
model_name = "fuson_plm"
|
284 |
+
model_epoch = "best"
|
285 |
+
model_str = f"fuson_plm/best"
|
286 |
+
else:
|
287 |
+
model_name = list(fuson_ckpt_path.keys())[0]
|
288 |
+
epoch = list(fuson_ckpt_path.values())[0]
|
289 |
+
fuson_ckpt_path = f'../../training/checkpoints/{model_name}/checkpoint_epoch_{epoch}'
|
290 |
+
model_name, model_epoch = fuson_ckpt_path.split('/')[-2::]
|
291 |
+
model_epoch = model_epoch.split('checkpoint_')[-1]
|
292 |
+
model_str = f"{model_name}/{model_epoch}"
|
293 |
+
|
294 |
+
log_update(f"\nLoading FusOn-pLM model from {fuson_ckpt_path}")
|
295 |
+
fuson_tokenizer = AutoTokenizer.from_pretrained(fuson_ckpt_path)
|
296 |
+
fuson_model = AutoModelForMaskedLM.from_pretrained(fuson_ckpt_path)
|
297 |
+
fuson_model.to(device)
|
298 |
+
fuson_model.eval()
|
299 |
+
|
300 |
+
|
301 |
+
# Evaluate BCR::ABL performance with FusOn
|
302 |
+
os.makedirs(f"bcr_abl_mutations/{model_name}",exist_ok=True)
|
303 |
+
os.makedirs(f"bcr_abl_mutations/{model_name}/{model_epoch}",exist_ok=True)
|
304 |
+
log_update("\tEvaluating performance on BCR::ABL mutation prediction with FusOn")
|
305 |
+
abl_lit_performance_fuson, (mut_seq, mut_count, mut_pcnt) = evaluate_bcr_abl(fuson_model, fuson_tokenizer, device, model_str)
|
306 |
+
abl_lit_performance_fuson.to_csv(f'{output_dir}/BCR_ABL_mutation_recovery_fuson.csv', index = False)
|
307 |
+
with open(f'{output_dir}/BCR_ABL_mutation_recovery_fuson_summary.txt', 'w') as f:
|
308 |
+
f.write(mut_seq)
|
309 |
+
f.write(f'number of mutations: {mut_count}')
|
310 |
+
f.write(f'percentage of seq mutated: {mut_pcnt}')
|
311 |
+
|
312 |
+
# Evaluate EML4::ALK performance with Fuson
|
313 |
+
os.makedirs(f"eml4_alk_mutations/{model_name}",exist_ok=True)
|
314 |
+
os.makedirs(f"eml4_alk_mutations/{model_name}/{model_epoch}",exist_ok=True)
|
315 |
+
log_update("\tEvaluating performance on EML4::ALK mutation prediction with FusOn")
|
316 |
+
alk_lit_performance_fuson, (mut_seq, mut_count, mut_pcnt) = evaluate_eml4_alk(fuson_model, fuson_tokenizer, device, model_str)
|
317 |
+
alk_lit_performance_fuson.to_csv(f'{output_dir}/EML4_ALK_mutation_recovery_fuson.csv', index = False)
|
318 |
+
with open(f'{output_dir}/EML4_ALK_mutation_recovery_fuson_summary.txt', 'w') as f:
|
319 |
+
f.write(mut_seq)
|
320 |
+
f.write(f'number of mutations: {mut_count}')
|
321 |
+
f.write(f'percentage of seq mutated: {mut_pcnt}')
|
322 |
+
|
323 |
+
### Summarize
|
324 |
+
log_update("\nSummarizing FusOn-pLM performance on BCR::ABL")
|
325 |
+
summarize_individual_performance(abl_lit_performance_fuson, "abl_mutations.csv")
|
326 |
+
log_update("\nSummarizing FusOn-pLM performance on EML4::ALK")
|
327 |
+
summarize_individual_performance(alk_lit_performance_fuson, "alk_mutations.csv")
|
328 |
+
|
329 |
+
if __name__ == "__main__":
|
330 |
+
main()
|
fuson_plm/benchmarking/mutation_prediction/recovery/results/final_public/BCR_ABL_mutation_recovery_fuson_mutated_pns_only.csv
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:779bac32bf18f7efdf843120274bdcd0ee2dd1a940655b7f1af1bc0d640bda1a
|
3 |
+
size 18267
|
fuson_plm/benchmarking/mutation_prediction/recovery/results/final_public/EML4_ALK_mutation_recovery_fuson_mutated_pns_only.csv
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2fdc821e3e303aae226567cea331437eed8af96db43ebbd7487b7a48767de51e
|
3 |
+
size 9375
|
fuson_plm/benchmarking/mutation_prediction/recovery/results/final_public/Supplementary Tables - EML4 ALK Mutations.csv
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:bbc2b69ac1e403b060cda54b753c0e8cd9c0bbac1c64cdf38128aaa23081e12d
|
3 |
+
size 709
|
fuson_plm/benchmarking/mutation_prediction/recovery/results/final_public/Supplementary Tables - BCR ABL Mutations.csv
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a100f40575539ae5009e0ee87aac73b4b7701ee108cacd417585a8e02155e5c1
|
3 |
+
size 1380
|
fuson_plm/benchmarking/puncta/README.md
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
## Puncta Prediction Benchmark
|
2 |
|
3 |
-
This folder contains all the data and code needed to perform the **puncta prediction benchmark** (Figure 3).
|
4 |
|
5 |
### From raw data to train/test splits
|
6 |
To train the puncta predictors, we processed raw data from FOdb [(Tripathi et al. 2023)](https://doi.org/10.1038/s41467-023-41655-2) Supplementary dataset 4 (`fuson_plm/data/raw_data/FOdb_puncta.csv`) and Supplementary dataset 5 (`fuson_plm/data/raw_data/FODb_SD5.csv`) using the file `clean.py` in the `puncta` directory.
|
|
|
1 |
## Puncta Prediction Benchmark
|
2 |
|
3 |
+
This folder contains all the data and code needed to train FusOn-pLM-Puncta models and perform the **puncta prediction benchmark** (Figure 3).
|
4 |
|
5 |
### From raw data to train/test splits
|
6 |
To train the puncta predictors, we processed raw data from FOdb [(Tripathi et al. 2023)](https://doi.org/10.1038/s41467-023-41655-2) Supplementary dataset 4 (`fuson_plm/data/raw_data/FOdb_puncta.csv`) and Supplementary dataset 5 (`fuson_plm/data/raw_data/FODb_SD5.csv`) using the file `clean.py` in the `puncta` directory.
|