dependencies and embedding_exploration benchmark
Browse files- fuson_plm/README.md +588 -0
- fuson_plm/benchmarking/README.md +11 -0
- fuson_plm/benchmarking/__init__.py +0 -0
- fuson_plm/benchmarking/embed.py +296 -0
- fuson_plm/benchmarking/embedding_exploration/README.md +58 -0
- fuson_plm/benchmarking/embedding_exploration/__init__.py +0 -0
- fuson_plm/benchmarking/embedding_exploration/config.py +10 -0
- fuson_plm/benchmarking/embedding_exploration/data/salokas_2020_tableS3.csv +3 -0
- fuson_plm/benchmarking/embedding_exploration/data/tf_and_kinase_fusions.csv +3 -0
- fuson_plm/benchmarking/embedding_exploration/data/top_genes.csv +3 -0
- fuson_plm/benchmarking/embedding_exploration/plot.py +496 -0
- fuson_plm/benchmarking/embedding_exploration/results/final/umap_plots/fuson_plm/best/favorites/umap_favorites_source_data.csv +3 -0
- fuson_plm/benchmarking/embedding_exploration/results/final/umap_plots/fuson_plm/best/favorites/umap_favorites_visualization.png +0 -0
- fuson_plm/benchmarking/embedding_exploration/results/final/umap_plots/fuson_plm/best/tf_and_kinase/umap_tf_and_kinase_fusions_source_data.csv +3 -0
- fuson_plm/benchmarking/embedding_exploration/results/final/umap_plots/fuson_plm/best/tf_and_kinase/umap_tf_and_kinase_fusions_visualization.png +0 -0
- fuson_plm/benchmarking/mutation_prediction/README.md +1 -1
- fuson_plm/benchmarking/puncta/train.py +1 -1
- fuson_plm/benchmarking/xgboost_predictor.py +65 -0
fuson_plm/README.md
ADDED
@@ -0,0 +1,588 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Dependencies
|
2 |
+
|
3 |
+
Here we provied package versions needed to run FusOn-pLM code. For the project, Docker containers were used. We provide a pip list of what is inside the Docker container, as well as the images used for our containers.
|
4 |
+
|
5 |
+
## pip installs
|
6 |
+
|
7 |
+
The following dependencies were used for all training and benchmarking except for the `puncta` benchmarks.
|
8 |
+
Note that after cloning the repository, you will need to run `pip install e .` outside the `fuson_plm` directory to install `fuson_plm` package.
|
9 |
+
|
10 |
+
Package Version Editable project location
|
11 |
+
------------------------- -------------------- -------------------------
|
12 |
+
absl-py 1.4.0
|
13 |
+
aiohttp 3.8.4
|
14 |
+
aiosignal 1.3.1
|
15 |
+
apex 0.1
|
16 |
+
argon2-cffi 21.3.0
|
17 |
+
argon2-cffi-bindings 21.2.0
|
18 |
+
asttokens 2.2.1
|
19 |
+
astunparse 1.6.3
|
20 |
+
async-timeout 4.0.2
|
21 |
+
attrs 23.1.0
|
22 |
+
audioread 3.0.0
|
23 |
+
backcall 0.2.0
|
24 |
+
beautifulsoup4 4.12.2
|
25 |
+
bio 1.7.1
|
26 |
+
biopython 1.84
|
27 |
+
biothings-client 0.3.1
|
28 |
+
bleach 6.0.0
|
29 |
+
blis 0.7.10
|
30 |
+
cachetools 5.3.1
|
31 |
+
catalogue 2.0.9
|
32 |
+
certifi 2023.7.22
|
33 |
+
cffi 1.15.1
|
34 |
+
charset-normalizer 3.2.0
|
35 |
+
click 8.1.5
|
36 |
+
cloudpickle 2.2.1
|
37 |
+
cmake 3.27.1
|
38 |
+
comm 0.1.4
|
39 |
+
confection 0.1.1
|
40 |
+
contourpy 1.1.0
|
41 |
+
cubinlinker 0.3.0+2.g7c3675e
|
42 |
+
cuda-python 12.1.0rc5+1.g994d8d0
|
43 |
+
cudf 23.6.0
|
44 |
+
cugraph 23.6.0
|
45 |
+
cugraph-dgl 23.6.0
|
46 |
+
cugraph-service-client 23.6.0
|
47 |
+
cugraph-service-server 23.6.0
|
48 |
+
cuml 23.6.0
|
49 |
+
cupy-cuda12x 12.1.0
|
50 |
+
cycler 0.11.0
|
51 |
+
cymem 2.0.7
|
52 |
+
Cython 3.0.0
|
53 |
+
dask 2023.3.2
|
54 |
+
dask-cuda 23.6.0
|
55 |
+
dask-cudf 23.6.0
|
56 |
+
debugpy 1.6.7
|
57 |
+
decorator 5.1.1
|
58 |
+
defusedxml 0.7.1
|
59 |
+
distributed 2023.3.2.1
|
60 |
+
dm-tree 0.1.8
|
61 |
+
docker-pycreds 0.4.0
|
62 |
+
einops 0.6.1
|
63 |
+
exceptiongroup 1.1.2
|
64 |
+
execnet 2.0.2
|
65 |
+
executing 1.2.0
|
66 |
+
expecttest 0.1.3
|
67 |
+
fair-esm 2.0.0
|
68 |
+
fastjsonschema 2.18.0
|
69 |
+
fastrlock 0.8.1
|
70 |
+
filelock 3.12.2
|
71 |
+
flash-attn 2.0.4
|
72 |
+
fonttools 4.42.0
|
73 |
+
frozenlist 1.4.0
|
74 |
+
fsspec 2023.6.0
|
75 |
+
fuson-plm 1.0 /workspace/FusOn-pLM
|
76 |
+
gast 0.5.4
|
77 |
+
gdown 5.2.0
|
78 |
+
gitdb 4.0.11
|
79 |
+
GitPython 3.1.43
|
80 |
+
google-auth 2.22.0
|
81 |
+
google-auth-oauthlib 0.4.6
|
82 |
+
gprofiler-official 1.0.0
|
83 |
+
graphsurgeon 0.4.6
|
84 |
+
grpcio 1.56.2
|
85 |
+
huggingface-hub 0.25.2
|
86 |
+
hypothesis 5.35.1
|
87 |
+
idna 3.4
|
88 |
+
importlib-metadata 6.8.0
|
89 |
+
iniconfig 2.0.0
|
90 |
+
intel-openmp 2021.4.0
|
91 |
+
ipykernel 6.25.0
|
92 |
+
ipython 8.14.0
|
93 |
+
ipython-genutils 0.2.0
|
94 |
+
jedi 0.19.0
|
95 |
+
Jinja2 3.1.2
|
96 |
+
joblib 1.3.1
|
97 |
+
json5 0.9.14
|
98 |
+
jsonschema 4.18.6
|
99 |
+
jsonschema-specifications 2023.7.1
|
100 |
+
jupyter_client 8.3.0
|
101 |
+
jupyter_core 5.3.1
|
102 |
+
jupyter-tensorboard 0.2.0
|
103 |
+
jupyterlab 2.3.2
|
104 |
+
jupyterlab-pygments 0.2.2
|
105 |
+
jupyterlab-server 1.2.0
|
106 |
+
jupytext 1.15.0
|
107 |
+
kiwisolver 1.4.4
|
108 |
+
langcodes 3.3.0
|
109 |
+
librosa 0.9.2
|
110 |
+
lightning-utilities 0.11.8
|
111 |
+
llvmlite 0.40.1
|
112 |
+
locket 1.0.0
|
113 |
+
Markdown 3.4.4
|
114 |
+
markdown-it-py 3.0.0
|
115 |
+
MarkupSafe 2.1.3
|
116 |
+
matplotlib 3.7.2
|
117 |
+
matplotlib-inline 0.1.6
|
118 |
+
mdit-py-plugins 0.4.0
|
119 |
+
mdurl 0.1.2
|
120 |
+
mistune 3.0.1
|
121 |
+
mkl 2021.1.1
|
122 |
+
mkl-devel 2021.1.1
|
123 |
+
mkl-include 2021.1.1
|
124 |
+
mock 5.1.0
|
125 |
+
mpmath 1.3.0
|
126 |
+
msgpack 1.0.5
|
127 |
+
multidict 6.0.4
|
128 |
+
murmurhash 1.0.9
|
129 |
+
mygene 3.2.2
|
130 |
+
nbclient 0.8.0
|
131 |
+
nbconvert 7.7.3
|
132 |
+
nbformat 5.9.2
|
133 |
+
nest-asyncio 1.5.7
|
134 |
+
networkx 2.6.3
|
135 |
+
ninja 1.11.1
|
136 |
+
notebook 6.4.10
|
137 |
+
numba 0.57.1+1.gc785c8f1f
|
138 |
+
numpy 1.22.2
|
139 |
+
nvidia-cublas-cu12 12.4.5.8
|
140 |
+
nvidia-cuda-cupti-cu12 12.4.127
|
141 |
+
nvidia-cuda-nvrtc-cu12 12.4.127
|
142 |
+
nvidia-cuda-runtime-cu12 12.4.127
|
143 |
+
nvidia-cudnn-cu12 9.1.0.70
|
144 |
+
nvidia-cufft-cu12 11.2.1.3
|
145 |
+
nvidia-curand-cu12 10.3.5.147
|
146 |
+
nvidia-cusolver-cu12 11.6.1.9
|
147 |
+
nvidia-cusparse-cu12 12.3.1.170
|
148 |
+
nvidia-dali-cuda120 1.28.0
|
149 |
+
nvidia-nccl-cu12 2.21.5
|
150 |
+
nvidia-nvjitlink-cu12 12.4.127
|
151 |
+
nvidia-nvtx-cu12 12.4.127
|
152 |
+
nvidia-pyindex 1.0.9
|
153 |
+
nvtx 0.2.5
|
154 |
+
oauthlib 3.2.2
|
155 |
+
onnx 1.14.0
|
156 |
+
opencv 4.7.0
|
157 |
+
packaging 23.1
|
158 |
+
pandas 1.5.2
|
159 |
+
pandocfilters 1.5.0
|
160 |
+
parso 0.8.3
|
161 |
+
partd 1.4.0
|
162 |
+
pathy 0.10.2
|
163 |
+
pexpect 4.8.0
|
164 |
+
pickleshare 0.7.5
|
165 |
+
Pillow 9.2.0
|
166 |
+
pip 23.2.1
|
167 |
+
platformdirs 3.10.0
|
168 |
+
pluggy 1.2.0
|
169 |
+
ply 3.11
|
170 |
+
polygraphy 0.47.1
|
171 |
+
pooch 1.7.0
|
172 |
+
preshed 3.0.8
|
173 |
+
prettytable 3.8.0
|
174 |
+
prometheus-client 0.17.1
|
175 |
+
prompt-toolkit 3.0.39
|
176 |
+
protobuf 4.21.12
|
177 |
+
psutil 5.9.4
|
178 |
+
ptxcompiler 0.8.1+1.g4a94326
|
179 |
+
ptyprocess 0.7.0
|
180 |
+
pure-eval 0.2.2
|
181 |
+
py3Dmol 2.4.0
|
182 |
+
pyarrow 11.0.0
|
183 |
+
pyasn1 0.5.0
|
184 |
+
pyasn1-modules 0.3.0
|
185 |
+
pybind11 2.11.1
|
186 |
+
pycocotools 2.0+nv0.7.3
|
187 |
+
pycparser 2.21
|
188 |
+
pydantic 1.10.12
|
189 |
+
Pygments 2.16.1
|
190 |
+
pylibcugraph 23.6.0
|
191 |
+
pylibcugraphops 23.6.0
|
192 |
+
pylibraft 23.6.0
|
193 |
+
pynndescent 0.5.13
|
194 |
+
pynvml 11.4.1
|
195 |
+
pyparsing 3.0.9
|
196 |
+
PySocks 1.7.1
|
197 |
+
pytest 7.4.0
|
198 |
+
pytest-flakefinder 1.1.0
|
199 |
+
pytest-rerunfailures 12.0
|
200 |
+
pytest-shard 0.1.2
|
201 |
+
pytest-xdist 3.3.1
|
202 |
+
python-dateutil 2.8.2
|
203 |
+
python-hostlist 1.23.0
|
204 |
+
pytorch-lightning 2.4.0
|
205 |
+
pytorch-quantization 2.1.2
|
206 |
+
pytz 2023.3
|
207 |
+
PyYAML 6.0.1
|
208 |
+
pyzmq 25.1.0
|
209 |
+
raft-dask 23.6.0
|
210 |
+
referencing 0.30.2
|
211 |
+
regex 2023.6.3
|
212 |
+
requests 2.31.0
|
213 |
+
requests-oauthlib 1.3.1
|
214 |
+
resampy 0.4.2
|
215 |
+
rmm 23.6.0
|
216 |
+
rpds-py 0.9.2
|
217 |
+
rsa 4.9
|
218 |
+
safetensors 0.4.5
|
219 |
+
scikit-learn 1.2.0
|
220 |
+
scipy 1.11.1
|
221 |
+
seaborn 0.13.2
|
222 |
+
Send2Trash 1.8.2
|
223 |
+
sentencepiece 0.2.0
|
224 |
+
sentry-sdk 2.16.0
|
225 |
+
setproctitle 1.3.3
|
226 |
+
setuptools 68.0.0
|
227 |
+
six 1.16.0
|
228 |
+
smart-open 6.3.0
|
229 |
+
smmap 5.0.1
|
230 |
+
sortedcontainers 2.4.0
|
231 |
+
soundfile 0.12.1
|
232 |
+
soupsieve 2.4.1
|
233 |
+
spacy 3.6.0
|
234 |
+
spacy-legacy 3.0.12
|
235 |
+
spacy-loggers 1.0.4
|
236 |
+
sphinx-glpi-theme 0.3
|
237 |
+
srsly 2.4.7
|
238 |
+
stack-data 0.6.2
|
239 |
+
sympy 1.13.1
|
240 |
+
tabulate 0.9.0
|
241 |
+
tbb 2021.10.0
|
242 |
+
tblib 2.0.0
|
243 |
+
tensorboard 2.9.0
|
244 |
+
tensorboard-data-server 0.6.1
|
245 |
+
tensorboard-plugin-wit 1.8.1
|
246 |
+
tensorrt 8.6.1
|
247 |
+
terminado 0.17.1
|
248 |
+
thinc 8.1.10
|
249 |
+
threadpoolctl 3.2.0
|
250 |
+
thriftpy2 0.4.16
|
251 |
+
tinycss2 1.2.1
|
252 |
+
tokenizers 0.20.1
|
253 |
+
toml 0.10.2
|
254 |
+
tomli 2.0.1
|
255 |
+
toolz 0.12.0
|
256 |
+
torch 2.5.0
|
257 |
+
torch-tensorrt 2.0.0.dev0
|
258 |
+
torchdata 0.7.0a0
|
259 |
+
torchmetrics 1.5.0
|
260 |
+
torchtext 0.16.0a0
|
261 |
+
torchvision 0.16.0a0
|
262 |
+
tornado 6.3.2
|
263 |
+
tqdm 4.65.0
|
264 |
+
traitlets 5.9.0
|
265 |
+
transformer-engine 0.11.0+3f01b4f
|
266 |
+
transformers 4.45.2
|
267 |
+
treelite 3.2.0
|
268 |
+
treelite-runtime 3.2.0
|
269 |
+
triton 3.1.0
|
270 |
+
typer 0.9.0
|
271 |
+
types-dataclasses 0.6.6
|
272 |
+
typing_extensions 4.12.2
|
273 |
+
ucx-py 0.32.0
|
274 |
+
uff 0.6.9
|
275 |
+
umap-learn 0.5.6
|
276 |
+
urllib3 1.26.16
|
277 |
+
wandb 0.18.3
|
278 |
+
wasabi 1.1.2
|
279 |
+
wcwidth 0.2.6
|
280 |
+
webencodings 0.5.1
|
281 |
+
Werkzeug 2.3.6
|
282 |
+
wheel 0.41.1
|
283 |
+
xdoctest 1.0.2
|
284 |
+
xgboost 1.7.5
|
285 |
+
yarl 1.9.2
|
286 |
+
zict 3.0.0
|
287 |
+
zipp 3.16.2
|
288 |
+
|
289 |
+
The following packages and versions were used for the `puncta` benchmarks. A different environment was required to run ProtT5.
|
290 |
+
|
291 |
+
Package Version Editable project location
|
292 |
+
------------------------- -------------------------- -------------------------
|
293 |
+
absl-py 2.1.0
|
294 |
+
aiohttp 3.9.3
|
295 |
+
aiosignal 1.3.1
|
296 |
+
annotated-types 0.6.0
|
297 |
+
anyio 4.8.0
|
298 |
+
apex 0.1
|
299 |
+
argon2-cffi 23.1.0
|
300 |
+
argon2-cffi-bindings 21.2.0
|
301 |
+
asttokens 2.4.1
|
302 |
+
astunparse 1.6.3
|
303 |
+
async-timeout 4.0.3
|
304 |
+
attrs 23.2.0
|
305 |
+
audioread 3.0.1
|
306 |
+
beautifulsoup4 4.12.3
|
307 |
+
bio 1.7.1
|
308 |
+
biopython 1.85
|
309 |
+
biothings_client 0.4.1
|
310 |
+
bleach 6.1.0
|
311 |
+
blis 0.7.11
|
312 |
+
cachetools 5.3.3
|
313 |
+
catalogue 2.0.10
|
314 |
+
certifi 2024.2.2
|
315 |
+
cffi 1.16.0
|
316 |
+
charset-normalizer 3.3.2
|
317 |
+
click 8.1.7
|
318 |
+
cloudpathlib 0.16.0
|
319 |
+
cloudpickle 3.0.0
|
320 |
+
cmake 3.29.0.1
|
321 |
+
comm 0.2.2
|
322 |
+
confection 0.1.4
|
323 |
+
contourpy 1.2.1
|
324 |
+
cuda-python 12.4.0rc7+3.ge75c8a9.dirty
|
325 |
+
cudf 24.2.0
|
326 |
+
cudnn 1.1.2
|
327 |
+
cugraph 24.2.0
|
328 |
+
cugraph-dgl 24.2.0
|
329 |
+
cugraph-service-client 24.2.0
|
330 |
+
cugraph-service-server 24.2.0
|
331 |
+
cuml 24.2.0
|
332 |
+
cupy-cuda12x 13.0.0
|
333 |
+
cycler 0.12.1
|
334 |
+
cymem 2.0.8
|
335 |
+
Cython 3.0.10
|
336 |
+
dask 2024.1.1
|
337 |
+
dask-cuda 24.2.0
|
338 |
+
dask-cudf 24.2.0
|
339 |
+
debugpy 1.8.1
|
340 |
+
decorator 5.1.1
|
341 |
+
defusedxml 0.7.1
|
342 |
+
distributed 2024.1.1
|
343 |
+
dm-tree 0.1.8
|
344 |
+
docker-pycreds 0.4.0
|
345 |
+
einops 0.7.0
|
346 |
+
exceptiongroup 1.2.0
|
347 |
+
execnet 2.0.2
|
348 |
+
executing 2.0.1
|
349 |
+
expecttest 0.1.3
|
350 |
+
fair-esm 2.0.0
|
351 |
+
fastjsonschema 2.19.1
|
352 |
+
fastrlock 0.8.2
|
353 |
+
filelock 3.13.3
|
354 |
+
flash-attn 2.4.2
|
355 |
+
fonttools 4.51.0
|
356 |
+
frozenlist 1.4.1
|
357 |
+
fsspec 2024.2.0
|
358 |
+
fuson-plm 1.0 /workspace/FusOn-pLM
|
359 |
+
gast 0.5.4
|
360 |
+
gdown 5.2.0
|
361 |
+
gitdb 4.0.12
|
362 |
+
GitPython 3.1.44
|
363 |
+
google-auth 2.29.0
|
364 |
+
google-auth-oauthlib 0.4.6
|
365 |
+
gprofiler-official 1.0.0
|
366 |
+
graphsurgeon 0.4.6
|
367 |
+
grpcio 1.62.1
|
368 |
+
h11 0.14.0
|
369 |
+
httpcore 1.0.7
|
370 |
+
httpx 0.28.1
|
371 |
+
huggingface-hub 0.27.1
|
372 |
+
hypothesis 5.35.1
|
373 |
+
idna 3.6
|
374 |
+
igraph 0.11.4
|
375 |
+
importlib_metadata 7.0.2
|
376 |
+
iniconfig 2.0.0
|
377 |
+
intel-openmp 2021.4.0
|
378 |
+
ipykernel 6.29.4
|
379 |
+
ipython 8.21.0
|
380 |
+
ipython-genutils 0.2.0
|
381 |
+
jedi 0.19.1
|
382 |
+
Jinja2 3.1.3
|
383 |
+
joblib 1.3.2
|
384 |
+
json5 0.9.24
|
385 |
+
jsonschema 4.21.1
|
386 |
+
jsonschema-specifications 2023.12.1
|
387 |
+
jupyter_client 8.6.1
|
388 |
+
jupyter_core 5.7.2
|
389 |
+
jupyter-tensorboard 0.2.0
|
390 |
+
jupyterlab 2.3.2
|
391 |
+
jupyterlab_pygments 0.3.0
|
392 |
+
jupyterlab-server 1.2.0
|
393 |
+
jupytext 1.16.1
|
394 |
+
kiwisolver 1.4.5
|
395 |
+
langcodes 3.3.0
|
396 |
+
lark 1.1.9
|
397 |
+
lazy_loader 0.4
|
398 |
+
librosa 0.10.1
|
399 |
+
lightning-thunder 0.1.0
|
400 |
+
lightning-utilities 0.11.2
|
401 |
+
llvmlite 0.42.0
|
402 |
+
locket 1.0.0
|
403 |
+
looseversion 1.3.0
|
404 |
+
Markdown 3.6
|
405 |
+
markdown-it-py 3.0.0
|
406 |
+
MarkupSafe 2.1.5
|
407 |
+
matplotlib 3.8.4
|
408 |
+
matplotlib-inline 0.1.6
|
409 |
+
mdit-py-plugins 0.4.0
|
410 |
+
mdurl 0.1.2
|
411 |
+
mistune 3.0.2
|
412 |
+
mkl 2021.1.1
|
413 |
+
mkl-devel 2021.1.1
|
414 |
+
mkl-include 2021.1.1
|
415 |
+
mock 5.1.0
|
416 |
+
mpmath 1.3.0
|
417 |
+
msgpack 1.0.8
|
418 |
+
multidict 6.0.5
|
419 |
+
murmurhash 1.0.10
|
420 |
+
mygene 3.2.2
|
421 |
+
nbclient 0.10.0
|
422 |
+
nbconvert 7.16.3
|
423 |
+
nbformat 5.10.4
|
424 |
+
nest-asyncio 1.6.0
|
425 |
+
networkx 2.6.3
|
426 |
+
ninja 1.11.1.1
|
427 |
+
notebook 6.4.10
|
428 |
+
numba 0.59.0+1.g20ae2b56c
|
429 |
+
numpy 1.24.4
|
430 |
+
nvfuser 0.1.6a0+a684e2a
|
431 |
+
nvidia-dali-cuda120 1.36.0
|
432 |
+
nvidia-nvimgcodec-cu12 0.2.0.7
|
433 |
+
nvidia-pyindex 1.0.9
|
434 |
+
nvtx 0.2.5
|
435 |
+
oauthlib 3.2.2
|
436 |
+
onnx 1.16.0
|
437 |
+
opencv 4.7.0
|
438 |
+
opt-einsum 3.3.0
|
439 |
+
optree 0.11.0
|
440 |
+
packaging 23.2
|
441 |
+
pandas 1.5.3
|
442 |
+
pandocfilters 1.5.1
|
443 |
+
parso 0.8.4
|
444 |
+
partd 1.4.1
|
445 |
+
pexpect 4.9.0
|
446 |
+
pillow 10.2.0
|
447 |
+
pip 24.0
|
448 |
+
platformdirs 4.2.0
|
449 |
+
pluggy 1.4.0
|
450 |
+
ply 3.11
|
451 |
+
polygraphy 0.49.8
|
452 |
+
pooch 1.8.1
|
453 |
+
preshed 3.0.9
|
454 |
+
prettytable 3.10.0
|
455 |
+
prometheus_client 0.20.0
|
456 |
+
prompt-toolkit 3.0.43
|
457 |
+
protobuf 4.24.4
|
458 |
+
psutil 5.9.4
|
459 |
+
ptyprocess 0.7.0
|
460 |
+
pure-eval 0.2.2
|
461 |
+
py3Dmol 2.4.2
|
462 |
+
pyarrow 14.0.1
|
463 |
+
pyasn1 0.6.0
|
464 |
+
pyasn1_modules 0.4.0
|
465 |
+
pybind11 2.12.0
|
466 |
+
pybind11_global 2.12.0
|
467 |
+
pycocotools 2.0+nv0.8.0
|
468 |
+
pycparser 2.22
|
469 |
+
pydantic 2.6.4
|
470 |
+
pydantic_core 2.16.3
|
471 |
+
Pygments 2.17.2
|
472 |
+
pylibcugraph 24.2.0
|
473 |
+
pylibcugraphops 24.2.0
|
474 |
+
pylibraft 24.2.0
|
475 |
+
pynndescent 0.5.13
|
476 |
+
pynvjitlink 0.1.13
|
477 |
+
pynvml 11.4.1
|
478 |
+
pyparsing 3.1.2
|
479 |
+
PySocks 1.7.1
|
480 |
+
pytest 8.1.1
|
481 |
+
pytest-flakefinder 1.1.0
|
482 |
+
pytest-rerunfailures 14.0
|
483 |
+
pytest-shard 0.1.2
|
484 |
+
pytest-xdist 3.5.0
|
485 |
+
python-dateutil 2.9.0.post0
|
486 |
+
python-hostlist 1.23.0
|
487 |
+
pytorch-lightning 2.5.0.post0
|
488 |
+
pytorch-quantization 2.1.2
|
489 |
+
pytorch-triton 3.0.0+a9bc1a364
|
490 |
+
pytz 2024.1
|
491 |
+
PyYAML 6.0.1
|
492 |
+
pyzmq 25.1.2
|
493 |
+
raft-dask 24.2.0
|
494 |
+
rapids-dask-dependency 24.2.0a0
|
495 |
+
referencing 0.34.0
|
496 |
+
regex 2023.12.25
|
497 |
+
requests 2.31.0
|
498 |
+
requests-oauthlib 2.0.0
|
499 |
+
rich 13.7.1
|
500 |
+
rmm 24.2.0
|
501 |
+
rpds-py 0.18.0
|
502 |
+
rsa 4.9
|
503 |
+
safetensors 0.5.2
|
504 |
+
scikit-learn 1.2.0
|
505 |
+
scipy 1.12.0
|
506 |
+
seaborn 0.13.2
|
507 |
+
Send2Trash 1.8.2
|
508 |
+
sentencepiece 0.2.0
|
509 |
+
sentry-sdk 2.20.0
|
510 |
+
setproctitle 1.3.4
|
511 |
+
setuptools 68.2.2
|
512 |
+
six 1.16.0
|
513 |
+
smart-open 6.4.0
|
514 |
+
smmap 5.0.2
|
515 |
+
sniffio 1.3.1
|
516 |
+
sortedcontainers 2.4.0
|
517 |
+
soundfile 0.12.1
|
518 |
+
soupsieve 2.5
|
519 |
+
soxr 0.3.7
|
520 |
+
spacy 3.7.4
|
521 |
+
spacy-legacy 3.0.12
|
522 |
+
spacy-loggers 1.0.5
|
523 |
+
sphinx_glpi_theme 0.6
|
524 |
+
srsly 2.4.8
|
525 |
+
stack-data 0.6.3
|
526 |
+
sympy 1.12
|
527 |
+
tabulate 0.9.0
|
528 |
+
tbb 2021.12.0
|
529 |
+
tblib 3.0.0
|
530 |
+
tensorboard 2.9.0
|
531 |
+
tensorboard-data-server 0.6.1
|
532 |
+
tensorboard-plugin-wit 1.8.1
|
533 |
+
tensorrt 8.6.3
|
534 |
+
terminado 0.18.1
|
535 |
+
texttable 1.7.0
|
536 |
+
thinc 8.2.3
|
537 |
+
threadpoolctl 3.3.0
|
538 |
+
thriftpy2 0.4.17
|
539 |
+
tinycss2 1.2.1
|
540 |
+
tokenizers 0.21.0
|
541 |
+
toml 0.10.2
|
542 |
+
tomli 2.0.1
|
543 |
+
toolz 0.12.1
|
544 |
+
torch 2.3.0a0+6ddf5cf85e.nv24.4
|
545 |
+
torch-tensorrt 2.3.0a0
|
546 |
+
torchdata 0.7.1a0
|
547 |
+
torchmetrics 1.6.1
|
548 |
+
torchtext 0.17.0a0
|
549 |
+
torchvision 0.18.0a0
|
550 |
+
tornado 6.4
|
551 |
+
tqdm 4.66.2
|
552 |
+
traitlets 5.9.0
|
553 |
+
transformer-engine 1.5.0+6a9edc3
|
554 |
+
transformers 4.48.0
|
555 |
+
treelite 4.0.0
|
556 |
+
typer 0.9.4
|
557 |
+
types-dataclasses 0.6.6
|
558 |
+
typing_extensions 4.10.0
|
559 |
+
ucx-py 0.36.0
|
560 |
+
uff 0.6.9
|
561 |
+
umap-learn 0.5.7
|
562 |
+
urllib3 1.26.18
|
563 |
+
wandb 0.19.4
|
564 |
+
wasabi 1.1.2
|
565 |
+
wcwidth 0.2.13
|
566 |
+
weasel 0.3.4
|
567 |
+
webencodings 0.5.1
|
568 |
+
Werkzeug 3.0.2
|
569 |
+
wheel 0.43.0
|
570 |
+
xdoctest 1.0.2
|
571 |
+
xgboost 1.7.5
|
572 |
+
yarl 1.9.4
|
573 |
+
zict 3.0.0
|
574 |
+
zipp 3.17.0
|
575 |
+
|
576 |
+
## Docker
|
577 |
+
|
578 |
+
The following image was used for Container 1 (all code except puncta benchmark):
|
579 |
+
|
580 |
+
```
|
581 |
+
nvcr.io/nvidia/pytorch:23.08-py3
|
582 |
+
```
|
583 |
+
|
584 |
+
The following image was used for Container 2 (puncta benchmark):
|
585 |
+
|
586 |
+
```
|
587 |
+
nvcr.io/nvidia/pytorch:24.04-py3
|
588 |
+
```
|
fuson_plm/benchmarking/README.md
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Benchmarking
|
2 |
+
|
3 |
+
This outer directory for the benchmarks in FusOn-pLM has some utility functions stored in `.py` files.
|
4 |
+
|
5 |
+
### embed.py
|
6 |
+
|
7 |
+
This file contains functions used to make and organize FusOn-pLM and ESM embeddings of benchmarking datasets. Its functions are used in all benchmarks.
|
8 |
+
|
9 |
+
### xgboost_predictor.py
|
10 |
+
|
11 |
+
This file contains functions used to train XGBoost predictors, which are utilized in the `puncta` benchmark.
|
fuson_plm/benchmarking/__init__.py
ADDED
File without changes
|
fuson_plm/benchmarking/embed.py
ADDED
@@ -0,0 +1,296 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Python file for making embeddings from a FusOn-pLM model for any dataset
|
2 |
+
from fuson_plm.utils.embedding import get_esm_embeddings, load_esm2_type, redump_pickle_dictionary, load_prott5, get_prott5_embeddings
|
3 |
+
from fuson_plm.utils.logging import log_update, open_logfile, print_configpy
|
4 |
+
from fuson_plm.utils.data_cleaning import find_invalid_chars
|
5 |
+
from fuson_plm.utils.constants import VALID_AAS
|
6 |
+
from fuson_plm.training.model import FusOnpLM
|
7 |
+
from transformers import AutoModelForMaskedLM, AutoTokenizer, AutoModel
|
8 |
+
import logging
|
9 |
+
import torch
|
10 |
+
import pickle
|
11 |
+
import os
|
12 |
+
import pandas as pd
|
13 |
+
import numpy as np
|
14 |
+
|
15 |
+
def validate_sequence_col(df, seq_col):
|
16 |
+
# if column isn't there, error
|
17 |
+
if seq_col not in list(df.columns):
|
18 |
+
raise Exception("Error: provided sequence column does not exist in the input dataframe")
|
19 |
+
|
20 |
+
# if column contains invalid characters, error
|
21 |
+
df['invalid_chars'] = df[seq_col].apply(lambda x: find_invalid_chars(x, VALID_AAS))
|
22 |
+
all_invalid_chars = set().union(*df['invalid_chars'])
|
23 |
+
df = df.drop(columns=['invalid_chars'])
|
24 |
+
if len(all_invalid_chars)>0:
|
25 |
+
raise Exception(f"Error: invalid characters {all_invalid_chars} found in the sequence column")
|
26 |
+
|
27 |
+
# make sure there are no duplicates
|
28 |
+
sequences = df[seq_col]
|
29 |
+
if len(set(sequences))<len(sequences): log_update("\tWARNING: input data has duplicate sequences")
|
30 |
+
|
31 |
+
def load_fuson_model(ckpt_path):
|
32 |
+
# Suppress warnings about newly initialized 'esm.pooler.dense.bias', 'esm.pooler.dense.weight' layers - these are not used to extract embeddings
|
33 |
+
logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR)
|
34 |
+
|
35 |
+
# Set device
|
36 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
37 |
+
print(f"Using device: {device}")
|
38 |
+
|
39 |
+
# Load model
|
40 |
+
model = AutoModel.from_pretrained(ckpt_path) # initialize model
|
41 |
+
tokenizer = AutoTokenizer.from_pretrained(ckpt_path) # initialize tokenizer
|
42 |
+
|
43 |
+
# Model to device and in eval mode
|
44 |
+
model.to(device)
|
45 |
+
model.eval() # disables dropout for deterministic results
|
46 |
+
|
47 |
+
return model, tokenizer, device
|
48 |
+
|
49 |
+
def get_fuson_embeddings(model, tokenizer, sequences, device, average=True, print_updates=False, savepath=None, save_at_end=False, max_length=2000):
|
50 |
+
# Correct save path to pickle if necessary
|
51 |
+
if savepath is not None:
|
52 |
+
if savepath[-4::] != '.pkl': savepath += '.pkl'
|
53 |
+
|
54 |
+
if print_updates: log_update(f"Dataset contains {len(sequences)} sequences.")
|
55 |
+
|
56 |
+
# If no max length was passed, just set it to the maximum in the dataset
|
57 |
+
max_seq_len = max([len(s) for s in sequences])
|
58 |
+
if max_length is None: max_length=max_seq_len+2 # add 2 for BOS, EOS
|
59 |
+
|
60 |
+
# Initialize an empty dict to store the ESM embeddings
|
61 |
+
embedding_dict = {}
|
62 |
+
# Iterate through the seqs
|
63 |
+
for i in range(len(sequences)):
|
64 |
+
sequence = sequences[i]
|
65 |
+
# Get the embeddings
|
66 |
+
with torch.no_grad():
|
67 |
+
# Tokenize the input sequence
|
68 |
+
inputs = tokenizer(sequence, return_tensors="pt", padding=True, truncation=True,max_length=max_length)
|
69 |
+
inputs = {k: v.to(device) for k, v in inputs.items()}
|
70 |
+
|
71 |
+
outputs = model(**inputs)
|
72 |
+
# The embeddings are in the last_hidden_state tensor
|
73 |
+
embedding = outputs.last_hidden_state
|
74 |
+
# remove extra dimension
|
75 |
+
embedding = embedding.squeeze(0)
|
76 |
+
# remove BOS and EOS tokens
|
77 |
+
embedding = embedding[1:-1, :]
|
78 |
+
|
79 |
+
# Convert embeddings to numpy array (if needed)
|
80 |
+
embedding = embedding.cpu().numpy()
|
81 |
+
|
82 |
+
# Average (if necessary)
|
83 |
+
if average:
|
84 |
+
embedding = embedding.mean(0)
|
85 |
+
|
86 |
+
# Add to dictionary
|
87 |
+
embedding_dict[sequence] = embedding
|
88 |
+
|
89 |
+
# Save individual embedding (if necessary)
|
90 |
+
if not(savepath is None) and not(save_at_end):
|
91 |
+
with open(savepath, 'ab+') as f:
|
92 |
+
d = {sequence: embedding}
|
93 |
+
pickle.dump(d, f)
|
94 |
+
|
95 |
+
# Print update (if necessary)
|
96 |
+
if print_updates: log_update(f"sequence {i+1}: {sequence[0:10]}...")
|
97 |
+
|
98 |
+
# Dump all at once at the end (if necessary)
|
99 |
+
if not(savepath is None):
|
100 |
+
# If saving for the first time, just dump it
|
101 |
+
if save_at_end:
|
102 |
+
with open(savepath, 'wb') as f:
|
103 |
+
pickle.dump(embedding_dict, f)
|
104 |
+
# If we've been saving all along and made it here without crashing, correct the pickle file so it can be loaded nicely
|
105 |
+
else:
|
106 |
+
redump_pickle_dictionary(savepath)
|
107 |
+
|
108 |
+
def embed_dataset(path_to_file, path_to_output, seq_col='aa_seq', model_type='fuson_plm', fuson_ckpt_path = None, average=True, overwrite=True, print_updates=False,max_length=2000):
|
109 |
+
# Make sure we aren't overwriting pre-existing embeddings
|
110 |
+
if os.path.exists(path_to_output):
|
111 |
+
if overwrite:
|
112 |
+
log_update(f"WARNING: these embeddings may already exist at {path_to_output} and will be overwritten")
|
113 |
+
else:
|
114 |
+
log_update(f"WARNING: these embeddings may already exist at {path_to_output}. Skipping.")
|
115 |
+
return None
|
116 |
+
|
117 |
+
dataset = pd.read_csv(path_to_file)
|
118 |
+
# Make sure the sequence column is valid
|
119 |
+
validate_sequence_col(dataset, seq_col)
|
120 |
+
|
121 |
+
sequences = dataset[seq_col].unique().tolist() # ensure all entries are unique
|
122 |
+
|
123 |
+
### If FusOn-pLM: make fusion embeddings
|
124 |
+
if model_type=='fuson_plm':
|
125 |
+
if not(os.path.exists(fuson_ckpt_path)): raise Exception("FusOn-pLM ckpt path does not exist")
|
126 |
+
|
127 |
+
# Load model
|
128 |
+
try:
|
129 |
+
model, tokenizer, device = load_fuson_model(fuson_ckpt_path)
|
130 |
+
except:
|
131 |
+
raise Exception(f"Could not load FusOn-pLM from {fuson_ckpt_path}")
|
132 |
+
|
133 |
+
# Generate embeddigns
|
134 |
+
try:
|
135 |
+
get_fuson_embeddings(model, tokenizer, sequences, device, average=average,
|
136 |
+
print_updates=print_updates, savepath=path_to_output, save_at_end=False,
|
137 |
+
max_length=max_length)
|
138 |
+
except:
|
139 |
+
raise Exception("Could not generate FusOn-pLM embeddings")
|
140 |
+
|
141 |
+
if model_type=='esm2_t33_650M_UR50D':
|
142 |
+
# Load model
|
143 |
+
try:
|
144 |
+
model, tokenizer, device = load_esm2_type(model_type)
|
145 |
+
except:
|
146 |
+
raise Exception(f"Could not load {model_type}")
|
147 |
+
# Generate embeddings
|
148 |
+
try:
|
149 |
+
get_esm_embeddings(model, tokenizer, sequences, device, average=average,
|
150 |
+
print_updates=print_updates, savepath=path_to_output, save_at_end=False,
|
151 |
+
max_length=max_length)
|
152 |
+
except:
|
153 |
+
raise Exception(f"Could not generate {model_type} embeddings")
|
154 |
+
|
155 |
+
if model_type=="prot_t5_xl_half_uniref50_enc":
|
156 |
+
# Load model
|
157 |
+
try:
|
158 |
+
model, tokenizer, device = load_prott5()
|
159 |
+
except:
|
160 |
+
raise Exception(f"Could not load {model_type}")
|
161 |
+
# Generate embeddings
|
162 |
+
try:
|
163 |
+
get_prott5_embeddings(model, tokenizer, sequences, device, average=average,
|
164 |
+
print_updates=print_updates, savepath=path_to_output, save_at_end=False,
|
165 |
+
max_length=max_length)
|
166 |
+
except:
|
167 |
+
raise Exception(f"Could not generate {model_type} embeddings")
|
168 |
+
|
169 |
+
|
170 |
+
def embed_dataset_for_benchmark(fuson_ckpts=None, input_data_path=None, input_fname=None, average=True, seq_col='seq', benchmark_fusonplm=False, benchmark_esm=False, benchmark_fo_puncta_ml=False, benchmark_prott5=False, overwrite=False,max_length=None):
|
171 |
+
# make directory for embeddings inside benchmarking dataset if one doesn't already eist
|
172 |
+
os.makedirs('embeddings',exist_ok=True)
|
173 |
+
|
174 |
+
# Extract input file name from configs
|
175 |
+
emb_type_tag ='average' if average else '2D'
|
176 |
+
|
177 |
+
all_embedding_paths = dict() # dictionary organized where embedding path points to model, epoch
|
178 |
+
|
179 |
+
# make the embedding files. Put them in an embedding directory
|
180 |
+
if benchmark_fusonplm:
|
181 |
+
os.makedirs('embeddings/fuson_plm',exist_ok=True)
|
182 |
+
|
183 |
+
log_update(f"\nMaking Fuson-PLM embeddings")
|
184 |
+
# make subdirs for all the
|
185 |
+
if type(fuson_ckpts)==dict:
|
186 |
+
for model_name, epoch_list in fuson_ckpts.items():
|
187 |
+
os.makedirs(f'embeddings/fuson_plm/{model_name}',exist_ok=True)
|
188 |
+
for epoch in epoch_list:
|
189 |
+
# Assemble ckpt path and throw error if it doesn't exist
|
190 |
+
fuson_ckpt_path = f'../../training/checkpoints/{model_name}/checkpoint_epoch_{epoch}'
|
191 |
+
if not(os.path.exists(fuson_ckpt_path)): raise Exception(f"Error. Cannot find ckpt path: {fuson_ckpt_path}")
|
192 |
+
|
193 |
+
# Make output directory and output embedding path
|
194 |
+
embedding_output_dir = f'embeddings/fuson_plm/{model_name}/epoch{epoch}'
|
195 |
+
embedding_output_path = f'{embedding_output_dir}/{input_fname}_{emb_type_tag}_embeddings.pkl'
|
196 |
+
os.makedirs(embedding_output_dir,exist_ok=True)
|
197 |
+
|
198 |
+
# Make dictionary item
|
199 |
+
model_type = 'fuson_plm'
|
200 |
+
all_embedding_paths[embedding_output_path] = {
|
201 |
+
'model_type': model_type,
|
202 |
+
'model': model_name,
|
203 |
+
'epoch': epoch
|
204 |
+
}
|
205 |
+
|
206 |
+
# Create embeddings (or skip if they're already made)
|
207 |
+
log_update(f"\tUsing ckpt {fuson_ckpt_path} and saving results to {embedding_output_path}...")
|
208 |
+
embed_dataset(input_data_path, embedding_output_path,
|
209 |
+
seq_col=seq_col, model_type=model_type,
|
210 |
+
fuson_ckpt_path=fuson_ckpt_path, average=average,
|
211 |
+
overwrite=overwrite,print_updates=True,
|
212 |
+
max_length=max_length)
|
213 |
+
elif fuson_ckpts=="FusOn-pLM":
|
214 |
+
model_name = "best"
|
215 |
+
os.makedirs(f'embeddings/fuson_plm/{model_name}',exist_ok=True)
|
216 |
+
|
217 |
+
# Assemble ckpt path and throw error if it doesn't exist
|
218 |
+
fuson_ckpt_path = "../../.." # go back to the FusOn-pLM directory to find the best ckpt
|
219 |
+
if not(os.path.exists(fuson_ckpt_path)): raise Exception(f"Error. Cannot find ckpt path: {fuson_ckpt_path}")
|
220 |
+
|
221 |
+
# Make output directory and output embedding path
|
222 |
+
embedding_output_dir = f'embeddings/fuson_plm/{model_name}'
|
223 |
+
embedding_output_path = f'{embedding_output_dir}/{input_fname}_{emb_type_tag}_embeddings.pkl'
|
224 |
+
os.makedirs(embedding_output_dir,exist_ok=True)
|
225 |
+
|
226 |
+
# Make dictionary item
|
227 |
+
model_type = 'fuson_plm'
|
228 |
+
all_embedding_paths[embedding_output_path] = {
|
229 |
+
'model_type': model_type,
|
230 |
+
'model': model_name,
|
231 |
+
'epoch': None
|
232 |
+
}
|
233 |
+
|
234 |
+
# Create embeddings (or skip if they're already made)
|
235 |
+
log_update(f"\tUsing ckpt {fuson_ckpt_path} and saving results to {embedding_output_path}...")
|
236 |
+
embed_dataset(input_data_path, embedding_output_path,
|
237 |
+
seq_col=seq_col, model_type=model_type,
|
238 |
+
fuson_ckpt_path=fuson_ckpt_path, average=average,
|
239 |
+
overwrite=overwrite,print_updates=True,
|
240 |
+
max_length=max_length)
|
241 |
+
else:
|
242 |
+
raise Exception(f"Error. fuson_ckpts should be a dict or str")
|
243 |
+
|
244 |
+
# make the embedding files. Put them in an embedding directory
|
245 |
+
if benchmark_esm:
|
246 |
+
os.makedirs('embeddings/esm2_t33_650M_UR50D',exist_ok=True)
|
247 |
+
|
248 |
+
# make output path
|
249 |
+
embedding_output_path = f'embeddings/esm2_t33_650M_UR50D/{input_fname}_{emb_type_tag}_embeddings.pkl'
|
250 |
+
|
251 |
+
# Make dictioary item
|
252 |
+
model_type = 'esm2_t33_650M_UR50D'
|
253 |
+
all_embedding_paths[embedding_output_path] = {
|
254 |
+
'model_type': model_type,
|
255 |
+
'model': model_type,
|
256 |
+
'epoch': np.nan
|
257 |
+
}
|
258 |
+
|
259 |
+
log_update(f"\nMaking ESM-2-650M embeddings for {input_data_path} and saving results to {embedding_output_path}...")
|
260 |
+
embed_dataset(input_data_path, embedding_output_path,
|
261 |
+
seq_col=seq_col, model_type=model_type,
|
262 |
+
fuson_ckpt_path = None, average=average,
|
263 |
+
overwrite=overwrite,print_updates=True,
|
264 |
+
max_length=max_length)
|
265 |
+
|
266 |
+
if benchmark_prott5:
|
267 |
+
os.makedirs('embeddings/prot_t5_xl_half_uniref50_enc',exist_ok=True)
|
268 |
+
|
269 |
+
# make output path
|
270 |
+
embedding_output_path = f'embeddings/prot_t5_xl_half_uniref50_enc/{input_fname}_{emb_type_tag}_embeddings.pkl'
|
271 |
+
|
272 |
+
# Make dictioary item
|
273 |
+
model_type = 'prot_t5_xl_half_uniref50_enc'
|
274 |
+
all_embedding_paths[embedding_output_path] = {
|
275 |
+
'model_type': model_type,
|
276 |
+
'model': model_type,
|
277 |
+
'epoch': np.nan
|
278 |
+
}
|
279 |
+
|
280 |
+
log_update(f"\nMaking ProtT5-XL-UniRef50 embeddings for {input_data_path} and saving results to {embedding_output_path}...")
|
281 |
+
embed_dataset(input_data_path, embedding_output_path,
|
282 |
+
seq_col=seq_col, model_type=model_type,
|
283 |
+
fuson_ckpt_path = None, average=average,
|
284 |
+
overwrite=overwrite,print_updates=True,
|
285 |
+
max_length=max_length)
|
286 |
+
|
287 |
+
if benchmark_fo_puncta_ml:
|
288 |
+
embedding_output_path =f'FOdb_physicochemical_embeddings.pkl'
|
289 |
+
# Make dictionary item
|
290 |
+
all_embedding_paths[embedding_output_path] = {
|
291 |
+
'model_type': 'fo_puncta_ml',
|
292 |
+
'model': 'fo_puncta_ml',
|
293 |
+
'epoch': np.nan
|
294 |
+
}
|
295 |
+
|
296 |
+
return all_embedding_paths
|
fuson_plm/benchmarking/embedding_exploration/README.md
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Embedding exploration
|
2 |
+
|
3 |
+
This folder contains all the data and code needed to run embedding exploration (Fig. S3).
|
4 |
+
|
5 |
+
### Data download
|
6 |
+
To help select TF (transcription factor) and Kinase-containing fusions for investigation (Fig. S3a), Supplementary Table 3 from [Salokas et al. 2020](https://doi.org/10.1038/s41598-020-71040-8) was downloaded as a reference of transcription factors and kinases.
|
7 |
+
|
8 |
+
```
|
9 |
+
benchmarking/
|
10 |
+
└── embedding_exploration/
|
11 |
+
└── data/
|
12 |
+
├── salokas_2020_tableS3.csv
|
13 |
+
├── tf_and_kinase_fusions.csv
|
14 |
+
├── top_genes.csv
|
15 |
+
```
|
16 |
+
|
17 |
+
- **`data/salokas_2020_tableS3.csv`**: Supplementary Table 3 from [Salokas et al. 2020](https://doi.org/10.1038/s41598-020-71040-8)
|
18 |
+
- **`data/tf_and_kinase_fusions.csv`**: set of TF::TF and Kinase::Kinase fusion oncoproteins from FusOn-DB database. Curated in `plot.py`
|
19 |
+
- **`data/top_genes.csv`**: fusion oncoproteins (and their head and tail components) visualized in Fig. S3b. Sequences for head and tail components were pulled from the best-aligned sequences in `fuson_plm/data/blast/blast_outputs/best_htg_alignments_swissprot_seqs.pkl`
|
20 |
+
|
21 |
+
### Plotting
|
22 |
+
|
23 |
+
Run `plot.py` to regenerate plots in Figure S3:
|
24 |
+
|
25 |
+
```
|
26 |
+
# Dictionary: key = run name, values = epochs. (use this option if you've trained your own model)
|
27 |
+
# # Or "FusOn-pLM" to use official model
|
28 |
+
FUSON_PLM_CKPT= "FusOn-pLM"
|
29 |
+
|
30 |
+
# Type of dim reduction
|
31 |
+
PLOT_UMAP = True
|
32 |
+
PLOT_TSNE = False
|
33 |
+
|
34 |
+
# Overwriting configs
|
35 |
+
PERMISSION_TO_OVERWRITE = False # if False, script will halt if it believes these embeddings have already been made.
|
36 |
+
```
|
37 |
+
|
38 |
+
To run, use:
|
39 |
+
```
|
40 |
+
nohup python plot.py > plot.out 2> plot.err &
|
41 |
+
```
|
42 |
+
- All **results** are stored in `embedding_exploration/results/<timestamp>`, where `timestamp` is a unique string encoding the date and time when you started training.
|
43 |
+
|
44 |
+
Below are the FusOn-pLM paper results in `results/final/umap_plots/fuson_plm/best/`:
|
45 |
+
|
46 |
+
```
|
47 |
+
benchmarking/
|
48 |
+
└── embedding_exploration/
|
49 |
+
└── results/final/umap_plots/fuson_plm/best/
|
50 |
+
└── favorites/
|
51 |
+
├── umap_favorites_source_data.csv
|
52 |
+
├── umap_favorites_visualization.png
|
53 |
+
└── tf_and_kinase/
|
54 |
+
├── umap_tf_and_kinase_fusions_source_data.csv ├── umap_tf_and_kinase_fusions_visualization.png
|
55 |
+
```
|
56 |
+
|
57 |
+
- **`favorites/umap_favorites_visualization.png`**: Fig. S3b, with the data directly plotted stored in `favorites/umap_favorites_source_data.csv`
|
58 |
+
- **`tf_and_kinase/umap_tf_and_kinase_fusions_visualization.png`**: Fig. S3a, with the data directly plotted stored in `tf_and_kinase/umap_tf_and_kinase_fusions_source_data.csv`.
|
fuson_plm/benchmarking/embedding_exploration/__init__.py
ADDED
File without changes
|
fuson_plm/benchmarking/embedding_exploration/config.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Dictionary: key = run name, values = epochs. (use this option if you've trained your own model)
|
2 |
+
# # Or, List: item goes to path (use this option if you're using the "best" ckpt from FusOn-pLM paper)
|
3 |
+
FUSON_PLM_CKPT= "FusOn-pLM"
|
4 |
+
|
5 |
+
# Type of dim reduction
|
6 |
+
PLOT_UMAP = True
|
7 |
+
PLOT_TSNE = False
|
8 |
+
|
9 |
+
# Overwriting configs
|
10 |
+
PERMISSION_TO_OVERWRITE = False # if False, script will halt if it believes these embeddings have already been made.
|
fuson_plm/benchmarking/embedding_exploration/data/salokas_2020_tableS3.csv
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d8bebc0871a4329015a3c6c7843f5bbc86c48811b2a836c42f1ef46b37f4282a
|
3 |
+
size 19626
|
fuson_plm/benchmarking/embedding_exploration/data/tf_and_kinase_fusions.csv
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:372321137ed12b2f8aa7c4891dafd0e88d64d5c5d0ea9c6f3a0aa9d897e8ead6
|
3 |
+
size 557262
|
fuson_plm/benchmarking/embedding_exploration/data/top_genes.csv
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:33d568fe413107318caebd5ee260ee66fe8571461ed8f8d1b47888441f7b5034
|
3 |
+
size 16695
|
fuson_plm/benchmarking/embedding_exploration/plot.py
ADDED
@@ -0,0 +1,496 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import numpy as np
|
3 |
+
import pickle
|
4 |
+
from sklearn.manifold import TSNE
|
5 |
+
import matplotlib.font_manager as fm
|
6 |
+
from matplotlib.font_manager import FontProperties
|
7 |
+
import matplotlib.pyplot as plt
|
8 |
+
import matplotlib.gridspec as gridspec
|
9 |
+
import matplotlib.patches as patches
|
10 |
+
import seaborn as sns
|
11 |
+
import umap
|
12 |
+
import os
|
13 |
+
|
14 |
+
from fuson_plm.benchmarking.embed import embed_dataset_for_benchmark
|
15 |
+
import fuson_plm.benchmarking.embedding_exploration.config as config
|
16 |
+
from fuson_plm.utils.visualizing import set_font
|
17 |
+
from fuson_plm.utils.constants import TCGA_CODES, FODB_CODES, VALID_AAS, DELIMITERS
|
18 |
+
from fuson_plm.utils.logging import get_local_time, open_logfile, log_update, print_configpy
|
19 |
+
|
20 |
+
|
21 |
+
def get_dimred_embeddings(embeddings, dimred_type="umap"):
|
22 |
+
if dimred_type=="umap":
|
23 |
+
dimred_embeddings = get_umap_embeddings(embeddings)
|
24 |
+
return dimred_embeddings
|
25 |
+
if dimred_type=="tsne":
|
26 |
+
dimred_embeddings = get_tsne_embeddings(embeddings)
|
27 |
+
return dimred_embeddings
|
28 |
+
|
29 |
+
def get_tsne_embeddings(embeddings):
|
30 |
+
embeddings = np.array(embeddings)
|
31 |
+
tsne = TSNE(n_components=2, random_state=42,perplexity=5)
|
32 |
+
tsne_embeddings = tsne.fit_transform(embeddings)
|
33 |
+
return tsne_embeddings
|
34 |
+
|
35 |
+
def get_umap_embeddings(embeddings):
|
36 |
+
embeddings = np.array(embeddings)
|
37 |
+
umap_model = umap.UMAP(n_components=2, random_state=42, n_neighbors=15, metric='euclidean') # default parameters for UMAP
|
38 |
+
umap_embeddings = umap_model.fit_transform(embeddings)
|
39 |
+
return umap_embeddings
|
40 |
+
|
41 |
+
def plot_half_filled_circle(ax, x, y, left_color, right_color, size=100):
|
42 |
+
"""
|
43 |
+
Plots a circle filled in halves with specified colors.
|
44 |
+
|
45 |
+
Parameters:
|
46 |
+
- ax: Matplotlib axis to draw on.
|
47 |
+
- x, y: Coordinates of the marker.
|
48 |
+
- left_color: Color of the left half.
|
49 |
+
- right_color: Color of the right half.
|
50 |
+
- size: Size of the marker.
|
51 |
+
"""
|
52 |
+
radius = (size ** 0.5) / 100 # Scale the radius
|
53 |
+
# Create left half-circle (0° to 180°)
|
54 |
+
left_half = patches.Wedge((x, y), radius, 90, 270, color=left_color, ec="black")
|
55 |
+
# Create right half-circle (180° to 360°)
|
56 |
+
right_half = patches.Wedge((x, y), radius, 270, 90, color=right_color, ec="black")
|
57 |
+
|
58 |
+
# Add both halves to the plot
|
59 |
+
ax.add_patch(left_half)
|
60 |
+
ax.add_patch(right_half)
|
61 |
+
|
62 |
+
def plot_umap_scatter_tftf_kk(df, filename="umap.png"):
|
63 |
+
"""
|
64 |
+
Plots a 2D scatterplot of UMAP coordinates with different markers and colors based on 'type'.
|
65 |
+
Only for TF::TF and Kinase::Kinase fusions
|
66 |
+
|
67 |
+
Parameters:
|
68 |
+
- df (pd.DataFrame): DataFrame containing 'umap1', 'umap2', 'sequence', and 'type' columns.
|
69 |
+
"""
|
70 |
+
set_font()
|
71 |
+
|
72 |
+
# Define colors for each type
|
73 |
+
colors = {
|
74 |
+
"TF": "pink",
|
75 |
+
"Kinase": "orange"
|
76 |
+
}
|
77 |
+
|
78 |
+
# Define marker types and colors for each combination
|
79 |
+
marker_colors = {
|
80 |
+
"TF::TF": colors["TF"],
|
81 |
+
"Kinase::Kinase": colors["Kinase"],
|
82 |
+
}
|
83 |
+
|
84 |
+
# Create the plot
|
85 |
+
fig, ax = plt.subplots(figsize=(10, 8))
|
86 |
+
x_min, x_max = df["umap1"].min() - 1, df["umap1"].max() + 1
|
87 |
+
y_min, y_max = df["umap2"].min() - 1, df["umap2"].max() + 1
|
88 |
+
ax.set_xlim(x_min, x_max)
|
89 |
+
ax.set_ylim(y_min, y_max)
|
90 |
+
|
91 |
+
# Plot each point with the specified half-filled marker
|
92 |
+
for i in range(len(df)):
|
93 |
+
row = df.iloc[i]
|
94 |
+
marker_type = row["fusion_type"]
|
95 |
+
x, y = row["umap1"], row["umap2"]
|
96 |
+
color = marker_colors[marker_type]
|
97 |
+
|
98 |
+
ax.scatter(x, y, color=color, s=15, edgecolors="black", linewidth=0.5)
|
99 |
+
|
100 |
+
# Add custom legend
|
101 |
+
legend_elements = [
|
102 |
+
patches.Patch(facecolor="pink", edgecolor="black", label="TF::TF"),
|
103 |
+
patches.Patch(facecolor="orange", edgecolor="black", label="Kinase::Kinase")
|
104 |
+
]
|
105 |
+
ax.legend(handles=legend_elements, title="Fusion Type", fontsize=16, title_fontsize=16)
|
106 |
+
|
107 |
+
# Add labels and title
|
108 |
+
plt.xlabel("UMAP 1", fontsize=20)
|
109 |
+
plt.ylabel("UMAP 2", fontsize=20)
|
110 |
+
plt.title("FusOn-pLM-embedded Transcription Factor and Kinase Fusions", fontsize=20)
|
111 |
+
plt.tight_layout()
|
112 |
+
|
113 |
+
# Save and show the plot
|
114 |
+
plt.savefig(filename, dpi=300)
|
115 |
+
plt.show()
|
116 |
+
|
117 |
+
def plot_umap_scatter_half_filled(df, filename="umap.png"):
|
118 |
+
"""
|
119 |
+
Plots a 2D scatterplot of UMAP coordinates with different markers and colors based on 'type'.
|
120 |
+
|
121 |
+
Parameters:
|
122 |
+
- df (pd.DataFrame): DataFrame containing 'umap1', 'umap2', 'sequence', and 'type' columns.
|
123 |
+
"""
|
124 |
+
# Define colors for each type
|
125 |
+
colors = {
|
126 |
+
"TF": "pink",
|
127 |
+
"Kinase": "orange",
|
128 |
+
"Other": "grey"
|
129 |
+
}
|
130 |
+
|
131 |
+
# Define marker types and colors for each combination
|
132 |
+
marker_colors = {
|
133 |
+
"TF::TF": {"left": colors["TF"], "right": colors["TF"]},
|
134 |
+
"TF::Other": {"left": colors["TF"], "right": colors["Other"]},
|
135 |
+
"Other::TF": {"left": colors["Other"], "right": colors["TF"]},
|
136 |
+
"Kinase::Kinase": {"left": colors["Kinase"], "right": colors["Kinase"]},
|
137 |
+
"Kinase::Other": {"left": colors["Kinase"], "right": colors["Other"]},
|
138 |
+
"Other::Kinase": {"left": colors["Other"], "right": colors["Kinase"]},
|
139 |
+
"Kinase::TF": {"left": colors["Kinase"], "right": colors["TF"]},
|
140 |
+
"TF::Kinase": {"left": colors["TF"], "right": colors["Kinase"]},
|
141 |
+
"Other::Other": {"left": colors["Other"], "right": colors["Other"]}
|
142 |
+
}
|
143 |
+
|
144 |
+
# Create the plot
|
145 |
+
fig, ax = plt.subplots(figsize=(10, 8))
|
146 |
+
x_min, x_max = df["umap1"].min() - 1, df["umap1"].max() + 1
|
147 |
+
y_min, y_max = df["umap2"].min() - 1, df["umap2"].max() + 1
|
148 |
+
ax.set_xlim(x_min, x_max)
|
149 |
+
ax.set_ylim(y_min, y_max)
|
150 |
+
|
151 |
+
# Plot each point with the specified half-filled marker
|
152 |
+
for i in range(len(df)):
|
153 |
+
row = df.iloc[i]
|
154 |
+
marker_type = row["fusion_type"]
|
155 |
+
x, y = row["umap1"], row["umap2"]
|
156 |
+
left_color = marker_colors[marker_type]["left"]
|
157 |
+
right_color = marker_colors[marker_type]["right"]
|
158 |
+
plot_half_filled_circle(ax, x, y, left_color, right_color, size=100)
|
159 |
+
|
160 |
+
# Add custom legend
|
161 |
+
legend_elements = [
|
162 |
+
patches.Patch(facecolor="pink", edgecolor="black", label="TF"),
|
163 |
+
patches.Patch(facecolor="orange", edgecolor="black", label="Kinase"),
|
164 |
+
patches.Patch(facecolor="grey", edgecolor="black", label="Other")
|
165 |
+
]
|
166 |
+
ax.legend(handles=legend_elements, title="Type")
|
167 |
+
|
168 |
+
# Add labels and title
|
169 |
+
plt.xlabel("UMAP 1")
|
170 |
+
plt.ylabel("UMAP 2")
|
171 |
+
plt.title("UMAP Scatter Plot")
|
172 |
+
plt.tight_layout()
|
173 |
+
|
174 |
+
# Save and show the plot
|
175 |
+
plt.savefig(filename, dpi=300)
|
176 |
+
plt.show()
|
177 |
+
|
178 |
+
def get_gene_type(gene, d):
|
179 |
+
if gene in d:
|
180 |
+
if d[gene] == 'kinase':
|
181 |
+
return 'Kinase'
|
182 |
+
if d[gene] == 'tf':
|
183 |
+
return 'TF'
|
184 |
+
else:
|
185 |
+
return 'Other'
|
186 |
+
|
187 |
+
def get_tf_and_kinase_fusions_dataset():
|
188 |
+
# Load TF and Kinase Fusions
|
189 |
+
tf_kinase_parts = pd.read_csv("data/salokas_2020_tableS3.csv")
|
190 |
+
print(tf_kinase_parts)
|
191 |
+
ht_tf_kinase_dict = dict(zip(tf_kinase_parts['Gene'],tf_kinase_parts['Kinase or TF']))
|
192 |
+
|
193 |
+
# This one has each row with one fusiongene name
|
194 |
+
fuson_ht_db = pd.read_csv("../../data/blast/fuson_ht_db.csv")
|
195 |
+
fuson_ht_db[['hg','tg']] = fuson_ht_db['fusiongenes'].str.split("::",expand=True)
|
196 |
+
|
197 |
+
fuson_ht_db['hg_type'] = fuson_ht_db['hg'].apply(lambda x: get_gene_type(x, ht_tf_kinase_dict))
|
198 |
+
fuson_ht_db['tg_type'] = fuson_ht_db['tg'].apply(lambda x: get_gene_type(x, ht_tf_kinase_dict))
|
199 |
+
fuson_ht_db['fusion_type'] = fuson_ht_db['hg_type']+'::'+fuson_ht_db['tg_type']
|
200 |
+
fuson_ht_db['type']=['fusion']*len(fuson_ht_db)
|
201 |
+
# Keep 100 things in each category
|
202 |
+
categories = pd.DataFrame(fuson_ht_db['fusion_type'].value_counts()).reset_index()['index'].tolist()
|
203 |
+
categories = ["TF::TF","Kinase::Kinase"] # manually set some easier categories
|
204 |
+
print(categories)
|
205 |
+
plot_df = None
|
206 |
+
|
207 |
+
for i, cat in enumerate(categories):
|
208 |
+
random_sample = fuson_ht_db.loc[fuson_ht_db['fusion_type']==cat].reset_index(drop=True)
|
209 |
+
#random_sample = random_sample.sample(n=100, random_state=1).reset_index(drop=True)
|
210 |
+
if i==0:
|
211 |
+
plot_df = random_sample
|
212 |
+
else:
|
213 |
+
plot_df = pd.concat([plot_df,random_sample],axis=0).reset_index(drop=True)
|
214 |
+
|
215 |
+
print(plot_df['fusion_type'].value_counts())
|
216 |
+
|
217 |
+
# Now, need to add in the embeddings
|
218 |
+
plot_df = plot_df[['aa_seq','fusiongenes','fusion_type','type']].rename(
|
219 |
+
columns={'aa_seq':'sequence','fusiongenes':'ID'}
|
220 |
+
)
|
221 |
+
|
222 |
+
return plot_df
|
223 |
+
|
224 |
+
def make_tf_and_kinase_fusions_plot(seqs_with_embeddings, savedir = '', dimred_type='umap'):
|
225 |
+
fuson_db = pd.read_csv("../../data/fuson_db.csv")
|
226 |
+
seq_id_dict = dict(zip(fuson_db['aa_seq'],fuson_db['seq_id']))
|
227 |
+
|
228 |
+
# add sequences so we can save results/sequence
|
229 |
+
data = seqs_with_embeddings[[f'{dimred_type}1',f'{dimred_type}2','sequence','fusion_type','ID']]
|
230 |
+
data['seq_id'] = data['sequence'].map(seq_id_dict)
|
231 |
+
|
232 |
+
tfkinase_save_dir = f"{savedir}"
|
233 |
+
os.makedirs(tfkinase_save_dir,exist_ok=True)
|
234 |
+
data.to_csv(f"{tfkinase_save_dir}/{dimred_type}_tf_and_kinase_fusions_source_data.csv",index=False)
|
235 |
+
plot_umap_scatter_tftf_kk(data,filename=f"{tfkinase_save_dir}/{dimred_type}_tf_and_kinase_fusions_visualization.png")
|
236 |
+
|
237 |
+
def tf_and_kinase_fusions_plot(dimred_types, output_dir):
|
238 |
+
"""
|
239 |
+
Makes the embeddings, THEN calls the plot. only on the four favorites
|
240 |
+
"""
|
241 |
+
plot_df = get_tf_and_kinase_fusions_dataset()
|
242 |
+
plot_df.to_csv("data/tf_and_kinase_fusions.csv",index=False)
|
243 |
+
|
244 |
+
# path to the pkl file with FOdb embeddings
|
245 |
+
input_fname='tf_and_kinase'
|
246 |
+
all_embedding_paths = embed_dataset_for_benchmark(
|
247 |
+
fuson_ckpts=config.FUSON_PLM_CKPT,
|
248 |
+
input_data_path='data/tf_and_kinase_fusions.csv', input_fname=input_fname,
|
249 |
+
average=True, seq_col='sequence',
|
250 |
+
benchmark_fusonplm=True,
|
251 |
+
benchmark_esm=False,
|
252 |
+
benchmark_fo_puncta_ml=False,
|
253 |
+
overwrite=config.PERMISSION_TO_OVERWRITE)
|
254 |
+
|
255 |
+
# For each of the models we are benchmarking, load embeddings and make plots
|
256 |
+
log_update("\nEmbedding sequences")
|
257 |
+
# loop through the embedding paths and train each one
|
258 |
+
for embedding_path, details in all_embedding_paths.items():
|
259 |
+
log_update(f"\tBenchmarking embeddings at: {embedding_path}")
|
260 |
+
try:
|
261 |
+
with open(embedding_path, "rb") as f:
|
262 |
+
embeddings = pickle.load(f)
|
263 |
+
except:
|
264 |
+
raise Exception(f"Cannot read embeddings from {embedding_path}")
|
265 |
+
|
266 |
+
# combine the embeddings and splits into one dataframe
|
267 |
+
seqs_with_embeddings = pd.DataFrame.from_dict(embeddings.items())
|
268 |
+
seqs_with_embeddings = seqs_with_embeddings.rename(columns={0: 'sequence', 1: 'embedding'}) # the column that was called FusOn-pLM is now called embedding
|
269 |
+
seqs_with_embeddings = pd.merge(seqs_with_embeddings, plot_df, on='sequence', how='inner')
|
270 |
+
# get UMAP transform of the embeddings
|
271 |
+
for dimred_type in dimred_types:
|
272 |
+
dimred_embeddings = get_dimred_embeddings(seqs_with_embeddings['embedding'].tolist(),dimred_type=dimred_type)
|
273 |
+
|
274 |
+
# turn the result into a dataframe, and add it to seqs_with_embeddings
|
275 |
+
data = pd.DataFrame(dimred_embeddings, columns=[f'{dimred_type}1', f'{dimred_type}2'])
|
276 |
+
# save the umap data!
|
277 |
+
model_name = "_".join(embedding_path.split('embeddings/')[1].split('/')[1:-1])
|
278 |
+
|
279 |
+
seqs_with_embeddings[[f'{dimred_type}1', f'{dimred_type}2']] = data
|
280 |
+
|
281 |
+
# make subdirectory
|
282 |
+
intermediate = '/'.join(embedding_path.split('embeddings/')[1].split('/')[0:-1])
|
283 |
+
cur_output_dir = f"{output_dir}/{dimred_type}_plots/{intermediate}/{input_fname}"
|
284 |
+
|
285 |
+
os.makedirs(cur_output_dir,exist_ok=True)
|
286 |
+
make_tf_and_kinase_fusions_plot(seqs_with_embeddings, savedir = cur_output_dir, dimred_type=dimred_type)
|
287 |
+
|
288 |
+
def make_fusion_v_parts_favorites_plot(seqs_with_embeddings, savedir = None, dimred_type='umap'):
|
289 |
+
"""
|
290 |
+
Make plots showing that PAX3::FOXO1, EWS::FLI1, SS18::SSX1, EML4::ALK are embedded distinctly from their heads and tails
|
291 |
+
"""
|
292 |
+
set_font()
|
293 |
+
|
294 |
+
# Load one sequence each for four proteins in the test set: PAX3::FOXO1, EWS::FLI1, SS18::SSX1, EML4::ALK
|
295 |
+
data = pd.read_csv("data/top_genes.csv")
|
296 |
+
seqs_with_embeddings = pd.merge(seqs_with_embeddings, data, on="sequence")
|
297 |
+
seqs_with_embeddings["Type"] = [""]*len(seqs_with_embeddings)
|
298 |
+
seqs_with_embeddings.loc[
|
299 |
+
seqs_with_embeddings["gene"].str.contains("::"),"Type"
|
300 |
+
] = "fusion_embeddings"
|
301 |
+
heads = seqs_with_embeddings.loc[seqs_with_embeddings["gene"].str.contains("::")]["gene"].str.split("::",expand=True)[0].tolist()
|
302 |
+
tails = seqs_with_embeddings.loc[seqs_with_embeddings["gene"].str.contains("::")]["gene"].str.split("::",expand=True)[1].tolist()
|
303 |
+
seqs_with_embeddings.loc[
|
304 |
+
seqs_with_embeddings["gene"].isin(heads),"Type"
|
305 |
+
] = "h_embeddings"
|
306 |
+
seqs_with_embeddings.loc[
|
307 |
+
seqs_with_embeddings["gene"].isin(tails),"Type"
|
308 |
+
] = "t_embeddings"
|
309 |
+
|
310 |
+
# make merge
|
311 |
+
merge = seqs_with_embeddings.loc[seqs_with_embeddings['gene'].str.contains('::')].reset_index(drop=True)[['gene','sequence']]
|
312 |
+
merge["head"] = merge["gene"].str.split("::",expand=True)[0]
|
313 |
+
merge["tail"] = merge["gene"].str.split("::",expand=True)[1]
|
314 |
+
merge = pd.merge(merge, seqs_with_embeddings[['gene','sequence']].rename(
|
315 |
+
columns={'gene': 'head', 'sequence': 'h_sequence'}),
|
316 |
+
on='head',how='left'
|
317 |
+
)
|
318 |
+
merge = pd.merge(merge, seqs_with_embeddings[['gene','sequence']].rename(
|
319 |
+
columns={'gene': 'tail', 'sequence': 't_sequence'}),
|
320 |
+
on='tail',how='left'
|
321 |
+
)
|
322 |
+
|
323 |
+
plt.figure()
|
324 |
+
|
325 |
+
# Define colors and markers
|
326 |
+
colors = {
|
327 |
+
'fusion_embeddings': '#cf9dfa', # old color #0C4A4D
|
328 |
+
'h_embeddings': '#eb8888', # Updated to original names; old color #619283
|
329 |
+
't_embeddings': '#5fa3e3', # Updated to original names; old color #619283
|
330 |
+
}
|
331 |
+
markers = {
|
332 |
+
'fusion_embeddings': 'o',
|
333 |
+
'h_embeddings': '^', # Updated to original names
|
334 |
+
't_embeddings': 'v' # Updated to original names
|
335 |
+
}
|
336 |
+
label_map = {
|
337 |
+
'fusion_embeddings': 'Fusion',
|
338 |
+
'h_embeddings': 'Head', # Updated label
|
339 |
+
't_embeddings': 'Tail', # Updated label
|
340 |
+
}
|
341 |
+
|
342 |
+
# Create a 2x3 grid of plots
|
343 |
+
fig, axes = plt.subplots(2, 3, figsize=(18, 12))
|
344 |
+
#fig, axes = plt.subplots(1, 4, figsize= (18, 7))
|
345 |
+
|
346 |
+
# Get the global min and max for the x and y axis ranges
|
347 |
+
all_tsne1 = seqs_with_embeddings[f'{dimred_type}1']
|
348 |
+
all_tsne2 = seqs_with_embeddings[f'{dimred_type}2']
|
349 |
+
x_min, x_max = all_tsne1.min(), all_tsne1.max()
|
350 |
+
y_min, y_max = all_tsne2.min(), all_tsne2.max()
|
351 |
+
x_min, x_max = [11, 16] # manually set range for cleaner plotting
|
352 |
+
y_min, y_max = [10, 22]
|
353 |
+
|
354 |
+
# Determine tick positions
|
355 |
+
x_ticks = np.arange(x_min, x_max + 1, 1)
|
356 |
+
y_ticks = np.arange(y_min, y_max + 1, 1)
|
357 |
+
|
358 |
+
# Flatten the axes array for easier iteration
|
359 |
+
axes = axes.flatten()
|
360 |
+
|
361 |
+
for i, ax in enumerate(axes):
|
362 |
+
# Extract the gene names from the current row
|
363 |
+
fgene_name = merge.loc[i, 'gene']
|
364 |
+
hgene = merge.loc[i, 'head']
|
365 |
+
tgene = merge.loc[i, 'tail']
|
366 |
+
|
367 |
+
# Filter tsne_embeddings for the relevant entries
|
368 |
+
tsne_data = seqs_with_embeddings[seqs_with_embeddings['gene'].isin([fgene_name, hgene, tgene])]
|
369 |
+
|
370 |
+
# Plot each type
|
371 |
+
for emb_type in tsne_data['Type'].unique():
|
372 |
+
subset = tsne_data[tsne_data['Type'] == emb_type]
|
373 |
+
ax.scatter(subset[f'{dimred_type}1'], subset[f'{dimred_type}2'], label=label_map[emb_type], color=colors[emb_type], marker=markers[emb_type], s=120, zorder=3)
|
374 |
+
|
375 |
+
ax.set_title(f'{fgene_name}',fontsize=44)
|
376 |
+
label_transform = {
|
377 |
+
'tsne': 't-SNE',
|
378 |
+
'umap': 'UMAP'
|
379 |
+
}
|
380 |
+
ax.set_xlabel(f'{label_transform[dimred_type]} 1',fontsize=44)
|
381 |
+
ax.set_ylabel(f'{label_transform[dimred_type]} 2',fontsize=44)
|
382 |
+
ax.grid(True, which='both', linestyle='--', linewidth=0.5, color='gray', zorder=1)
|
383 |
+
|
384 |
+
# Set the same limits and ticks for all axes
|
385 |
+
ax.set_xlim(x_min, x_max)
|
386 |
+
ax.set_ylim(y_min, y_max)
|
387 |
+
ax.set_xticks(x_ticks)#\\, labelsize=24)
|
388 |
+
ax.set_yticks(y_ticks)#, labelsize=24)
|
389 |
+
|
390 |
+
# Rotate x-axis labels
|
391 |
+
ax.set_xticklabels(ax.get_xticks(), rotation=45, ha='right')
|
392 |
+
|
393 |
+
ax.tick_params(axis='x', labelsize=16)
|
394 |
+
ax.tick_params(axis='y', labelsize=16)
|
395 |
+
|
396 |
+
for label in ax.get_xticklabels():
|
397 |
+
label.set_fontsize(24)
|
398 |
+
for label in ax.get_yticklabels():
|
399 |
+
label.set_fontsize(24)
|
400 |
+
|
401 |
+
# Set font size for the legend if needed
|
402 |
+
if i == 0:
|
403 |
+
legend = ax.legend(fontsize=20, markerscale=2, loc='best')
|
404 |
+
for text in legend.get_texts():
|
405 |
+
text.set_fontsize(24)
|
406 |
+
|
407 |
+
# Adjust layout to prevent overlap
|
408 |
+
plt.tight_layout()
|
409 |
+
|
410 |
+
# Show the plot
|
411 |
+
plt.show()
|
412 |
+
|
413 |
+
# Save the figure
|
414 |
+
plt.savefig(f'{savedir}/{dimred_type}_favorites_visualization.png', dpi=300)
|
415 |
+
|
416 |
+
# Save the data
|
417 |
+
seq_to_id_dict = pd.read_csv("../../data/fuson_db.csv")
|
418 |
+
seq_to_id_dict = dict(zip(seq_to_id_dict['aa_seq'],seq_to_id_dict['seq_id']))
|
419 |
+
seqs_with_embeddings['seq_id'] = seqs_with_embeddings['sequence'].map(seq_to_id_dict)
|
420 |
+
seqs_with_embeddings[['umap1','umap2','sequence','Type','gene','id','seq_id']].to_csv(f"{savedir}/{dimred_type}_favorites_source_data.csv",index=False)
|
421 |
+
|
422 |
+
def fusion_v_parts_favorites(dimred_types, output_dir):
|
423 |
+
"""
|
424 |
+
Makes the embeddings, THEN calls the plot. only on the four favorites
|
425 |
+
"""
|
426 |
+
|
427 |
+
# path to the pkl file with FOdb embeddings
|
428 |
+
input_fname='favorites'
|
429 |
+
all_embedding_paths = embed_dataset_for_benchmark(
|
430 |
+
fuson_ckpts=config.FUSON_PLM_CKPT,
|
431 |
+
input_data_path='data/top_genes.csv', input_fname=input_fname,
|
432 |
+
average=True, seq_col='sequence',
|
433 |
+
benchmark_fusonplm=True,
|
434 |
+
benchmark_esm=False,
|
435 |
+
benchmark_fo_puncta_ml=False,
|
436 |
+
overwrite=config.PERMISSION_TO_OVERWRITE)
|
437 |
+
|
438 |
+
# For each of the models we are benchmarking, load embeddings and make plots
|
439 |
+
log_update("\nEmbedding sequences")
|
440 |
+
# loop through the embedding paths and train each one
|
441 |
+
for embedding_path, details in all_embedding_paths.items():
|
442 |
+
log_update(f"\tBenchmarking embeddings at: {embedding_path}")
|
443 |
+
try:
|
444 |
+
with open(embedding_path, "rb") as f:
|
445 |
+
embeddings = pickle.load(f)
|
446 |
+
except:
|
447 |
+
raise Exception(f"Cannot read embeddings from {embedding_path}")
|
448 |
+
|
449 |
+
# combine the embeddings and splits into one dataframe
|
450 |
+
seqs_with_embeddings = pd.DataFrame.from_dict(embeddings.items())
|
451 |
+
seqs_with_embeddings = seqs_with_embeddings.rename(columns={0: 'sequence', 1: 'embedding'}) # the column that was called FusOn-pLM is now called embedding
|
452 |
+
|
453 |
+
# get UMAP transform of the embeddings
|
454 |
+
for dimred_type in dimred_types:
|
455 |
+
dimred_embeddings = get_dimred_embeddings(seqs_with_embeddings['embedding'].tolist(),dimred_type=dimred_type)
|
456 |
+
|
457 |
+
# turn the result into a dataframe, and add it to seqs_with_embeddings
|
458 |
+
data = pd.DataFrame(dimred_embeddings, columns=[f'{dimred_type}1', f'{dimred_type}2'])
|
459 |
+
# save the umap data!
|
460 |
+
model_name = "_".join(embedding_path.split('embeddings/')[1].split('/')[1:-1])
|
461 |
+
|
462 |
+
seqs_with_embeddings[[f'{dimred_type}1', f'{dimred_type}2']] = data
|
463 |
+
|
464 |
+
# make subdirectory
|
465 |
+
intermediate = '/'.join(embedding_path.split('embeddings/')[1].split('/')[0:-1])
|
466 |
+
cur_output_dir = f"{output_dir}/{dimred_type}_plots/{intermediate}/{input_fname}"
|
467 |
+
|
468 |
+
os.makedirs(cur_output_dir,exist_ok=True)
|
469 |
+
make_fusion_v_parts_favorites_plot(seqs_with_embeddings, savedir = cur_output_dir, dimred_type=dimred_type)
|
470 |
+
|
471 |
+
def main():
|
472 |
+
# make directory to save results
|
473 |
+
os.makedirs('results',exist_ok=True)
|
474 |
+
output_dir = f'results/{get_local_time()}'
|
475 |
+
os.makedirs(output_dir,exist_ok=True)
|
476 |
+
|
477 |
+
dimred_types = []
|
478 |
+
if config.PLOT_UMAP:
|
479 |
+
dimred_types.append("umap")
|
480 |
+
#os.makedirs(f"{output_dir}/umap_data",exist_ok=True)
|
481 |
+
os.makedirs(f"{output_dir}/umap_plots",exist_ok=True)
|
482 |
+
if config.PLOT_TSNE:
|
483 |
+
dimred_types.append("tsne")
|
484 |
+
#os.makedirs(f"{output_dir}/tsne_data",exist_ok=True)
|
485 |
+
os.makedirs(f"{output_dir}/tsne_plots",exist_ok=True)
|
486 |
+
|
487 |
+
with open_logfile(f'{output_dir}/embedding_exploration_log.txt'):
|
488 |
+
print_configpy(config)
|
489 |
+
# make the disinct embeddings plot
|
490 |
+
fusion_v_parts_favorites(dimred_types, output_dir)
|
491 |
+
|
492 |
+
tf_and_kinase_fusions_plot(dimred_types, output_dir)
|
493 |
+
|
494 |
+
|
495 |
+
if __name__ == "__main__":
|
496 |
+
main()
|
fuson_plm/benchmarking/embedding_exploration/results/final/umap_plots/fuson_plm/best/favorites/umap_favorites_source_data.csv
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:28c0b51f513da01df3dee3c4e71aa0c583bd57d9878137bdac9e7ebc704694e4
|
3 |
+
size 17383
|
fuson_plm/benchmarking/embedding_exploration/results/final/umap_plots/fuson_plm/best/favorites/umap_favorites_visualization.png
ADDED
![]() |
fuson_plm/benchmarking/embedding_exploration/results/final/umap_plots/fuson_plm/best/tf_and_kinase/umap_tf_and_kinase_fusions_source_data.csv
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b26a5a6c2f8f54225fd46f01dab52813532438732624561af8e2e4ad005e5dc7
|
3 |
+
size 570073
|
fuson_plm/benchmarking/embedding_exploration/results/final/umap_plots/fuson_plm/best/tf_and_kinase/umap_tf_and_kinase_fusions_visualization.png
ADDED
![]() |
fuson_plm/benchmarking/mutation_prediction/README.md
CHANGED
@@ -81,7 +81,7 @@ To run, use:
|
|
81 |
```
|
82 |
nohup python discover.py > discover.out 2> discover.err &
|
83 |
```
|
84 |
-
- All **results** are stored in `
|
85 |
|
86 |
Below are the FusOn-pLM paper results in `results/final`:
|
87 |
|
|
|
81 |
```
|
82 |
nohup python discover.py > discover.out 2> discover.err &
|
83 |
```
|
84 |
+
- All **results** are stored in `mutation_prediction/results/<timestamp>`, where `timestamp` is a unique string encoding the date and time when you started training.
|
85 |
|
86 |
Below are the FusOn-pLM paper results in `results/final`:
|
87 |
|
fuson_plm/benchmarking/puncta/train.py
CHANGED
@@ -5,7 +5,7 @@ import numpy as np
|
|
5 |
import pickle
|
6 |
import os
|
7 |
|
8 |
-
from fuson_plm.benchmarking.xgboost_predictor import train_final_predictor, evaluate_predictor
|
9 |
from fuson_plm.benchmarking.embed import embed_dataset_for_benchmark
|
10 |
import fuson_plm.benchmarking.puncta.config as config
|
11 |
from fuson_plm.benchmarking.puncta.plot import make_all_final_bar_charts
|
|
|
5 |
import pickle
|
6 |
import os
|
7 |
|
8 |
+
from fuson_plm.benchmarking.xgboost_predictor import train_final_predictor, evaluate_predictor
|
9 |
from fuson_plm.benchmarking.embed import embed_dataset_for_benchmark
|
10 |
import fuson_plm.benchmarking.puncta.config as config
|
11 |
from fuson_plm.benchmarking.puncta.plot import make_all_final_bar_charts
|
fuson_plm/benchmarking/xgboost_predictor.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from sklearn.model_selection import train_test_split, StratifiedKFold
|
2 |
+
from sklearn.metrics import accuracy_score, recall_score, f1_score, precision_score, roc_auc_score, average_precision_score
|
3 |
+
from fuson_plm.utils.logging import log_update
|
4 |
+
import time
|
5 |
+
import xgboost as xgb
|
6 |
+
import numpy as np
|
7 |
+
import pandas as pd
|
8 |
+
|
9 |
+
def train_final_predictor(X_train, y_train, n_estimators=50,tree_method="hist"):
|
10 |
+
clf = xgb.XGBClassifier(n_estimators=n_estimators, tree_method=tree_method)
|
11 |
+
clf.fit(X_train, y_train)
|
12 |
+
return clf
|
13 |
+
|
14 |
+
def evaluate_predictor(clf,X_test,y_test,class1_thresh=None):
|
15 |
+
# Predicting the labels on test set
|
16 |
+
y_pred_test = clf.predict(X_test) # labels with automatic thresholds
|
17 |
+
y_pred_prob_test = clf.predict_proba(X_test)[:, 1]
|
18 |
+
if class1_thresh is not None: y_pred_customthresh_test = np.where(np.array(y_pred_prob_test) >= class1_thresh, 1, 0)
|
19 |
+
|
20 |
+
# Calculating metrics - automatic
|
21 |
+
accuracy = accuracy_score(y_test, y_pred_test)
|
22 |
+
precision = precision_score(y_test, y_pred_test)
|
23 |
+
recall = recall_score(y_test, y_pred_test)
|
24 |
+
f1 = f1_score(y_test, y_pred_test)
|
25 |
+
auroc_prob = roc_auc_score(y_test, y_pred_prob_test)
|
26 |
+
auprc_prob = average_precision_score(y_test, y_pred_prob_test)
|
27 |
+
auroc_label = roc_auc_score(y_test, y_pred_test)
|
28 |
+
auprc_label = average_precision_score(y_test, y_pred_test)
|
29 |
+
|
30 |
+
automatic_stats_df = pd.DataFrame(data={
|
31 |
+
'Accuracy': [accuracy],
|
32 |
+
'Precision': [precision],
|
33 |
+
'Recall': [recall],
|
34 |
+
'F1 Score': [f1],
|
35 |
+
'AUROC': [auroc_prob],
|
36 |
+
'AUROC Label': [auroc_label],
|
37 |
+
'AUPRC': [auprc_prob],
|
38 |
+
'AUPRC Label': [auprc_label]
|
39 |
+
})
|
40 |
+
|
41 |
+
# Calculating metrics - custom threshold (note that probability ones won't change)
|
42 |
+
if class1_thresh is not None:
|
43 |
+
accuracy_custom = accuracy_score(y_test, y_pred_customthresh_test)
|
44 |
+
precision_custom = precision_score(y_test, y_pred_customthresh_test)
|
45 |
+
recall_custom = recall_score(y_test, y_pred_customthresh_test)
|
46 |
+
f1_custom = f1_score(y_test, y_pred_customthresh_test)
|
47 |
+
auroc_prob_custom = roc_auc_score(y_test, y_pred_prob_test)
|
48 |
+
auprc_prob_custom = average_precision_score(y_test, y_pred_prob_test)
|
49 |
+
auroc_label_custom = roc_auc_score(y_test, y_pred_customthresh_test)
|
50 |
+
auprc_label_custom = average_precision_score(y_test, y_pred_customthresh_test)
|
51 |
+
|
52 |
+
custom_stats_df = pd.DataFrame(data={
|
53 |
+
'Accuracy': [accuracy_custom],
|
54 |
+
'Precision': [precision_custom],
|
55 |
+
'Recall': [recall_custom],
|
56 |
+
'F1 Score': [f1_custom],
|
57 |
+
'AUROC': [auroc_prob_custom],
|
58 |
+
'AUROC Label': [auroc_label_custom],
|
59 |
+
'AUPRC': [auprc_prob_custom],
|
60 |
+
'AUPRC Label': [auprc_label_custom]
|
61 |
+
})
|
62 |
+
else:
|
63 |
+
custom_stats_df = None
|
64 |
+
|
65 |
+
return automatic_stats_df, custom_stats_df
|