1 |
# Dependencies
2 |
3 |
Here we provied package versions needed to run FusOn-pLM code. For the project, Docker containers were used. We provide a pip list of what is inside the Docker container, as well as the images used for our containers.
4 |
5 |
## pip installs
6 |
7 |
The following dependencies were used for all training and benchmarking except for the `puncta` benchmarks.
8 |
Note that after cloning the repository, you will need to run `pip install e .` outside the `fuson_plm` directory to install `fuson_plm` package.
9 |
10 |
Package Version Editable project location
11 |
------------------------- -------------------- -------------------------
12 |
absl-py 1.4.0
13 |
aiohttp 3.8.4
14 |
aiosignal 1.3.1
15 |
apex 0.1
16 |
argon2-cffi 21.3.0
17 |
argon2-cffi-bindings 21.2.0
18 |
asttokens 2.2.1
19 |
astunparse 1.6.3
20 |
async-timeout 4.0.2
21 |
attrs 23.1.0
22 |
audioread 3.0.0
23 |
backcall 0.2.0
24 |
beautifulsoup4 4.12.2
25 |
bio 1.7.1
26 |
biopython 1.84
27 |
biothings-client 0.3.1
28 |
bleach 6.0.0
29 |
blis 0.7.10
30 |
cachetools 5.3.1
31 |
catalogue 2.0.9
32 |
certifi 2023.7.22
33 |
cffi 1.15.1
34 |
charset-normalizer 3.2.0
35 |
click 8.1.5
36 |
cloudpickle 2.2.1
37 |
cmake 3.27.1
38 |
comm 0.1.4
39 |
confection 0.1.1
40 |
contourpy 1.1.0
41 |
cubinlinker 0.3.0+2.g7c3675e
42 |
cuda-python 12.1.0rc5+1.g994d8d0
43 |
cudf 23.6.0
44 |
cugraph 23.6.0
45 |
cugraph-dgl 23.6.0
46 |
cugraph-service-client 23.6.0
47 |
cugraph-service-server 23.6.0
48 |
cuml 23.6.0
49 |
cupy-cuda12x 12.1.0
50 |
cycler 0.11.0
51 |
cymem 2.0.7
52 |
Cython 3.0.0
53 |
dask 2023.3.2
54 |
dask-cuda 23.6.0
55 |
dask-cudf 23.6.0
56 |
debugpy 1.6.7
57 |
decorator 5.1.1
58 |
defusedxml 0.7.1
59 |
distributed 2023.3.2.1
60 |
dm-tree 0.1.8
61 |
docker-pycreds 0.4.0
62 |
einops 0.6.1
63 |
exceptiongroup 1.1.2
64 |
execnet 2.0.2
65 |
executing 1.2.0
66 |
expecttest 0.1.3
67 |
fair-esm 2.0.0
68 |
fastjsonschema 2.18.0
69 |
fastrlock 0.8.1
70 |
filelock 3.12.2
71 |
flash-attn 2.0.4
72 |
fonttools 4.42.0
73 |
frozenlist 1.4.0
74 |
fsspec 2023.6.0
75 |
fuson-plm 1.0 /workspace/FusOn-pLM
76 |
gast 0.5.4
77 |
gdown 5.2.0
78 |
gitdb 4.0.11
79 |
GitPython 3.1.43
80 |
google-auth 2.22.0
81 |
google-auth-oauthlib 0.4.6
82 |
gprofiler-official 1.0.0
83 |
graphsurgeon 0.4.6
84 |
grpcio 1.56.2
85 |
huggingface-hub 0.25.2
86 |
hypothesis 5.35.1
87 |
idna 3.4
88 |
importlib-metadata 6.8.0
89 |
iniconfig 2.0.0
90 |
intel-openmp 2021.4.0
91 |
ipykernel 6.25.0
92 |
ipython 8.14.0
93 |
ipython-genutils 0.2.0
94 |
jedi 0.19.0
95 |
Jinja2 3.1.2
96 |
joblib 1.3.1
97 |
json5 0.9.14
98 |
jsonschema 4.18.6
99 |
jsonschema-specifications 2023.7.1
100 |
jupyter_client 8.3.0
101 |
jupyter_core 5.3.1
102 |
jupyter-tensorboard 0.2.0
103 |
jupyterlab 2.3.2
104 |
jupyterlab-pygments 0.2.2
105 |
jupyterlab-server 1.2.0
106 |
jupytext 1.15.0
107 |
kiwisolver 1.4.4
108 |
langcodes 3.3.0
109 |
librosa 0.9.2
110 |
lightning-utilities 0.11.8
111 |
llvmlite 0.40.1
112 |
locket 1.0.0
113 |
Markdown 3.4.4
114 |
markdown-it-py 3.0.0
115 |
MarkupSafe 2.1.3
116 |
matplotlib 3.7.2
117 |
matplotlib-inline 0.1.6
118 |
mdit-py-plugins 0.4.0
119 |
mdurl 0.1.2
120 |
mistune 3.0.1
121 |
mkl 2021.1.1
122 |
mkl-devel 2021.1.1
123 |
mkl-include 2021.1.1
124 |
mock 5.1.0
125 |
mpmath 1.3.0
126 |
msgpack 1.0.5
127 |
multidict 6.0.4
128 |
murmurhash 1.0.9
129 |
mygene 3.2.2
130 |
nbclient 0.8.0
131 |
nbconvert 7.7.3
132 |
nbformat 5.9.2
133 |
nest-asyncio 1.5.7
134 |
networkx 2.6.3
135 |
ninja 1.11.1
136 |
notebook 6.4.10
137 |
numba 0.57.1+1.gc785c8f1f
138 |
numpy 1.22.2
139 |
140 |
nvidia-cuda-cupti-cu12 12.4.127
141 |
nvidia-cuda-nvrtc-cu12 12.4.127
142 |
nvidia-cuda-runtime-cu12 12.4.127
143 |
144 |
145 |
146 |
147 |
148 |
nvidia-dali-cuda120 1.28.0
149 |
nvidia-nccl-cu12 2.21.5
150 |
nvidia-nvjitlink-cu12 12.4.127
151 |
nvidia-nvtx-cu12 12.4.127
152 |
nvidia-pyindex 1.0.9
153 |
nvtx 0.2.5
154 |
oauthlib 3.2.2
155 |
onnx 1.14.0
156 |
opencv 4.7.0
157 |
packaging 23.1
158 |
pandas 1.5.2
159 |
pandocfilters 1.5.0
160 |
parso 0.8.3
161 |
partd 1.4.0
162 |
pathy 0.10.2
163 |
pexpect 4.8.0
164 |
pickleshare 0.7.5
165 |
Pillow 9.2.0
166 |
pip 23.2.1
167 |
platformdirs 3.10.0
168 |
pluggy 1.2.0
169 |
ply 3.11
170 |
polygraphy 0.47.1
171 |
pooch 1.7.0
172 |
preshed 3.0.8
173 |
prettytable 3.8.0
174 |
prometheus-client 0.17.1
175 |
prompt-toolkit 3.0.39
176 |
protobuf 4.21.12
177 |
psutil 5.9.4
178 |
ptxcompiler 0.8.1+1.g4a94326
179 |
ptyprocess 0.7.0
180 |
pure-eval 0.2.2
181 |
py3Dmol 2.4.0
182 |
pyarrow 11.0.0
183 |
pyasn1 0.5.0
184 |
pyasn1-modules 0.3.0
185 |
pybind11 2.11.1
186 |
pycocotools 2.0+nv0.7.3
187 |
pycparser 2.21
188 |
pydantic 1.10.12
189 |
Pygments 2.16.1
190 |
pylibcugraph 23.6.0
191 |
pylibcugraphops 23.6.0
192 |
pylibraft 23.6.0
193 |
pynndescent 0.5.13
194 |
pynvml 11.4.1
195 |
pyparsing 3.0.9
196 |
PySocks 1.7.1
197 |
pytest 7.4.0
198 |
pytest-flakefinder 1.1.0
199 |
pytest-rerunfailures 12.0
200 |
pytest-shard 0.1.2
201 |
pytest-xdist 3.3.1
202 |
python-dateutil 2.8.2
203 |
python-hostlist 1.23.0
204 |
pytorch-lightning 2.4.0
205 |
pytorch-quantization 2.1.2
206 |
pytz 2023.3
207 |
PyYAML 6.0.1
208 |
pyzmq 25.1.0
209 |
raft-dask 23.6.0
210 |
referencing 0.30.2
211 |
regex 2023.6.3
212 |
requests 2.31.0
213 |
requests-oauthlib 1.3.1
214 |
resampy 0.4.2
215 |
rmm 23.6.0
216 |
rpds-py 0.9.2
217 |
rsa 4.9
218 |
safetensors 0.4.5
219 |
scikit-learn 1.2.0
220 |
scipy 1.11.1
221 |
seaborn 0.13.2
222 |
Send2Trash 1.8.2
223 |
sentencepiece 0.2.0
224 |
sentry-sdk 2.16.0
225 |
setproctitle 1.3.3
226 |
setuptools 68.0.0
227 |
six 1.16.0
228 |
smart-open 6.3.0
229 |
smmap 5.0.1
230 |
sortedcontainers 2.4.0
231 |
soundfile 0.12.1
232 |
soupsieve 2.4.1
233 |
spacy 3.6.0
234 |
spacy-legacy 3.0.12
235 |
spacy-loggers 1.0.4
236 |
sphinx-glpi-theme 0.3
237 |
srsly 2.4.7
238 |
stack-data 0.6.2
239 |
sympy 1.13.1
240 |
tabulate 0.9.0
241 |
tbb 2021.10.0
242 |
tblib 2.0.0
243 |
tensorboard 2.9.0
244 |
tensorboard-data-server 0.6.1
245 |
tensorboard-plugin-wit 1.8.1
246 |
tensorrt 8.6.1
247 |
terminado 0.17.1
248 |
thinc 8.1.10
249 |
threadpoolctl 3.2.0
250 |
thriftpy2 0.4.16
251 |
tinycss2 1.2.1
252 |
tokenizers 0.20.1
253 |
toml 0.10.2
254 |
tomli 2.0.1
255 |
toolz 0.12.0
256 |
torch 2.5.0
257 |
torch-tensorrt 2.0.0.dev0
258 |
torchdata 0.7.0a0
259 |
torchmetrics 1.5.0
260 |
torchtext 0.16.0a0
261 |
torchvision 0.16.0a0
262 |
tornado 6.3.2
263 |
tqdm 4.65.0
264 |
traitlets 5.9.0
265 |
transformer-engine 0.11.0+3f01b4f
266 |
transformers 4.45.2
267 |
treelite 3.2.0
268 |
treelite-runtime 3.2.0
269 |
triton 3.1.0
270 |
typer 0.9.0
271 |
types-dataclasses 0.6.6
272 |
typing_extensions 4.12.2
273 |
ucx-py 0.32.0
274 |
uff 0.6.9
275 |
umap-learn 0.5.6
276 |
urllib3 1.26.16
277 |
wandb 0.18.3
278 |
wasabi 1.1.2
279 |
wcwidth 0.2.6
280 |
webencodings 0.5.1
281 |
Werkzeug 2.3.6
282 |
wheel 0.41.1
283 |
xdoctest 1.0.2
284 |
xgboost 1.7.5
285 |
yarl 1.9.2
286 |
zict 3.0.0
287 |
zipp 3.16.2
288 |
289 |
The following packages and versions were used for the `puncta` benchmarks. A different environment was required to run ProtT5.
290 |
291 |
Package Version Editable project location
292 |
------------------------- -------------------------- -------------------------
293 |
absl-py 2.1.0
294 |
aiohttp 3.9.3
295 |
aiosignal 1.3.1
296 |
annotated-types 0.6.0
297 |
anyio 4.8.0
298 |
apex 0.1
299 |
argon2-cffi 23.1.0
300 |
argon2-cffi-bindings 21.2.0
301 |
asttokens 2.4.1
302 |
astunparse 1.6.3
303 |
async-timeout 4.0.3
304 |
attrs 23.2.0
305 |
audioread 3.0.1
306 |
beautifulsoup4 4.12.3
307 |
bio 1.7.1
308 |
biopython 1.85
309 |
biothings_client 0.4.1
310 |
bleach 6.1.0
311 |
blis 0.7.11
312 |
cachetools 5.3.3
313 |
catalogue 2.0.10
314 |
certifi 2024.2.2
315 |
cffi 1.16.0
316 |
charset-normalizer 3.3.2
317 |
click 8.1.7
318 |
cloudpathlib 0.16.0
319 |
cloudpickle 3.0.0
320 |
321 |
comm 0.2.2
322 |
confection 0.1.4
323 |
contourpy 1.2.1
324 |
cuda-python 12.4.0rc7+3.ge75c8a9.dirty
325 |
cudf 24.2.0
326 |
cudnn 1.1.2
327 |
cugraph 24.2.0
328 |
cugraph-dgl 24.2.0
329 |
cugraph-service-client 24.2.0
330 |
cugraph-service-server 24.2.0
331 |
cuml 24.2.0
332 |
cupy-cuda12x 13.0.0
333 |
cycler 0.12.1
334 |
cymem 2.0.8
335 |
Cython 3.0.10
336 |
dask 2024.1.1
337 |
dask-cuda 24.2.0
338 |
dask-cudf 24.2.0
339 |
debugpy 1.8.1
340 |
decorator 5.1.1
341 |
defusedxml 0.7.1
342 |
distributed 2024.1.1
343 |
dm-tree 0.1.8
344 |
docker-pycreds 0.4.0
345 |
einops 0.7.0
346 |
exceptiongroup 1.2.0
347 |
execnet 2.0.2
348 |
executing 2.0.1
349 |
expecttest 0.1.3
350 |
fair-esm 2.0.0
351 |
fastjsonschema 2.19.1
352 |
fastrlock 0.8.2
353 |
filelock 3.13.3
354 |
flash-attn 2.4.2
355 |
fonttools 4.51.0
356 |
frozenlist 1.4.1
357 |
fsspec 2024.2.0
358 |
fuson-plm 1.0 /workspace/FusOn-pLM
359 |
gast 0.5.4
360 |
gdown 5.2.0
361 |
gitdb 4.0.12
362 |
GitPython 3.1.44
363 |
google-auth 2.29.0
364 |
google-auth-oauthlib 0.4.6
365 |
gprofiler-official 1.0.0
366 |
graphsurgeon 0.4.6
367 |
grpcio 1.62.1
368 |
h11 0.14.0
369 |
httpcore 1.0.7
370 |
httpx 0.28.1
371 |
huggingface-hub 0.27.1
372 |
hypothesis 5.35.1
373 |
idna 3.6
374 |
igraph 0.11.4
375 |
importlib_metadata 7.0.2
376 |
iniconfig 2.0.0
377 |
intel-openmp 2021.4.0
378 |
ipykernel 6.29.4
379 |
ipython 8.21.0
380 |
ipython-genutils 0.2.0
381 |
jedi 0.19.1
382 |
Jinja2 3.1.3
383 |
joblib 1.3.2
384 |
json5 0.9.24
385 |
jsonschema 4.21.1
386 |
jsonschema-specifications 2023.12.1
387 |
jupyter_client 8.6.1
388 |
jupyter_core 5.7.2
389 |
jupyter-tensorboard 0.2.0
390 |
jupyterlab 2.3.2
391 |
jupyterlab_pygments 0.3.0
392 |
jupyterlab-server 1.2.0
393 |
jupytext 1.16.1
394 |
kiwisolver 1.4.5
395 |
langcodes 3.3.0
396 |
lark 1.1.9
397 |
lazy_loader 0.4
398 |
librosa 0.10.1
399 |
lightning-thunder 0.1.0
400 |
lightning-utilities 0.11.2
401 |
llvmlite 0.42.0
402 |
locket 1.0.0
403 |
looseversion 1.3.0
404 |
Markdown 3.6
405 |
markdown-it-py 3.0.0
406 |
MarkupSafe 2.1.5
407 |
matplotlib 3.8.4
408 |
matplotlib-inline 0.1.6
409 |
mdit-py-plugins 0.4.0
410 |
mdurl 0.1.2
411 |
mistune 3.0.2
412 |
mkl 2021.1.1
413 |
mkl-devel 2021.1.1
414 |
mkl-include 2021.1.1
415 |
mock 5.1.0
416 |
mpmath 1.3.0
417 |
msgpack 1.0.8
418 |
multidict 6.0.5
419 |
murmurhash 1.0.10
420 |
mygene 3.2.2
421 |
nbclient 0.10.0
422 |
nbconvert 7.16.3
423 |
nbformat 5.10.4
424 |
nest-asyncio 1.6.0
425 |
networkx 2.6.3
426 |
427 |
notebook 6.4.10
428 |
numba 0.59.0+1.g20ae2b56c
429 |
numpy 1.24.4
430 |
nvfuser 0.1.6a0+a684e2a
431 |
nvidia-dali-cuda120 1.36.0
432 |
433 |
nvidia-pyindex 1.0.9
434 |
nvtx 0.2.5
435 |
oauthlib 3.2.2
436 |
onnx 1.16.0
437 |
opencv 4.7.0
438 |
opt-einsum 3.3.0
439 |
optree 0.11.0
440 |
packaging 23.2
441 |
pandas 1.5.3
442 |
pandocfilters 1.5.1
443 |
parso 0.8.4
444 |
partd 1.4.1
445 |
pexpect 4.9.0
446 |
pillow 10.2.0
447 |
pip 24.0
448 |
platformdirs 4.2.0
449 |
pluggy 1.4.0
450 |
ply 3.11
451 |
polygraphy 0.49.8
452 |
pooch 1.8.1
453 |
preshed 3.0.9
454 |
prettytable 3.10.0
455 |
prometheus_client 0.20.0
456 |
prompt-toolkit 3.0.43
457 |
protobuf 4.24.4
458 |
psutil 5.9.4
459 |
ptyprocess 0.7.0
460 |
pure-eval 0.2.2
461 |
py3Dmol 2.4.2
462 |
pyarrow 14.0.1
463 |
pyasn1 0.6.0
464 |
pyasn1_modules 0.4.0
465 |
pybind11 2.12.0
466 |
pybind11_global 2.12.0
467 |
pycocotools 2.0+nv0.8.0
468 |
pycparser 2.22
469 |
pydantic 2.6.4
470 |
pydantic_core 2.16.3
471 |
Pygments 2.17.2
472 |
pylibcugraph 24.2.0
473 |
pylibcugraphops 24.2.0
474 |
pylibraft 24.2.0
475 |
pynndescent 0.5.13
476 |
pynvjitlink 0.1.13
477 |
pynvml 11.4.1
478 |
pyparsing 3.1.2
479 |
PySocks 1.7.1
480 |
pytest 8.1.1
481 |
pytest-flakefinder 1.1.0
482 |
pytest-rerunfailures 14.0
483 |
pytest-shard 0.1.2
484 |
pytest-xdist 3.5.0
485 |
python-dateutil 2.9.0.post0
486 |
python-hostlist 1.23.0
487 |
pytorch-lightning 2.5.0.post0
488 |
pytorch-quantization 2.1.2
489 |
pytorch-triton 3.0.0+a9bc1a364
490 |
pytz 2024.1
491 |
PyYAML 6.0.1
492 |
pyzmq 25.1.2
493 |
raft-dask 24.2.0
494 |
rapids-dask-dependency 24.2.0a0
495 |
referencing 0.34.0
496 |
regex 2023.12.25
497 |
requests 2.31.0
498 |
requests-oauthlib 2.0.0
499 |
rich 13.7.1
500 |
rmm 24.2.0
501 |
rpds-py 0.18.0
502 |
rsa 4.9
503 |
safetensors 0.5.2
504 |
scikit-learn 1.2.0
505 |
scipy 1.12.0
506 |
seaborn 0.13.2
507 |
Send2Trash 1.8.2
508 |
sentencepiece 0.2.0
509 |
sentry-sdk 2.20.0
510 |
setproctitle 1.3.4
511 |
setuptools 68.2.2
512 |
six 1.16.0
513 |
smart-open 6.4.0
514 |
smmap 5.0.2
515 |
sniffio 1.3.1
516 |
sortedcontainers 2.4.0
517 |
soundfile 0.12.1
518 |
soupsieve 2.5
519 |
soxr 0.3.7
520 |
spacy 3.7.4
521 |
spacy-legacy 3.0.12
522 |
spacy-loggers 1.0.5
523 |
sphinx_glpi_theme 0.6
524 |
srsly 2.4.8
525 |
stack-data 0.6.3
526 |
sympy 1.12
527 |
tabulate 0.9.0
528 |
tbb 2021.12.0
529 |
tblib 3.0.0
530 |
tensorboard 2.9.0
531 |
tensorboard-data-server 0.6.1
532 |
tensorboard-plugin-wit 1.8.1
533 |
tensorrt 8.6.3
534 |
terminado 0.18.1
535 |
texttable 1.7.0
536 |
thinc 8.2.3
537 |
threadpoolctl 3.3.0
538 |
thriftpy2 0.4.17
539 |
tinycss2 1.2.1
540 |
tokenizers 0.21.0
541 |
toml 0.10.2
542 |
tomli 2.0.1
543 |
toolz 0.12.1
544 |
torch 2.3.0a0+6ddf5cf85e.nv24.4
545 |
torch-tensorrt 2.3.0a0
546 |
torchdata 0.7.1a0
547 |
torchmetrics 1.6.1
548 |
torchtext 0.17.0a0
549 |
torchvision 0.18.0a0
550 |
tornado 6.4
551 |
tqdm 4.66.2
552 |
traitlets 5.9.0
553 |
transformer-engine 1.5.0+6a9edc3
554 |
transformers 4.48.0
555 |
treelite 4.0.0
556 |
typer 0.9.4
557 |
types-dataclasses 0.6.6
558 |
typing_extensions 4.10.0
559 |
ucx-py 0.36.0
560 |
uff 0.6.9
561 |
umap-learn 0.5.7
562 |
urllib3 1.26.18
563 |
wandb 0.19.4
564 |
wasabi 1.1.2
565 |
wcwidth 0.2.13
566 |
weasel 0.3.4
567 |
webencodings 0.5.1
568 |
Werkzeug 3.0.2
569 |
wheel 0.43.0
570 |
xdoctest 1.0.2
571 |
xgboost 1.7.5
572 |
yarl 1.9.4
573 |
zict 3.0.0
574 |
zipp 3.17.0
575 |
576 |
## Docker
577 |
578 |
The following image was used for Container 1 (all code except puncta benchmark):
579 |
580 |
581 |
582 |
583 |
584 |
The following image was used for Container 2 (puncta benchmark):
585 |
586 |
587 |
588 |
1 |
# Benchmarking
2 |
3 |
This outer directory for the benchmarks in FusOn-pLM has some utility functions stored in `.py` files.
4 |
5 |
6 |
7 |
This file contains functions used to make and organize FusOn-pLM and ESM embeddings of benchmarking datasets. Its functions are used in all benchmarks.
8 |
9 |
10 |
11 |
This file contains functions used to train XGBoost predictors, which are utilized in the `puncta` benchmark.
File without changes
1 |
# Python file for making embeddings from a FusOn-pLM model for any dataset
2 |
from fuson_plm.utils.embedding import get_esm_embeddings, load_esm2_type, redump_pickle_dictionary, load_prott5, get_prott5_embeddings
3 |
from fuson_plm.utils.logging import log_update, open_logfile, print_configpy
4 |
from fuson_plm.utils.data_cleaning import find_invalid_chars
5 |
from fuson_plm.utils.constants import VALID_AAS
6 |
from import FusOnpLM
7 |
from transformers import AutoModelForMaskedLM, AutoTokenizer, AutoModel
8 |
import logging
9 |
import torch
10 |
import pickle
11 |
import os
12 |
import pandas as pd
13 |
import numpy as np
14 |
15 |
def validate_sequence_col(df, seq_col):
16 |
# if column isn't there, error
17 |
if seq_col not in list(df.columns):
18 |
raise Exception("Error: provided sequence column does not exist in the input dataframe")
19 |
20 |
# if column contains invalid characters, error
21 |
df['invalid_chars'] = df[seq_col].apply(lambda x: find_invalid_chars(x, VALID_AAS))
22 |
all_invalid_chars = set().union(*df['invalid_chars'])
23 |
df = df.drop(columns=['invalid_chars'])
24 |
if len(all_invalid_chars)>0:
25 |
raise Exception(f"Error: invalid characters {all_invalid_chars} found in the sequence column")
26 |
27 |
# make sure there are no duplicates
28 |
sequences = df[seq_col]
29 |
if len(set(sequences))<len(sequences): log_update("\tWARNING: input data has duplicate sequences")
30 |
31 |
def load_fuson_model(ckpt_path):
32 |
# Suppress warnings about newly initialized 'esm.pooler.dense.bias', 'esm.pooler.dense.weight' layers - these are not used to extract embeddings
33 |
34 |
35 |
# Set device
36 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
37 |
print(f"Using device: {device}")
38 |
39 |
# Load model
40 |
model = AutoModel.from_pretrained(ckpt_path) # initialize model
41 |
tokenizer = AutoTokenizer.from_pretrained(ckpt_path) # initialize tokenizer
42 |
43 |
# Model to device and in eval mode
44 |
45 |
model.eval() # disables dropout for deterministic results
46 |
47 |
return model, tokenizer, device
48 |
49 |
def get_fuson_embeddings(model, tokenizer, sequences, device, average=True, print_updates=False, savepath=None, save_at_end=False, max_length=2000):
50 |
# Correct save path to pickle if necessary
51 |
if savepath is not None:
52 |
if savepath[-4::] != '.pkl': savepath += '.pkl'
53 |
54 |
if print_updates: log_update(f"Dataset contains {len(sequences)} sequences.")
55 |
56 |
# If no max length was passed, just set it to the maximum in the dataset
57 |
max_seq_len = max([len(s) for s in sequences])
58 |
if max_length is None: max_length=max_seq_len+2 # add 2 for BOS, EOS
59 |
60 |
# Initialize an empty dict to store the ESM embeddings
61 |
embedding_dict = {}
62 |
# Iterate through the seqs
63 |
for i in range(len(sequences)):
64 |
sequence = sequences[i]
65 |
# Get the embeddings
66 |
with torch.no_grad():
67 |
# Tokenize the input sequence
68 |
inputs = tokenizer(sequence, return_tensors="pt", padding=True, truncation=True,max_length=max_length)
69 |
inputs = {k: for k, v in inputs.items()}
70 |
71 |
outputs = model(**inputs)
72 |
# The embeddings are in the last_hidden_state tensor
73 |
embedding = outputs.last_hidden_state
74 |
# remove extra dimension
75 |
embedding = embedding.squeeze(0)
76 |
# remove BOS and EOS tokens
77 |
embedding = embedding[1:-1, :]
78 |
79 |
# Convert embeddings to numpy array (if needed)
80 |
embedding = embedding.cpu().numpy()
81 |
82 |
# Average (if necessary)
83 |
if average:
84 |
embedding = embedding.mean(0)
85 |
86 |
# Add to dictionary
87 |
embedding_dict[sequence] = embedding
88 |
89 |
# Save individual embedding (if necessary)
90 |
if not(savepath is None) and not(save_at_end):
91 |
with open(savepath, 'ab+') as f:
92 |
d = {sequence: embedding}
93 |
pickle.dump(d, f)
94 |
95 |
# Print update (if necessary)
96 |
if print_updates: log_update(f"sequence {i+1}: {sequence[0:10]}...")
97 |
98 |
# Dump all at once at the end (if necessary)
99 |
if not(savepath is None):
100 |
# If saving for the first time, just dump it
101 |
if save_at_end:
102 |
with open(savepath, 'wb') as f:
103 |
pickle.dump(embedding_dict, f)
104 |
# If we've been saving all along and made it here without crashing, correct the pickle file so it can be loaded nicely
105 |
106 |
107 |
108 |
def embed_dataset(path_to_file, path_to_output, seq_col='aa_seq', model_type='fuson_plm', fuson_ckpt_path = None, average=True, overwrite=True, print_updates=False,max_length=2000):
109 |
# Make sure we aren't overwriting pre-existing embeddings
110 |
if os.path.exists(path_to_output):
111 |
if overwrite:
112 |
log_update(f"WARNING: these embeddings may already exist at {path_to_output} and will be overwritten")
113 |
114 |
log_update(f"WARNING: these embeddings may already exist at {path_to_output}. Skipping.")
115 |
return None
116 |
117 |
dataset = pd.read_csv(path_to_file)
118 |
# Make sure the sequence column is valid
119 |
validate_sequence_col(dataset, seq_col)
120 |
121 |
sequences = dataset[seq_col].unique().tolist() # ensure all entries are unique
122 |
123 |
### If FusOn-pLM: make fusion embeddings
124 |
if model_type=='fuson_plm':
125 |
if not(os.path.exists(fuson_ckpt_path)): raise Exception("FusOn-pLM ckpt path does not exist")
126 |
127 |
# Load model
128 |
129 |
model, tokenizer, device = load_fuson_model(fuson_ckpt_path)
130 |
131 |
raise Exception(f"Could not load FusOn-pLM from {fuson_ckpt_path}")
132 |
133 |
# Generate embeddigns
134 |
135 |
get_fuson_embeddings(model, tokenizer, sequences, device, average=average,
136 |
print_updates=print_updates, savepath=path_to_output, save_at_end=False,
137 |
138 |
139 |
raise Exception("Could not generate FusOn-pLM embeddings")
140 |
141 |
if model_type=='esm2_t33_650M_UR50D':
142 |
# Load model
143 |
144 |
model, tokenizer, device = load_esm2_type(model_type)
145 |
146 |
raise Exception(f"Could not load {model_type}")
147 |
# Generate embeddings
148 |
149 |
get_esm_embeddings(model, tokenizer, sequences, device, average=average,
150 |
print_updates=print_updates, savepath=path_to_output, save_at_end=False,
151 |
152 |
153 |
raise Exception(f"Could not generate {model_type} embeddings")
154 |
155 |
if model_type=="prot_t5_xl_half_uniref50_enc":
156 |
# Load model
157 |
158 |
model, tokenizer, device = load_prott5()
159 |
160 |
raise Exception(f"Could not load {model_type}")
161 |
# Generate embeddings
162 |
163 |
get_prott5_embeddings(model, tokenizer, sequences, device, average=average,
164 |
print_updates=print_updates, savepath=path_to_output, save_at_end=False,
165 |
166 |
167 |
raise Exception(f"Could not generate {model_type} embeddings")
168 |
169 |
170 |
def embed_dataset_for_benchmark(fuson_ckpts=None, input_data_path=None, input_fname=None, average=True, seq_col='seq', benchmark_fusonplm=False, benchmark_esm=False, benchmark_fo_puncta_ml=False, benchmark_prott5=False, overwrite=False,max_length=None):
171 |
# make directory for embeddings inside benchmarking dataset if one doesn't already eist
172 |
173 |
174 |
# Extract input file name from configs
175 |
emb_type_tag ='average' if average else '2D'
176 |
177 |
all_embedding_paths = dict() # dictionary organized where embedding path points to model, epoch
178 |
179 |
# make the embedding files. Put them in an embedding directory
180 |
if benchmark_fusonplm:
181 |
182 |
183 |
log_update(f"\nMaking Fuson-PLM embeddings")
184 |
# make subdirs for all the
185 |
if type(fuson_ckpts)==dict:
186 |
for model_name, epoch_list in fuson_ckpts.items():
187 |
188 |
for epoch in epoch_list:
189 |
# Assemble ckpt path and throw error if it doesn't exist
190 |
fuson_ckpt_path = f'../../training/checkpoints/{model_name}/checkpoint_epoch_{epoch}'
191 |
if not(os.path.exists(fuson_ckpt_path)): raise Exception(f"Error. Cannot find ckpt path: {fuson_ckpt_path}")
192 |
193 |
# Make output directory and output embedding path
194 |
embedding_output_dir = f'embeddings/fuson_plm/{model_name}/epoch{epoch}'
195 |
embedding_output_path = f'{embedding_output_dir}/{input_fname}_{emb_type_tag}_embeddings.pkl'
196 |
197 |
198 |
# Make dictionary item
199 |
model_type = 'fuson_plm'
200 |
all_embedding_paths[embedding_output_path] = {
201 |
'model_type': model_type,
202 |
'model': model_name,
203 |
'epoch': epoch
204 |
205 |
206 |
# Create embeddings (or skip if they're already made)
207 |
log_update(f"\tUsing ckpt {fuson_ckpt_path} and saving results to {embedding_output_path}...")
208 |
embed_dataset(input_data_path, embedding_output_path,
209 |
seq_col=seq_col, model_type=model_type,
210 |
fuson_ckpt_path=fuson_ckpt_path, average=average,
211 |
212 |
213 |
elif fuson_ckpts=="FusOn-pLM":
214 |
model_name = "best"
215 |
216 |
217 |
# Assemble ckpt path and throw error if it doesn't exist
218 |
fuson_ckpt_path = "../../.." # go back to the FusOn-pLM directory to find the best ckpt
219 |
if not(os.path.exists(fuson_ckpt_path)): raise Exception(f"Error. Cannot find ckpt path: {fuson_ckpt_path}")
220 |
221 |
# Make output directory and output embedding path
222 |
embedding_output_dir = f'embeddings/fuson_plm/{model_name}'
223 |
embedding_output_path = f'{embedding_output_dir}/{input_fname}_{emb_type_tag}_embeddings.pkl'
224 |
225 |
226 |
# Make dictionary item
227 |
model_type = 'fuson_plm'
228 |
all_embedding_paths[embedding_output_path] = {
229 |
'model_type': model_type,
230 |
'model': model_name,
231 |
'epoch': None
232 |
233 |
234 |
# Create embeddings (or skip if they're already made)
235 |
log_update(f"\tUsing ckpt {fuson_ckpt_path} and saving results to {embedding_output_path}...")
236 |
embed_dataset(input_data_path, embedding_output_path,
237 |
seq_col=seq_col, model_type=model_type,
238 |
fuson_ckpt_path=fuson_ckpt_path, average=average,
239 |
240 |
241 |
242 |
raise Exception(f"Error. fuson_ckpts should be a dict or str")
243 |
244 |
# make the embedding files. Put them in an embedding directory
245 |
if benchmark_esm:
246 |
247 |
248 |
# make output path
249 |
embedding_output_path = f'embeddings/esm2_t33_650M_UR50D/{input_fname}_{emb_type_tag}_embeddings.pkl'
250 |
251 |
# Make dictioary item
252 |
model_type = 'esm2_t33_650M_UR50D'
253 |
all_embedding_paths[embedding_output_path] = {
254 |
'model_type': model_type,
255 |
'model': model_type,
256 |
'epoch': np.nan
257 |
258 |
259 |
log_update(f"\nMaking ESM-2-650M embeddings for {input_data_path} and saving results to {embedding_output_path}...")
260 |
embed_dataset(input_data_path, embedding_output_path,
261 |
seq_col=seq_col, model_type=model_type,
262 |
fuson_ckpt_path = None, average=average,
263 |
264 |
265 |
266 |
if benchmark_prott5:
267 |
268 |
269 |
# make output path
270 |
embedding_output_path = f'embeddings/prot_t5_xl_half_uniref50_enc/{input_fname}_{emb_type_tag}_embeddings.pkl'
271 |
272 |
# Make dictioary item
273 |
model_type = 'prot_t5_xl_half_uniref50_enc'
274 |
all_embedding_paths[embedding_output_path] = {
275 |
'model_type': model_type,
276 |
'model': model_type,
277 |
'epoch': np.nan
278 |
279 |
280 |
log_update(f"\nMaking ProtT5-XL-UniRef50 embeddings for {input_data_path} and saving results to {embedding_output_path}...")
281 |
embed_dataset(input_data_path, embedding_output_path,
282 |
seq_col=seq_col, model_type=model_type,
283 |
fuson_ckpt_path = None, average=average,
284 |
285 |
286 |
287 |
if benchmark_fo_puncta_ml:
288 |
embedding_output_path =f'FOdb_physicochemical_embeddings.pkl'
289 |
# Make dictionary item
290 |
all_embedding_paths[embedding_output_path] = {
291 |
'model_type': 'fo_puncta_ml',
292 |
'model': 'fo_puncta_ml',
293 |
'epoch': np.nan
294 |
295 |
296 |
return all_embedding_paths
1 |
# Embedding exploration
2 |
3 |
This folder contains all the data and code needed to run embedding exploration (Fig. S3).
4 |
5 |
### Data download
6 |
To help select TF (transcription factor) and Kinase-containing fusions for investigation (Fig. S3a), Supplementary Table 3 from [Salokas et al. 2020]( was downloaded as a reference of transcription factors and kinases.
7 |
8 |
9 |
10 |
└── embedding_exploration/
11 |
└── data/
12 |
├── salokas_2020_tableS3.csv
13 |
├── tf_and_kinase_fusions.csv
14 |
├── top_genes.csv
15 |
16 |
17 |
- **`data/salokas_2020_tableS3.csv`**: Supplementary Table 3 from [Salokas et al. 2020](
18 |
- **`data/tf_and_kinase_fusions.csv`**: set of TF::TF and Kinase::Kinase fusion oncoproteins from FusOn-DB database. Curated in ``
19 |
- **`data/top_genes.csv`**: fusion oncoproteins (and their head and tail components) visualized in Fig. S3b. Sequences for head and tail components were pulled from the best-aligned sequences in `fuson_plm/data/blast/blast_outputs/best_htg_alignments_swissprot_seqs.pkl`
20 |
21 |
### Plotting
22 |
23 |
Run `` to regenerate plots in Figure S3:
24 |
25 |
26 |
# Dictionary: key = run name, values = epochs. (use this option if you've trained your own model)
27 |
# # Or "FusOn-pLM" to use official model
28 |
29 |
30 |
# Type of dim reduction
31 |
32 |
33 |
34 |
# Overwriting configs
35 |
PERMISSION_TO_OVERWRITE = False # if False, script will halt if it believes these embeddings have already been made.
36 |
37 |
38 |
To run, use:
39 |
40 |
nohup python > plot.out 2> plot.err &
41 |
42 |
- All **results** are stored in `embedding_exploration/results/<timestamp>`, where `timestamp` is a unique string encoding the date and time when you started training.
43 |
44 |
Below are the FusOn-pLM paper results in `results/final/umap_plots/fuson_plm/best/`:
45 |
46 |
47 |
48 |
└── embedding_exploration/
49 |
└── results/final/umap_plots/fuson_plm/best/
50 |
└── favorites/
51 |
├── umap_favorites_source_data.csv
52 |
├── umap_favorites_visualization.png
53 |
└── tf_and_kinase/
54 |
├── umap_tf_and_kinase_fusions_source_data.csv ├── umap_tf_and_kinase_fusions_visualization.png
55 |
56 |
57 |
- **`favorites/umap_favorites_visualization.png`**: Fig. S3b, with the data directly plotted stored in `favorites/umap_favorites_source_data.csv`
58 |
- **`tf_and_kinase/umap_tf_and_kinase_fusions_visualization.png`**: Fig. S3a, with the data directly plotted stored in `tf_and_kinase/umap_tf_and_kinase_fusions_source_data.csv`.
File without changes
1 |
# Dictionary: key = run name, values = epochs. (use this option if you've trained your own model)
2 |
# # Or, List: item goes to path (use this option if you're using the "best" ckpt from FusOn-pLM paper)
3 |
4 |
5 |
# Type of dim reduction
6 |
7 |
8 |
9 |
# Overwriting configs
10 |
PERMISSION_TO_OVERWRITE = False # if False, script will halt if it believes these embeddings have already been made.
1 |
2 |
oid sha256:d8bebc0871a4329015a3c6c7843f5bbc86c48811b2a836c42f1ef46b37f4282a
3 |
size 19626
1 |
2 |
oid sha256:372321137ed12b2f8aa7c4891dafd0e88d64d5c5d0ea9c6f3a0aa9d897e8ead6
3 |
size 557262
1 |
2 |
oid sha256:33d568fe413107318caebd5ee260ee66fe8571461ed8f8d1b47888441f7b5034
3 |
size 16695
1 |
import pandas as pd
2 |
import numpy as np
3 |
import pickle
4 |
from sklearn.manifold import TSNE
5 |
import matplotlib.font_manager as fm
6 |
from matplotlib.font_manager import FontProperties
7 |
import matplotlib.pyplot as plt
8 |
import matplotlib.gridspec as gridspec
9 |
import matplotlib.patches as patches
10 |
import seaborn as sns
11 |
import umap
12 |
import os
13 |
14 |
from fuson_plm.benchmarking.embed import embed_dataset_for_benchmark
15 |
import fuson_plm.benchmarking.embedding_exploration.config as config
16 |
from fuson_plm.utils.visualizing import set_font
17 |
from fuson_plm.utils.constants import TCGA_CODES, FODB_CODES, VALID_AAS, DELIMITERS
18 |
from fuson_plm.utils.logging import get_local_time, open_logfile, log_update, print_configpy
19 |
20 |
21 |
def get_dimred_embeddings(embeddings, dimred_type="umap"):
22 |
if dimred_type=="umap":
23 |
dimred_embeddings = get_umap_embeddings(embeddings)
24 |
return dimred_embeddings
25 |
if dimred_type=="tsne":
26 |
dimred_embeddings = get_tsne_embeddings(embeddings)
27 |
return dimred_embeddings
28 |
29 |
def get_tsne_embeddings(embeddings):
30 |
embeddings = np.array(embeddings)
31 |
tsne = TSNE(n_components=2, random_state=42,perplexity=5)
32 |
tsne_embeddings = tsne.fit_transform(embeddings)
33 |
return tsne_embeddings
34 |
35 |
def get_umap_embeddings(embeddings):
36 |
embeddings = np.array(embeddings)
37 |
umap_model = umap.UMAP(n_components=2, random_state=42, n_neighbors=15, metric='euclidean') # default parameters for UMAP
38 |
umap_embeddings = umap_model.fit_transform(embeddings)
39 |
return umap_embeddings
40 |
41 |
def plot_half_filled_circle(ax, x, y, left_color, right_color, size=100):
42 |
43 |
Plots a circle filled in halves with specified colors.
44 |
45 |
46 |
- ax: Matplotlib axis to draw on.
47 |
- x, y: Coordinates of the marker.
48 |
- left_color: Color of the left half.
49 |
- right_color: Color of the right half.
50 |
- size: Size of the marker.
51 |
52 |
radius = (size ** 0.5) / 100 # Scale the radius
53 |
# Create left half-circle (0° to 180°)
54 |
left_half = patches.Wedge((x, y), radius, 90, 270, color=left_color, ec="black")
55 |
# Create right half-circle (180° to 360°)
56 |
right_half = patches.Wedge((x, y), radius, 270, 90, color=right_color, ec="black")
57 |
58 |
# Add both halves to the plot
59 |
60 |
61 |
62 |
def plot_umap_scatter_tftf_kk(df, filename="umap.png"):
63 |
64 |
Plots a 2D scatterplot of UMAP coordinates with different markers and colors based on 'type'.
65 |
Only for TF::TF and Kinase::Kinase fusions
66 |
67 |
68 |
- df (pd.DataFrame): DataFrame containing 'umap1', 'umap2', 'sequence', and 'type' columns.
69 |
70 |
71 |
72 |
# Define colors for each type
73 |
colors = {
74 |
"TF": "pink",
75 |
"Kinase": "orange"
76 |
77 |
78 |
# Define marker types and colors for each combination
79 |
marker_colors = {
80 |
"TF::TF": colors["TF"],
81 |
"Kinase::Kinase": colors["Kinase"],
82 |
83 |
84 |
# Create the plot
85 |
fig, ax = plt.subplots(figsize=(10, 8))
86 |
x_min, x_max = df["umap1"].min() - 1, df["umap1"].max() + 1
87 |
y_min, y_max = df["umap2"].min() - 1, df["umap2"].max() + 1
88 |
ax.set_xlim(x_min, x_max)
89 |
ax.set_ylim(y_min, y_max)
90 |
91 |
# Plot each point with the specified half-filled marker
92 |
for i in range(len(df)):
93 |
row = df.iloc[i]
94 |
marker_type = row["fusion_type"]
95 |
x, y = row["umap1"], row["umap2"]
96 |
color = marker_colors[marker_type]
97 |
98 |
ax.scatter(x, y, color=color, s=15, edgecolors="black", linewidth=0.5)
99 |
100 |
# Add custom legend
101 |
legend_elements = [
102 |
patches.Patch(facecolor="pink", edgecolor="black", label="TF::TF"),
103 |
patches.Patch(facecolor="orange", edgecolor="black", label="Kinase::Kinase")
104 |
105 |
ax.legend(handles=legend_elements, title="Fusion Type", fontsize=16, title_fontsize=16)
106 |
107 |
# Add labels and title
108 |
plt.xlabel("UMAP 1", fontsize=20)
109 |
plt.ylabel("UMAP 2", fontsize=20)
110 |
plt.title("FusOn-pLM-embedded Transcription Factor and Kinase Fusions", fontsize=20)
111 |
112 |
113 |
# Save and show the plot
114 |
plt.savefig(filename, dpi=300)
115 |
116 |
117 |
def plot_umap_scatter_half_filled(df, filename="umap.png"):
118 |
119 |
Plots a 2D scatterplot of UMAP coordinates with different markers and colors based on 'type'.
120 |
121 |
122 |
- df (pd.DataFrame): DataFrame containing 'umap1', 'umap2', 'sequence', and 'type' columns.
123 |
124 |
# Define colors for each type
125 |
colors = {
126 |
"TF": "pink",
127 |
"Kinase": "orange",
128 |
"Other": "grey"
129 |
130 |
131 |
# Define marker types and colors for each combination
132 |
marker_colors = {
133 |
"TF::TF": {"left": colors["TF"], "right": colors["TF"]},
134 |
"TF::Other": {"left": colors["TF"], "right": colors["Other"]},
135 |
"Other::TF": {"left": colors["Other"], "right": colors["TF"]},
136 |
"Kinase::Kinase": {"left": colors["Kinase"], "right": colors["Kinase"]},
137 |
"Kinase::Other": {"left": colors["Kinase"], "right": colors["Other"]},
138 |
"Other::Kinase": {"left": colors["Other"], "right": colors["Kinase"]},
139 |
"Kinase::TF": {"left": colors["Kinase"], "right": colors["TF"]},
140 |
"TF::Kinase": {"left": colors["TF"], "right": colors["Kinase"]},
141 |
"Other::Other": {"left": colors["Other"], "right": colors["Other"]}
142 |
143 |
144 |
# Create the plot
145 |
fig, ax = plt.subplots(figsize=(10, 8))
146 |
x_min, x_max = df["umap1"].min() - 1, df["umap1"].max() + 1
147 |
y_min, y_max = df["umap2"].min() - 1, df["umap2"].max() + 1
148 |
ax.set_xlim(x_min, x_max)
149 |
ax.set_ylim(y_min, y_max)
150 |
151 |
# Plot each point with the specified half-filled marker
152 |
for i in range(len(df)):
153 |
row = df.iloc[i]
154 |
marker_type = row["fusion_type"]
155 |
x, y = row["umap1"], row["umap2"]
156 |
left_color = marker_colors[marker_type]["left"]
157 |
right_color = marker_colors[marker_type]["right"]
158 |
plot_half_filled_circle(ax, x, y, left_color, right_color, size=100)
159 |
160 |
# Add custom legend
161 |
legend_elements = [
162 |
patches.Patch(facecolor="pink", edgecolor="black", label="TF"),
163 |
patches.Patch(facecolor="orange", edgecolor="black", label="Kinase"),
164 |
patches.Patch(facecolor="grey", edgecolor="black", label="Other")
165 |
166 |
ax.legend(handles=legend_elements, title="Type")
167 |
168 |
# Add labels and title
169 |
plt.xlabel("UMAP 1")
170 |
plt.ylabel("UMAP 2")
171 |
plt.title("UMAP Scatter Plot")
172 |
173 |
174 |
# Save and show the plot
175 |
plt.savefig(filename, dpi=300)
176 |
177 |
178 |
def get_gene_type(gene, d):
179 |
if gene in d:
180 |
if d[gene] == 'kinase':
181 |
return 'Kinase'
182 |
if d[gene] == 'tf':
183 |
return 'TF'
184 |
185 |
return 'Other'
186 |
187 |
def get_tf_and_kinase_fusions_dataset():
188 |
# Load TF and Kinase Fusions
189 |
tf_kinase_parts = pd.read_csv("data/salokas_2020_tableS3.csv")
190 |
191 |
ht_tf_kinase_dict = dict(zip(tf_kinase_parts['Gene'],tf_kinase_parts['Kinase or TF']))
192 |
193 |
# This one has each row with one fusiongene name
194 |
fuson_ht_db = pd.read_csv("../../data/blast/fuson_ht_db.csv")
195 |
fuson_ht_db[['hg','tg']] = fuson_ht_db['fusiongenes'].str.split("::",expand=True)
196 |
197 |
fuson_ht_db['hg_type'] = fuson_ht_db['hg'].apply(lambda x: get_gene_type(x, ht_tf_kinase_dict))
198 |
fuson_ht_db['tg_type'] = fuson_ht_db['tg'].apply(lambda x: get_gene_type(x, ht_tf_kinase_dict))
199 |
fuson_ht_db['fusion_type'] = fuson_ht_db['hg_type']+'::'+fuson_ht_db['tg_type']
200 |
201 |
# Keep 100 things in each category
202 |
categories = pd.DataFrame(fuson_ht_db['fusion_type'].value_counts()).reset_index()['index'].tolist()
203 |
categories = ["TF::TF","Kinase::Kinase"] # manually set some easier categories
204 |
205 |
plot_df = None
206 |
207 |
for i, cat in enumerate(categories):
208 |
random_sample = fuson_ht_db.loc[fuson_ht_db['fusion_type']==cat].reset_index(drop=True)
209 |
#random_sample = random_sample.sample(n=100, random_state=1).reset_index(drop=True)
210 |
if i==0:
211 |
plot_df = random_sample
212 |
213 |
plot_df = pd.concat([plot_df,random_sample],axis=0).reset_index(drop=True)
214 |
215 |
216 |
217 |
# Now, need to add in the embeddings
218 |
plot_df = plot_df[['aa_seq','fusiongenes','fusion_type','type']].rename(
219 |
220 |
221 |
222 |
return plot_df
223 |
224 |
def make_tf_and_kinase_fusions_plot(seqs_with_embeddings, savedir = '', dimred_type='umap'):
225 |
fuson_db = pd.read_csv("../../data/fuson_db.csv")
226 |
seq_id_dict = dict(zip(fuson_db['aa_seq'],fuson_db['seq_id']))
227 |
228 |
# add sequences so we can save results/sequence
229 |
data = seqs_with_embeddings[[f'{dimred_type}1',f'{dimred_type}2','sequence','fusion_type','ID']]
230 |
data['seq_id'] = data['sequence'].map(seq_id_dict)
231 |
232 |
tfkinase_save_dir = f"{savedir}"
233 |
234 |
235 |
236 |
237 |
def tf_and_kinase_fusions_plot(dimred_types, output_dir):
238 |
239 |
Makes the embeddings, THEN calls the plot. only on the four favorites
240 |
241 |
plot_df = get_tf_and_kinase_fusions_dataset()
242 |
243 |
244 |
# path to the pkl file with FOdb embeddings
245 |
246 |
all_embedding_paths = embed_dataset_for_benchmark(
247 |
248 |
input_data_path='data/tf_and_kinase_fusions.csv', input_fname=input_fname,
249 |
average=True, seq_col='sequence',
250 |
251 |
252 |
253 |
254 |
255 |
# For each of the models we are benchmarking, load embeddings and make plots
256 |
log_update("\nEmbedding sequences")
257 |
# loop through the embedding paths and train each one
258 |
for embedding_path, details in all_embedding_paths.items():
259 |
log_update(f"\tBenchmarking embeddings at: {embedding_path}")
260 |
261 |
with open(embedding_path, "rb") as f:
262 |
embeddings = pickle.load(f)
263 |
264 |
raise Exception(f"Cannot read embeddings from {embedding_path}")
265 |
266 |
# combine the embeddings and splits into one dataframe
267 |
seqs_with_embeddings = pd.DataFrame.from_dict(embeddings.items())
268 |
seqs_with_embeddings = seqs_with_embeddings.rename(columns={0: 'sequence', 1: 'embedding'}) # the column that was called FusOn-pLM is now called embedding
269 |
seqs_with_embeddings = pd.merge(seqs_with_embeddings, plot_df, on='sequence', how='inner')
270 |
# get UMAP transform of the embeddings
271 |
for dimred_type in dimred_types:
272 |
dimred_embeddings = get_dimred_embeddings(seqs_with_embeddings['embedding'].tolist(),dimred_type=dimred_type)
273 |
274 |
# turn the result into a dataframe, and add it to seqs_with_embeddings
275 |
data = pd.DataFrame(dimred_embeddings, columns=[f'{dimred_type}1', f'{dimred_type}2'])
276 |
# save the umap data!
277 |
model_name = "_".join(embedding_path.split('embeddings/')[1].split('/')[1:-1])
278 |
279 |
seqs_with_embeddings[[f'{dimred_type}1', f'{dimred_type}2']] = data
280 |
281 |
# make subdirectory
282 |
intermediate = '/'.join(embedding_path.split('embeddings/')[1].split('/')[0:-1])
283 |
cur_output_dir = f"{output_dir}/{dimred_type}_plots/{intermediate}/{input_fname}"
284 |
285 |
286 |
make_tf_and_kinase_fusions_plot(seqs_with_embeddings, savedir = cur_output_dir, dimred_type=dimred_type)
287 |
288 |
def make_fusion_v_parts_favorites_plot(seqs_with_embeddings, savedir = None, dimred_type='umap'):
289 |
290 |
Make plots showing that PAX3::FOXO1, EWS::FLI1, SS18::SSX1, EML4::ALK are embedded distinctly from their heads and tails
291 |
292 |
293 |
294 |
# Load one sequence each for four proteins in the test set: PAX3::FOXO1, EWS::FLI1, SS18::SSX1, EML4::ALK
295 |
data = pd.read_csv("data/top_genes.csv")
296 |
seqs_with_embeddings = pd.merge(seqs_with_embeddings, data, on="sequence")
297 |
seqs_with_embeddings["Type"] = [""]*len(seqs_with_embeddings)
298 |
299 |
300 |
] = "fusion_embeddings"
301 |
heads = seqs_with_embeddings.loc[seqs_with_embeddings["gene"].str.contains("::")]["gene"].str.split("::",expand=True)[0].tolist()
302 |
tails = seqs_with_embeddings.loc[seqs_with_embeddings["gene"].str.contains("::")]["gene"].str.split("::",expand=True)[1].tolist()
303 |
304 |
305 |
] = "h_embeddings"
306 |
307 |
308 |
] = "t_embeddings"
309 |
310 |
# make merge
311 |
merge = seqs_with_embeddings.loc[seqs_with_embeddings['gene'].str.contains('::')].reset_index(drop=True)[['gene','sequence']]
312 |
merge["head"] = merge["gene"].str.split("::",expand=True)[0]
313 |
merge["tail"] = merge["gene"].str.split("::",expand=True)[1]
314 |
merge = pd.merge(merge, seqs_with_embeddings[['gene','sequence']].rename(
315 |
columns={'gene': 'head', 'sequence': 'h_sequence'}),
316 |
317 |
318 |
merge = pd.merge(merge, seqs_with_embeddings[['gene','sequence']].rename(
319 |
columns={'gene': 'tail', 'sequence': 't_sequence'}),
320 |
321 |
322 |
323 |
324 |
325 |
# Define colors and markers
326 |
colors = {
327 |
'fusion_embeddings': '#cf9dfa', # old color #0C4A4D
328 |
'h_embeddings': '#eb8888', # Updated to original names; old color #619283
329 |
't_embeddings': '#5fa3e3', # Updated to original names; old color #619283
330 |
331 |
markers = {
332 |
'fusion_embeddings': 'o',
333 |
'h_embeddings': '^', # Updated to original names
334 |
't_embeddings': 'v' # Updated to original names
335 |
336 |
label_map = {
337 |
'fusion_embeddings': 'Fusion',
338 |
'h_embeddings': 'Head', # Updated label
339 |
't_embeddings': 'Tail', # Updated label
340 |
341 |
342 |
# Create a 2x3 grid of plots
343 |
fig, axes = plt.subplots(2, 3, figsize=(18, 12))
344 |
#fig, axes = plt.subplots(1, 4, figsize= (18, 7))
345 |
346 |
# Get the global min and max for the x and y axis ranges
347 |
all_tsne1 = seqs_with_embeddings[f'{dimred_type}1']
348 |
all_tsne2 = seqs_with_embeddings[f'{dimred_type}2']
349 |
x_min, x_max = all_tsne1.min(), all_tsne1.max()
350 |
y_min, y_max = all_tsne2.min(), all_tsne2.max()
351 |
x_min, x_max = [11, 16] # manually set range for cleaner plotting
352 |
y_min, y_max = [10, 22]
353 |
354 |
# Determine tick positions
355 |
x_ticks = np.arange(x_min, x_max + 1, 1)
356 |
y_ticks = np.arange(y_min, y_max + 1, 1)
357 |
358 |
# Flatten the axes array for easier iteration
359 |
axes = axes.flatten()
360 |
361 |
for i, ax in enumerate(axes):
362 |
# Extract the gene names from the current row
363 |
fgene_name = merge.loc[i, 'gene']
364 |
hgene = merge.loc[i, 'head']
365 |
tgene = merge.loc[i, 'tail']
366 |
367 |
# Filter tsne_embeddings for the relevant entries
368 |
tsne_data = seqs_with_embeddings[seqs_with_embeddings['gene'].isin([fgene_name, hgene, tgene])]
369 |
370 |
# Plot each type
371 |
for emb_type in tsne_data['Type'].unique():
372 |
subset = tsne_data[tsne_data['Type'] == emb_type]
373 |
ax.scatter(subset[f'{dimred_type}1'], subset[f'{dimred_type}2'], label=label_map[emb_type], color=colors[emb_type], marker=markers[emb_type], s=120, zorder=3)
374 |
375 |
376 |
label_transform = {
377 |
'tsne': 't-SNE',
378 |
'umap': 'UMAP'
379 |
380 |
ax.set_xlabel(f'{label_transform[dimred_type]} 1',fontsize=44)
381 |
ax.set_ylabel(f'{label_transform[dimred_type]} 2',fontsize=44)
382 |
ax.grid(True, which='both', linestyle='--', linewidth=0.5, color='gray', zorder=1)
383 |
384 |
# Set the same limits and ticks for all axes
385 |
ax.set_xlim(x_min, x_max)
386 |
ax.set_ylim(y_min, y_max)
387 |
ax.set_xticks(x_ticks)#\\, labelsize=24)
388 |
ax.set_yticks(y_ticks)#, labelsize=24)
389 |
390 |
# Rotate x-axis labels
391 |
ax.set_xticklabels(ax.get_xticks(), rotation=45, ha='right')
392 |
393 |
ax.tick_params(axis='x', labelsize=16)
394 |
ax.tick_params(axis='y', labelsize=16)
395 |
396 |
for label in ax.get_xticklabels():
397 |
398 |
for label in ax.get_yticklabels():
399 |
400 |
401 |
# Set font size for the legend if needed
402 |
if i == 0:
403 |
legend = ax.legend(fontsize=20, markerscale=2, loc='best')
404 |
for text in legend.get_texts():
405 |
406 |
407 |
# Adjust layout to prevent overlap
408 |
409 |
410 |
# Show the plot
411 |
412 |
413 |
# Save the figure
414 |
plt.savefig(f'{savedir}/{dimred_type}_favorites_visualization.png', dpi=300)
415 |
416 |
# Save the data
417 |
seq_to_id_dict = pd.read_csv("../../data/fuson_db.csv")
418 |
seq_to_id_dict = dict(zip(seq_to_id_dict['aa_seq'],seq_to_id_dict['seq_id']))
419 |
seqs_with_embeddings['seq_id'] = seqs_with_embeddings['sequence'].map(seq_to_id_dict)
420 |
421 |
422 |
def fusion_v_parts_favorites(dimred_types, output_dir):
423 |
424 |
Makes the embeddings, THEN calls the plot. only on the four favorites
425 |
426 |
427 |
# path to the pkl file with FOdb embeddings
428 |
429 |
all_embedding_paths = embed_dataset_for_benchmark(
430 |
431 |
input_data_path='data/top_genes.csv', input_fname=input_fname,
432 |
average=True, seq_col='sequence',
433 |
434 |
435 |
436 |
437 |
438 |
# For each of the models we are benchmarking, load embeddings and make plots
439 |
log_update("\nEmbedding sequences")
440 |
# loop through the embedding paths and train each one
441 |
for embedding_path, details in all_embedding_paths.items():
442 |
log_update(f"\tBenchmarking embeddings at: {embedding_path}")
443 |
444 |
with open(embedding_path, "rb") as f:
445 |
embeddings = pickle.load(f)
446 |
447 |
raise Exception(f"Cannot read embeddings from {embedding_path}")
448 |
449 |
# combine the embeddings and splits into one dataframe
450 |
seqs_with_embeddings = pd.DataFrame.from_dict(embeddings.items())
451 |
seqs_with_embeddings = seqs_with_embeddings.rename(columns={0: 'sequence', 1: 'embedding'}) # the column that was called FusOn-pLM is now called embedding
452 |
453 |
# get UMAP transform of the embeddings
454 |
for dimred_type in dimred_types:
455 |
dimred_embeddings = get_dimred_embeddings(seqs_with_embeddings['embedding'].tolist(),dimred_type=dimred_type)
456 |
457 |
# turn the result into a dataframe, and add it to seqs_with_embeddings
458 |
data = pd.DataFrame(dimred_embeddings, columns=[f'{dimred_type}1', f'{dimred_type}2'])
459 |
# save the umap data!
460 |
model_name = "_".join(embedding_path.split('embeddings/')[1].split('/')[1:-1])
461 |
462 |
seqs_with_embeddings[[f'{dimred_type}1', f'{dimred_type}2']] = data
463 |
464 |
# make subdirectory
465 |
intermediate = '/'.join(embedding_path.split('embeddings/')[1].split('/')[0:-1])
466 |
cur_output_dir = f"{output_dir}/{dimred_type}_plots/{intermediate}/{input_fname}"
467 |
468 |
469 |
make_fusion_v_parts_favorites_plot(seqs_with_embeddings, savedir = cur_output_dir, dimred_type=dimred_type)
470 |
471 |
def main():
472 |
# make directory to save results
473 |
474 |
output_dir = f'results/{get_local_time()}'
475 |
476 |
477 |
dimred_types = []
478 |
if config.PLOT_UMAP:
479 |
480 |
481 |
482 |
if config.PLOT_TSNE:
483 |
484 |
485 |
486 |
487 |
with open_logfile(f'{output_dir}/embedding_exploration_log.txt'):
488 |
489 |
# make the disinct embeddings plot
490 |
fusion_v_parts_favorites(dimred_types, output_dir)
491 |
492 |
tf_and_kinase_fusions_plot(dimred_types, output_dir)
493 |
494 |
495 |
if __name__ == "__main__":
496 |
1 |
2 |
oid sha256:28c0b51f513da01df3dee3c4e71aa0c583bd57d9878137bdac9e7ebc704694e4
3 |
size 17383
![]() |
1 |
2 |
oid sha256:b26a5a6c2f8f54225fd46f01dab52813532438732624561af8e2e4ad005e5dc7
3 |
size 570073
![]() |
81 |
82 |
nohup python > discover.out 2> discover.err &
83 |
84 |
- All **results** are stored in `
85 |
86 |
Below are the FusOn-pLM paper results in `results/final`:
87 |
81 |
82 |
nohup python > discover.out 2> discover.err &
83 |
84 |
- All **results** are stored in `mutation_prediction/results/<timestamp>`, where `timestamp` is a unique string encoding the date and time when you started training.
85 |
86 |
Below are the FusOn-pLM paper results in `results/final`:
87 |
5 |
import pickle
6 |
import os
7 |
8 |
from fuson_plm.benchmarking.xgboost_predictor import train_final_predictor, evaluate_predictor
9 |
from fuson_plm.benchmarking.embed import embed_dataset_for_benchmark
10 |
import fuson_plm.benchmarking.puncta.config as config
11 |
from fuson_plm.benchmarking.puncta.plot import make_all_final_bar_charts
5 |
import pickle
6 |
import os
7 |
8 |
from fuson_plm.benchmarking.xgboost_predictor import train_final_predictor, evaluate_predictor
9 |
from fuson_plm.benchmarking.embed import embed_dataset_for_benchmark
10 |
import fuson_plm.benchmarking.puncta.config as config
11 |
from fuson_plm.benchmarking.puncta.plot import make_all_final_bar_charts
1 |
from sklearn.model_selection import train_test_split, StratifiedKFold
2 |
from sklearn.metrics import accuracy_score, recall_score, f1_score, precision_score, roc_auc_score, average_precision_score
3 |
from fuson_plm.utils.logging import log_update
4 |
import time
5 |
import xgboost as xgb
6 |
import numpy as np
7 |
import pandas as pd
8 |
9 |
def train_final_predictor(X_train, y_train, n_estimators=50,tree_method="hist"):
10 |
clf = xgb.XGBClassifier(n_estimators=n_estimators, tree_method=tree_method)
11 |
+, y_train)
12 |
return clf
13 |
14 |
def evaluate_predictor(clf,X_test,y_test,class1_thresh=None):
15 |
# Predicting the labels on test set
16 |
y_pred_test = clf.predict(X_test) # labels with automatic thresholds
17 |
y_pred_prob_test = clf.predict_proba(X_test)[:, 1]
18 |
if class1_thresh is not None: y_pred_customthresh_test = np.where(np.array(y_pred_prob_test) >= class1_thresh, 1, 0)
19 |
20 |
# Calculating metrics - automatic
21 |
accuracy = accuracy_score(y_test, y_pred_test)
22 |
precision = precision_score(y_test, y_pred_test)
23 |
recall = recall_score(y_test, y_pred_test)
24 |
f1 = f1_score(y_test, y_pred_test)
25 |
auroc_prob = roc_auc_score(y_test, y_pred_prob_test)
26 |
auprc_prob = average_precision_score(y_test, y_pred_prob_test)
27 |
auroc_label = roc_auc_score(y_test, y_pred_test)
28 |
auprc_label = average_precision_score(y_test, y_pred_test)
29 |
30 |
automatic_stats_df = pd.DataFrame(data={
31 |
'Accuracy': [accuracy],
32 |
'Precision': [precision],
33 |
'Recall': [recall],
34 |
'F1 Score': [f1],
35 |
'AUROC': [auroc_prob],
36 |
'AUROC Label': [auroc_label],
37 |
'AUPRC': [auprc_prob],
38 |
'AUPRC Label': [auprc_label]
39 |
40 |
41 |
# Calculating metrics - custom threshold (note that probability ones won't change)
42 |
if class1_thresh is not None:
43 |
accuracy_custom = accuracy_score(y_test, y_pred_customthresh_test)
44 |
precision_custom = precision_score(y_test, y_pred_customthresh_test)
45 |
recall_custom = recall_score(y_test, y_pred_customthresh_test)
46 |
f1_custom = f1_score(y_test, y_pred_customthresh_test)
47 |
auroc_prob_custom = roc_auc_score(y_test, y_pred_prob_test)
48 |
auprc_prob_custom = average_precision_score(y_test, y_pred_prob_test)
49 |
auroc_label_custom = roc_auc_score(y_test, y_pred_customthresh_test)
50 |
auprc_label_custom = average_precision_score(y_test, y_pred_customthresh_test)
51 |
52 |
custom_stats_df = pd.DataFrame(data={
53 |
'Accuracy': [accuracy_custom],
54 |
'Precision': [precision_custom],
55 |
'Recall': [recall_custom],
56 |
'F1 Score': [f1_custom],
57 |
'AUROC': [auroc_prob_custom],
58 |
'AUROC Label': [auroc_label_custom],
59 |
'AUPRC': [auprc_prob_custom],
60 |
'AUPRC Label': [auprc_label_custom]
61 |
62 |
63 |
custom_stats_df = None
64 |
65 |
return automatic_stats_df, custom_stats_df