svincoff commited on
Commit
c43fbc6
·
1 Parent(s): 3efa812

dependencies and embedding_exploration benchmark

Browse files
fuson_plm/README.md ADDED
@@ -0,0 +1,588 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Dependencies
2
+
3
+ Here we provied package versions needed to run FusOn-pLM code. For the project, Docker containers were used. We provide a pip list of what is inside the Docker container, as well as the images used for our containers.
4
+
5
+ ## pip installs
6
+
7
+ The following dependencies were used for all training and benchmarking except for the `puncta` benchmarks.
8
+ Note that after cloning the repository, you will need to run `pip install e .` outside the `fuson_plm` directory to install `fuson_plm` package.
9
+
10
+ Package Version Editable project location
11
+ ------------------------- -------------------- -------------------------
12
+ absl-py 1.4.0
13
+ aiohttp 3.8.4
14
+ aiosignal 1.3.1
15
+ apex 0.1
16
+ argon2-cffi 21.3.0
17
+ argon2-cffi-bindings 21.2.0
18
+ asttokens 2.2.1
19
+ astunparse 1.6.3
20
+ async-timeout 4.0.2
21
+ attrs 23.1.0
22
+ audioread 3.0.0
23
+ backcall 0.2.0
24
+ beautifulsoup4 4.12.2
25
+ bio 1.7.1
26
+ biopython 1.84
27
+ biothings-client 0.3.1
28
+ bleach 6.0.0
29
+ blis 0.7.10
30
+ cachetools 5.3.1
31
+ catalogue 2.0.9
32
+ certifi 2023.7.22
33
+ cffi 1.15.1
34
+ charset-normalizer 3.2.0
35
+ click 8.1.5
36
+ cloudpickle 2.2.1
37
+ cmake 3.27.1
38
+ comm 0.1.4
39
+ confection 0.1.1
40
+ contourpy 1.1.0
41
+ cubinlinker 0.3.0+2.g7c3675e
42
+ cuda-python 12.1.0rc5+1.g994d8d0
43
+ cudf 23.6.0
44
+ cugraph 23.6.0
45
+ cugraph-dgl 23.6.0
46
+ cugraph-service-client 23.6.0
47
+ cugraph-service-server 23.6.0
48
+ cuml 23.6.0
49
+ cupy-cuda12x 12.1.0
50
+ cycler 0.11.0
51
+ cymem 2.0.7
52
+ Cython 3.0.0
53
+ dask 2023.3.2
54
+ dask-cuda 23.6.0
55
+ dask-cudf 23.6.0
56
+ debugpy 1.6.7
57
+ decorator 5.1.1
58
+ defusedxml 0.7.1
59
+ distributed 2023.3.2.1
60
+ dm-tree 0.1.8
61
+ docker-pycreds 0.4.0
62
+ einops 0.6.1
63
+ exceptiongroup 1.1.2
64
+ execnet 2.0.2
65
+ executing 1.2.0
66
+ expecttest 0.1.3
67
+ fair-esm 2.0.0
68
+ fastjsonschema 2.18.0
69
+ fastrlock 0.8.1
70
+ filelock 3.12.2
71
+ flash-attn 2.0.4
72
+ fonttools 4.42.0
73
+ frozenlist 1.4.0
74
+ fsspec 2023.6.0
75
+ fuson-plm 1.0 /workspace/FusOn-pLM
76
+ gast 0.5.4
77
+ gdown 5.2.0
78
+ gitdb 4.0.11
79
+ GitPython 3.1.43
80
+ google-auth 2.22.0
81
+ google-auth-oauthlib 0.4.6
82
+ gprofiler-official 1.0.0
83
+ graphsurgeon 0.4.6
84
+ grpcio 1.56.2
85
+ huggingface-hub 0.25.2
86
+ hypothesis 5.35.1
87
+ idna 3.4
88
+ importlib-metadata 6.8.0
89
+ iniconfig 2.0.0
90
+ intel-openmp 2021.4.0
91
+ ipykernel 6.25.0
92
+ ipython 8.14.0
93
+ ipython-genutils 0.2.0
94
+ jedi 0.19.0
95
+ Jinja2 3.1.2
96
+ joblib 1.3.1
97
+ json5 0.9.14
98
+ jsonschema 4.18.6
99
+ jsonschema-specifications 2023.7.1
100
+ jupyter_client 8.3.0
101
+ jupyter_core 5.3.1
102
+ jupyter-tensorboard 0.2.0
103
+ jupyterlab 2.3.2
104
+ jupyterlab-pygments 0.2.2
105
+ jupyterlab-server 1.2.0
106
+ jupytext 1.15.0
107
+ kiwisolver 1.4.4
108
+ langcodes 3.3.0
109
+ librosa 0.9.2
110
+ lightning-utilities 0.11.8
111
+ llvmlite 0.40.1
112
+ locket 1.0.0
113
+ Markdown 3.4.4
114
+ markdown-it-py 3.0.0
115
+ MarkupSafe 2.1.3
116
+ matplotlib 3.7.2
117
+ matplotlib-inline 0.1.6
118
+ mdit-py-plugins 0.4.0
119
+ mdurl 0.1.2
120
+ mistune 3.0.1
121
+ mkl 2021.1.1
122
+ mkl-devel 2021.1.1
123
+ mkl-include 2021.1.1
124
+ mock 5.1.0
125
+ mpmath 1.3.0
126
+ msgpack 1.0.5
127
+ multidict 6.0.4
128
+ murmurhash 1.0.9
129
+ mygene 3.2.2
130
+ nbclient 0.8.0
131
+ nbconvert 7.7.3
132
+ nbformat 5.9.2
133
+ nest-asyncio 1.5.7
134
+ networkx 2.6.3
135
+ ninja 1.11.1
136
+ notebook 6.4.10
137
+ numba 0.57.1+1.gc785c8f1f
138
+ numpy 1.22.2
139
+ nvidia-cublas-cu12 12.4.5.8
140
+ nvidia-cuda-cupti-cu12 12.4.127
141
+ nvidia-cuda-nvrtc-cu12 12.4.127
142
+ nvidia-cuda-runtime-cu12 12.4.127
143
+ nvidia-cudnn-cu12 9.1.0.70
144
+ nvidia-cufft-cu12 11.2.1.3
145
+ nvidia-curand-cu12 10.3.5.147
146
+ nvidia-cusolver-cu12 11.6.1.9
147
+ nvidia-cusparse-cu12 12.3.1.170
148
+ nvidia-dali-cuda120 1.28.0
149
+ nvidia-nccl-cu12 2.21.5
150
+ nvidia-nvjitlink-cu12 12.4.127
151
+ nvidia-nvtx-cu12 12.4.127
152
+ nvidia-pyindex 1.0.9
153
+ nvtx 0.2.5
154
+ oauthlib 3.2.2
155
+ onnx 1.14.0
156
+ opencv 4.7.0
157
+ packaging 23.1
158
+ pandas 1.5.2
159
+ pandocfilters 1.5.0
160
+ parso 0.8.3
161
+ partd 1.4.0
162
+ pathy 0.10.2
163
+ pexpect 4.8.0
164
+ pickleshare 0.7.5
165
+ Pillow 9.2.0
166
+ pip 23.2.1
167
+ platformdirs 3.10.0
168
+ pluggy 1.2.0
169
+ ply 3.11
170
+ polygraphy 0.47.1
171
+ pooch 1.7.0
172
+ preshed 3.0.8
173
+ prettytable 3.8.0
174
+ prometheus-client 0.17.1
175
+ prompt-toolkit 3.0.39
176
+ protobuf 4.21.12
177
+ psutil 5.9.4
178
+ ptxcompiler 0.8.1+1.g4a94326
179
+ ptyprocess 0.7.0
180
+ pure-eval 0.2.2
181
+ py3Dmol 2.4.0
182
+ pyarrow 11.0.0
183
+ pyasn1 0.5.0
184
+ pyasn1-modules 0.3.0
185
+ pybind11 2.11.1
186
+ pycocotools 2.0+nv0.7.3
187
+ pycparser 2.21
188
+ pydantic 1.10.12
189
+ Pygments 2.16.1
190
+ pylibcugraph 23.6.0
191
+ pylibcugraphops 23.6.0
192
+ pylibraft 23.6.0
193
+ pynndescent 0.5.13
194
+ pynvml 11.4.1
195
+ pyparsing 3.0.9
196
+ PySocks 1.7.1
197
+ pytest 7.4.0
198
+ pytest-flakefinder 1.1.0
199
+ pytest-rerunfailures 12.0
200
+ pytest-shard 0.1.2
201
+ pytest-xdist 3.3.1
202
+ python-dateutil 2.8.2
203
+ python-hostlist 1.23.0
204
+ pytorch-lightning 2.4.0
205
+ pytorch-quantization 2.1.2
206
+ pytz 2023.3
207
+ PyYAML 6.0.1
208
+ pyzmq 25.1.0
209
+ raft-dask 23.6.0
210
+ referencing 0.30.2
211
+ regex 2023.6.3
212
+ requests 2.31.0
213
+ requests-oauthlib 1.3.1
214
+ resampy 0.4.2
215
+ rmm 23.6.0
216
+ rpds-py 0.9.2
217
+ rsa 4.9
218
+ safetensors 0.4.5
219
+ scikit-learn 1.2.0
220
+ scipy 1.11.1
221
+ seaborn 0.13.2
222
+ Send2Trash 1.8.2
223
+ sentencepiece 0.2.0
224
+ sentry-sdk 2.16.0
225
+ setproctitle 1.3.3
226
+ setuptools 68.0.0
227
+ six 1.16.0
228
+ smart-open 6.3.0
229
+ smmap 5.0.1
230
+ sortedcontainers 2.4.0
231
+ soundfile 0.12.1
232
+ soupsieve 2.4.1
233
+ spacy 3.6.0
234
+ spacy-legacy 3.0.12
235
+ spacy-loggers 1.0.4
236
+ sphinx-glpi-theme 0.3
237
+ srsly 2.4.7
238
+ stack-data 0.6.2
239
+ sympy 1.13.1
240
+ tabulate 0.9.0
241
+ tbb 2021.10.0
242
+ tblib 2.0.0
243
+ tensorboard 2.9.0
244
+ tensorboard-data-server 0.6.1
245
+ tensorboard-plugin-wit 1.8.1
246
+ tensorrt 8.6.1
247
+ terminado 0.17.1
248
+ thinc 8.1.10
249
+ threadpoolctl 3.2.0
250
+ thriftpy2 0.4.16
251
+ tinycss2 1.2.1
252
+ tokenizers 0.20.1
253
+ toml 0.10.2
254
+ tomli 2.0.1
255
+ toolz 0.12.0
256
+ torch 2.5.0
257
+ torch-tensorrt 2.0.0.dev0
258
+ torchdata 0.7.0a0
259
+ torchmetrics 1.5.0
260
+ torchtext 0.16.0a0
261
+ torchvision 0.16.0a0
262
+ tornado 6.3.2
263
+ tqdm 4.65.0
264
+ traitlets 5.9.0
265
+ transformer-engine 0.11.0+3f01b4f
266
+ transformers 4.45.2
267
+ treelite 3.2.0
268
+ treelite-runtime 3.2.0
269
+ triton 3.1.0
270
+ typer 0.9.0
271
+ types-dataclasses 0.6.6
272
+ typing_extensions 4.12.2
273
+ ucx-py 0.32.0
274
+ uff 0.6.9
275
+ umap-learn 0.5.6
276
+ urllib3 1.26.16
277
+ wandb 0.18.3
278
+ wasabi 1.1.2
279
+ wcwidth 0.2.6
280
+ webencodings 0.5.1
281
+ Werkzeug 2.3.6
282
+ wheel 0.41.1
283
+ xdoctest 1.0.2
284
+ xgboost 1.7.5
285
+ yarl 1.9.2
286
+ zict 3.0.0
287
+ zipp 3.16.2
288
+
289
+ The following packages and versions were used for the `puncta` benchmarks. A different environment was required to run ProtT5.
290
+
291
+ Package Version Editable project location
292
+ ------------------------- -------------------------- -------------------------
293
+ absl-py 2.1.0
294
+ aiohttp 3.9.3
295
+ aiosignal 1.3.1
296
+ annotated-types 0.6.0
297
+ anyio 4.8.0
298
+ apex 0.1
299
+ argon2-cffi 23.1.0
300
+ argon2-cffi-bindings 21.2.0
301
+ asttokens 2.4.1
302
+ astunparse 1.6.3
303
+ async-timeout 4.0.3
304
+ attrs 23.2.0
305
+ audioread 3.0.1
306
+ beautifulsoup4 4.12.3
307
+ bio 1.7.1
308
+ biopython 1.85
309
+ biothings_client 0.4.1
310
+ bleach 6.1.0
311
+ blis 0.7.11
312
+ cachetools 5.3.3
313
+ catalogue 2.0.10
314
+ certifi 2024.2.2
315
+ cffi 1.16.0
316
+ charset-normalizer 3.3.2
317
+ click 8.1.7
318
+ cloudpathlib 0.16.0
319
+ cloudpickle 3.0.0
320
+ cmake 3.29.0.1
321
+ comm 0.2.2
322
+ confection 0.1.4
323
+ contourpy 1.2.1
324
+ cuda-python 12.4.0rc7+3.ge75c8a9.dirty
325
+ cudf 24.2.0
326
+ cudnn 1.1.2
327
+ cugraph 24.2.0
328
+ cugraph-dgl 24.2.0
329
+ cugraph-service-client 24.2.0
330
+ cugraph-service-server 24.2.0
331
+ cuml 24.2.0
332
+ cupy-cuda12x 13.0.0
333
+ cycler 0.12.1
334
+ cymem 2.0.8
335
+ Cython 3.0.10
336
+ dask 2024.1.1
337
+ dask-cuda 24.2.0
338
+ dask-cudf 24.2.0
339
+ debugpy 1.8.1
340
+ decorator 5.1.1
341
+ defusedxml 0.7.1
342
+ distributed 2024.1.1
343
+ dm-tree 0.1.8
344
+ docker-pycreds 0.4.0
345
+ einops 0.7.0
346
+ exceptiongroup 1.2.0
347
+ execnet 2.0.2
348
+ executing 2.0.1
349
+ expecttest 0.1.3
350
+ fair-esm 2.0.0
351
+ fastjsonschema 2.19.1
352
+ fastrlock 0.8.2
353
+ filelock 3.13.3
354
+ flash-attn 2.4.2
355
+ fonttools 4.51.0
356
+ frozenlist 1.4.1
357
+ fsspec 2024.2.0
358
+ fuson-plm 1.0 /workspace/FusOn-pLM
359
+ gast 0.5.4
360
+ gdown 5.2.0
361
+ gitdb 4.0.12
362
+ GitPython 3.1.44
363
+ google-auth 2.29.0
364
+ google-auth-oauthlib 0.4.6
365
+ gprofiler-official 1.0.0
366
+ graphsurgeon 0.4.6
367
+ grpcio 1.62.1
368
+ h11 0.14.0
369
+ httpcore 1.0.7
370
+ httpx 0.28.1
371
+ huggingface-hub 0.27.1
372
+ hypothesis 5.35.1
373
+ idna 3.6
374
+ igraph 0.11.4
375
+ importlib_metadata 7.0.2
376
+ iniconfig 2.0.0
377
+ intel-openmp 2021.4.0
378
+ ipykernel 6.29.4
379
+ ipython 8.21.0
380
+ ipython-genutils 0.2.0
381
+ jedi 0.19.1
382
+ Jinja2 3.1.3
383
+ joblib 1.3.2
384
+ json5 0.9.24
385
+ jsonschema 4.21.1
386
+ jsonschema-specifications 2023.12.1
387
+ jupyter_client 8.6.1
388
+ jupyter_core 5.7.2
389
+ jupyter-tensorboard 0.2.0
390
+ jupyterlab 2.3.2
391
+ jupyterlab_pygments 0.3.0
392
+ jupyterlab-server 1.2.0
393
+ jupytext 1.16.1
394
+ kiwisolver 1.4.5
395
+ langcodes 3.3.0
396
+ lark 1.1.9
397
+ lazy_loader 0.4
398
+ librosa 0.10.1
399
+ lightning-thunder 0.1.0
400
+ lightning-utilities 0.11.2
401
+ llvmlite 0.42.0
402
+ locket 1.0.0
403
+ looseversion 1.3.0
404
+ Markdown 3.6
405
+ markdown-it-py 3.0.0
406
+ MarkupSafe 2.1.5
407
+ matplotlib 3.8.4
408
+ matplotlib-inline 0.1.6
409
+ mdit-py-plugins 0.4.0
410
+ mdurl 0.1.2
411
+ mistune 3.0.2
412
+ mkl 2021.1.1
413
+ mkl-devel 2021.1.1
414
+ mkl-include 2021.1.1
415
+ mock 5.1.0
416
+ mpmath 1.3.0
417
+ msgpack 1.0.8
418
+ multidict 6.0.5
419
+ murmurhash 1.0.10
420
+ mygene 3.2.2
421
+ nbclient 0.10.0
422
+ nbconvert 7.16.3
423
+ nbformat 5.10.4
424
+ nest-asyncio 1.6.0
425
+ networkx 2.6.3
426
+ ninja 1.11.1.1
427
+ notebook 6.4.10
428
+ numba 0.59.0+1.g20ae2b56c
429
+ numpy 1.24.4
430
+ nvfuser 0.1.6a0+a684e2a
431
+ nvidia-dali-cuda120 1.36.0
432
+ nvidia-nvimgcodec-cu12 0.2.0.7
433
+ nvidia-pyindex 1.0.9
434
+ nvtx 0.2.5
435
+ oauthlib 3.2.2
436
+ onnx 1.16.0
437
+ opencv 4.7.0
438
+ opt-einsum 3.3.0
439
+ optree 0.11.0
440
+ packaging 23.2
441
+ pandas 1.5.3
442
+ pandocfilters 1.5.1
443
+ parso 0.8.4
444
+ partd 1.4.1
445
+ pexpect 4.9.0
446
+ pillow 10.2.0
447
+ pip 24.0
448
+ platformdirs 4.2.0
449
+ pluggy 1.4.0
450
+ ply 3.11
451
+ polygraphy 0.49.8
452
+ pooch 1.8.1
453
+ preshed 3.0.9
454
+ prettytable 3.10.0
455
+ prometheus_client 0.20.0
456
+ prompt-toolkit 3.0.43
457
+ protobuf 4.24.4
458
+ psutil 5.9.4
459
+ ptyprocess 0.7.0
460
+ pure-eval 0.2.2
461
+ py3Dmol 2.4.2
462
+ pyarrow 14.0.1
463
+ pyasn1 0.6.0
464
+ pyasn1_modules 0.4.0
465
+ pybind11 2.12.0
466
+ pybind11_global 2.12.0
467
+ pycocotools 2.0+nv0.8.0
468
+ pycparser 2.22
469
+ pydantic 2.6.4
470
+ pydantic_core 2.16.3
471
+ Pygments 2.17.2
472
+ pylibcugraph 24.2.0
473
+ pylibcugraphops 24.2.0
474
+ pylibraft 24.2.0
475
+ pynndescent 0.5.13
476
+ pynvjitlink 0.1.13
477
+ pynvml 11.4.1
478
+ pyparsing 3.1.2
479
+ PySocks 1.7.1
480
+ pytest 8.1.1
481
+ pytest-flakefinder 1.1.0
482
+ pytest-rerunfailures 14.0
483
+ pytest-shard 0.1.2
484
+ pytest-xdist 3.5.0
485
+ python-dateutil 2.9.0.post0
486
+ python-hostlist 1.23.0
487
+ pytorch-lightning 2.5.0.post0
488
+ pytorch-quantization 2.1.2
489
+ pytorch-triton 3.0.0+a9bc1a364
490
+ pytz 2024.1
491
+ PyYAML 6.0.1
492
+ pyzmq 25.1.2
493
+ raft-dask 24.2.0
494
+ rapids-dask-dependency 24.2.0a0
495
+ referencing 0.34.0
496
+ regex 2023.12.25
497
+ requests 2.31.0
498
+ requests-oauthlib 2.0.0
499
+ rich 13.7.1
500
+ rmm 24.2.0
501
+ rpds-py 0.18.0
502
+ rsa 4.9
503
+ safetensors 0.5.2
504
+ scikit-learn 1.2.0
505
+ scipy 1.12.0
506
+ seaborn 0.13.2
507
+ Send2Trash 1.8.2
508
+ sentencepiece 0.2.0
509
+ sentry-sdk 2.20.0
510
+ setproctitle 1.3.4
511
+ setuptools 68.2.2
512
+ six 1.16.0
513
+ smart-open 6.4.0
514
+ smmap 5.0.2
515
+ sniffio 1.3.1
516
+ sortedcontainers 2.4.0
517
+ soundfile 0.12.1
518
+ soupsieve 2.5
519
+ soxr 0.3.7
520
+ spacy 3.7.4
521
+ spacy-legacy 3.0.12
522
+ spacy-loggers 1.0.5
523
+ sphinx_glpi_theme 0.6
524
+ srsly 2.4.8
525
+ stack-data 0.6.3
526
+ sympy 1.12
527
+ tabulate 0.9.0
528
+ tbb 2021.12.0
529
+ tblib 3.0.0
530
+ tensorboard 2.9.0
531
+ tensorboard-data-server 0.6.1
532
+ tensorboard-plugin-wit 1.8.1
533
+ tensorrt 8.6.3
534
+ terminado 0.18.1
535
+ texttable 1.7.0
536
+ thinc 8.2.3
537
+ threadpoolctl 3.3.0
538
+ thriftpy2 0.4.17
539
+ tinycss2 1.2.1
540
+ tokenizers 0.21.0
541
+ toml 0.10.2
542
+ tomli 2.0.1
543
+ toolz 0.12.1
544
+ torch 2.3.0a0+6ddf5cf85e.nv24.4
545
+ torch-tensorrt 2.3.0a0
546
+ torchdata 0.7.1a0
547
+ torchmetrics 1.6.1
548
+ torchtext 0.17.0a0
549
+ torchvision 0.18.0a0
550
+ tornado 6.4
551
+ tqdm 4.66.2
552
+ traitlets 5.9.0
553
+ transformer-engine 1.5.0+6a9edc3
554
+ transformers 4.48.0
555
+ treelite 4.0.0
556
+ typer 0.9.4
557
+ types-dataclasses 0.6.6
558
+ typing_extensions 4.10.0
559
+ ucx-py 0.36.0
560
+ uff 0.6.9
561
+ umap-learn 0.5.7
562
+ urllib3 1.26.18
563
+ wandb 0.19.4
564
+ wasabi 1.1.2
565
+ wcwidth 0.2.13
566
+ weasel 0.3.4
567
+ webencodings 0.5.1
568
+ Werkzeug 3.0.2
569
+ wheel 0.43.0
570
+ xdoctest 1.0.2
571
+ xgboost 1.7.5
572
+ yarl 1.9.4
573
+ zict 3.0.0
574
+ zipp 3.17.0
575
+
576
+ ## Docker
577
+
578
+ The following image was used for Container 1 (all code except puncta benchmark):
579
+
580
+ ```
581
+ nvcr.io/nvidia/pytorch:23.08-py3
582
+ ```
583
+
584
+ The following image was used for Container 2 (puncta benchmark):
585
+
586
+ ```
587
+ nvcr.io/nvidia/pytorch:24.04-py3
588
+ ```
fuson_plm/benchmarking/README.md ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Benchmarking
2
+
3
+ This outer directory for the benchmarks in FusOn-pLM has some utility functions stored in `.py` files.
4
+
5
+ ### embed.py
6
+
7
+ This file contains functions used to make and organize FusOn-pLM and ESM embeddings of benchmarking datasets. Its functions are used in all benchmarks.
8
+
9
+ ### xgboost_predictor.py
10
+
11
+ This file contains functions used to train XGBoost predictors, which are utilized in the `puncta` benchmark.
fuson_plm/benchmarking/__init__.py ADDED
File without changes
fuson_plm/benchmarking/embed.py ADDED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python file for making embeddings from a FusOn-pLM model for any dataset
2
+ from fuson_plm.utils.embedding import get_esm_embeddings, load_esm2_type, redump_pickle_dictionary, load_prott5, get_prott5_embeddings
3
+ from fuson_plm.utils.logging import log_update, open_logfile, print_configpy
4
+ from fuson_plm.utils.data_cleaning import find_invalid_chars
5
+ from fuson_plm.utils.constants import VALID_AAS
6
+ from fuson_plm.training.model import FusOnpLM
7
+ from transformers import AutoModelForMaskedLM, AutoTokenizer, AutoModel
8
+ import logging
9
+ import torch
10
+ import pickle
11
+ import os
12
+ import pandas as pd
13
+ import numpy as np
14
+
15
+ def validate_sequence_col(df, seq_col):
16
+ # if column isn't there, error
17
+ if seq_col not in list(df.columns):
18
+ raise Exception("Error: provided sequence column does not exist in the input dataframe")
19
+
20
+ # if column contains invalid characters, error
21
+ df['invalid_chars'] = df[seq_col].apply(lambda x: find_invalid_chars(x, VALID_AAS))
22
+ all_invalid_chars = set().union(*df['invalid_chars'])
23
+ df = df.drop(columns=['invalid_chars'])
24
+ if len(all_invalid_chars)>0:
25
+ raise Exception(f"Error: invalid characters {all_invalid_chars} found in the sequence column")
26
+
27
+ # make sure there are no duplicates
28
+ sequences = df[seq_col]
29
+ if len(set(sequences))<len(sequences): log_update("\tWARNING: input data has duplicate sequences")
30
+
31
+ def load_fuson_model(ckpt_path):
32
+ # Suppress warnings about newly initialized 'esm.pooler.dense.bias', 'esm.pooler.dense.weight' layers - these are not used to extract embeddings
33
+ logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR)
34
+
35
+ # Set device
36
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
37
+ print(f"Using device: {device}")
38
+
39
+ # Load model
40
+ model = AutoModel.from_pretrained(ckpt_path) # initialize model
41
+ tokenizer = AutoTokenizer.from_pretrained(ckpt_path) # initialize tokenizer
42
+
43
+ # Model to device and in eval mode
44
+ model.to(device)
45
+ model.eval() # disables dropout for deterministic results
46
+
47
+ return model, tokenizer, device
48
+
49
+ def get_fuson_embeddings(model, tokenizer, sequences, device, average=True, print_updates=False, savepath=None, save_at_end=False, max_length=2000):
50
+ # Correct save path to pickle if necessary
51
+ if savepath is not None:
52
+ if savepath[-4::] != '.pkl': savepath += '.pkl'
53
+
54
+ if print_updates: log_update(f"Dataset contains {len(sequences)} sequences.")
55
+
56
+ # If no max length was passed, just set it to the maximum in the dataset
57
+ max_seq_len = max([len(s) for s in sequences])
58
+ if max_length is None: max_length=max_seq_len+2 # add 2 for BOS, EOS
59
+
60
+ # Initialize an empty dict to store the ESM embeddings
61
+ embedding_dict = {}
62
+ # Iterate through the seqs
63
+ for i in range(len(sequences)):
64
+ sequence = sequences[i]
65
+ # Get the embeddings
66
+ with torch.no_grad():
67
+ # Tokenize the input sequence
68
+ inputs = tokenizer(sequence, return_tensors="pt", padding=True, truncation=True,max_length=max_length)
69
+ inputs = {k: v.to(device) for k, v in inputs.items()}
70
+
71
+ outputs = model(**inputs)
72
+ # The embeddings are in the last_hidden_state tensor
73
+ embedding = outputs.last_hidden_state
74
+ # remove extra dimension
75
+ embedding = embedding.squeeze(0)
76
+ # remove BOS and EOS tokens
77
+ embedding = embedding[1:-1, :]
78
+
79
+ # Convert embeddings to numpy array (if needed)
80
+ embedding = embedding.cpu().numpy()
81
+
82
+ # Average (if necessary)
83
+ if average:
84
+ embedding = embedding.mean(0)
85
+
86
+ # Add to dictionary
87
+ embedding_dict[sequence] = embedding
88
+
89
+ # Save individual embedding (if necessary)
90
+ if not(savepath is None) and not(save_at_end):
91
+ with open(savepath, 'ab+') as f:
92
+ d = {sequence: embedding}
93
+ pickle.dump(d, f)
94
+
95
+ # Print update (if necessary)
96
+ if print_updates: log_update(f"sequence {i+1}: {sequence[0:10]}...")
97
+
98
+ # Dump all at once at the end (if necessary)
99
+ if not(savepath is None):
100
+ # If saving for the first time, just dump it
101
+ if save_at_end:
102
+ with open(savepath, 'wb') as f:
103
+ pickle.dump(embedding_dict, f)
104
+ # If we've been saving all along and made it here without crashing, correct the pickle file so it can be loaded nicely
105
+ else:
106
+ redump_pickle_dictionary(savepath)
107
+
108
+ def embed_dataset(path_to_file, path_to_output, seq_col='aa_seq', model_type='fuson_plm', fuson_ckpt_path = None, average=True, overwrite=True, print_updates=False,max_length=2000):
109
+ # Make sure we aren't overwriting pre-existing embeddings
110
+ if os.path.exists(path_to_output):
111
+ if overwrite:
112
+ log_update(f"WARNING: these embeddings may already exist at {path_to_output} and will be overwritten")
113
+ else:
114
+ log_update(f"WARNING: these embeddings may already exist at {path_to_output}. Skipping.")
115
+ return None
116
+
117
+ dataset = pd.read_csv(path_to_file)
118
+ # Make sure the sequence column is valid
119
+ validate_sequence_col(dataset, seq_col)
120
+
121
+ sequences = dataset[seq_col].unique().tolist() # ensure all entries are unique
122
+
123
+ ### If FusOn-pLM: make fusion embeddings
124
+ if model_type=='fuson_plm':
125
+ if not(os.path.exists(fuson_ckpt_path)): raise Exception("FusOn-pLM ckpt path does not exist")
126
+
127
+ # Load model
128
+ try:
129
+ model, tokenizer, device = load_fuson_model(fuson_ckpt_path)
130
+ except:
131
+ raise Exception(f"Could not load FusOn-pLM from {fuson_ckpt_path}")
132
+
133
+ # Generate embeddigns
134
+ try:
135
+ get_fuson_embeddings(model, tokenizer, sequences, device, average=average,
136
+ print_updates=print_updates, savepath=path_to_output, save_at_end=False,
137
+ max_length=max_length)
138
+ except:
139
+ raise Exception("Could not generate FusOn-pLM embeddings")
140
+
141
+ if model_type=='esm2_t33_650M_UR50D':
142
+ # Load model
143
+ try:
144
+ model, tokenizer, device = load_esm2_type(model_type)
145
+ except:
146
+ raise Exception(f"Could not load {model_type}")
147
+ # Generate embeddings
148
+ try:
149
+ get_esm_embeddings(model, tokenizer, sequences, device, average=average,
150
+ print_updates=print_updates, savepath=path_to_output, save_at_end=False,
151
+ max_length=max_length)
152
+ except:
153
+ raise Exception(f"Could not generate {model_type} embeddings")
154
+
155
+ if model_type=="prot_t5_xl_half_uniref50_enc":
156
+ # Load model
157
+ try:
158
+ model, tokenizer, device = load_prott5()
159
+ except:
160
+ raise Exception(f"Could not load {model_type}")
161
+ # Generate embeddings
162
+ try:
163
+ get_prott5_embeddings(model, tokenizer, sequences, device, average=average,
164
+ print_updates=print_updates, savepath=path_to_output, save_at_end=False,
165
+ max_length=max_length)
166
+ except:
167
+ raise Exception(f"Could not generate {model_type} embeddings")
168
+
169
+
170
+ def embed_dataset_for_benchmark(fuson_ckpts=None, input_data_path=None, input_fname=None, average=True, seq_col='seq', benchmark_fusonplm=False, benchmark_esm=False, benchmark_fo_puncta_ml=False, benchmark_prott5=False, overwrite=False,max_length=None):
171
+ # make directory for embeddings inside benchmarking dataset if one doesn't already eist
172
+ os.makedirs('embeddings',exist_ok=True)
173
+
174
+ # Extract input file name from configs
175
+ emb_type_tag ='average' if average else '2D'
176
+
177
+ all_embedding_paths = dict() # dictionary organized where embedding path points to model, epoch
178
+
179
+ # make the embedding files. Put them in an embedding directory
180
+ if benchmark_fusonplm:
181
+ os.makedirs('embeddings/fuson_plm',exist_ok=True)
182
+
183
+ log_update(f"\nMaking Fuson-PLM embeddings")
184
+ # make subdirs for all the
185
+ if type(fuson_ckpts)==dict:
186
+ for model_name, epoch_list in fuson_ckpts.items():
187
+ os.makedirs(f'embeddings/fuson_plm/{model_name}',exist_ok=True)
188
+ for epoch in epoch_list:
189
+ # Assemble ckpt path and throw error if it doesn't exist
190
+ fuson_ckpt_path = f'../../training/checkpoints/{model_name}/checkpoint_epoch_{epoch}'
191
+ if not(os.path.exists(fuson_ckpt_path)): raise Exception(f"Error. Cannot find ckpt path: {fuson_ckpt_path}")
192
+
193
+ # Make output directory and output embedding path
194
+ embedding_output_dir = f'embeddings/fuson_plm/{model_name}/epoch{epoch}'
195
+ embedding_output_path = f'{embedding_output_dir}/{input_fname}_{emb_type_tag}_embeddings.pkl'
196
+ os.makedirs(embedding_output_dir,exist_ok=True)
197
+
198
+ # Make dictionary item
199
+ model_type = 'fuson_plm'
200
+ all_embedding_paths[embedding_output_path] = {
201
+ 'model_type': model_type,
202
+ 'model': model_name,
203
+ 'epoch': epoch
204
+ }
205
+
206
+ # Create embeddings (or skip if they're already made)
207
+ log_update(f"\tUsing ckpt {fuson_ckpt_path} and saving results to {embedding_output_path}...")
208
+ embed_dataset(input_data_path, embedding_output_path,
209
+ seq_col=seq_col, model_type=model_type,
210
+ fuson_ckpt_path=fuson_ckpt_path, average=average,
211
+ overwrite=overwrite,print_updates=True,
212
+ max_length=max_length)
213
+ elif fuson_ckpts=="FusOn-pLM":
214
+ model_name = "best"
215
+ os.makedirs(f'embeddings/fuson_plm/{model_name}',exist_ok=True)
216
+
217
+ # Assemble ckpt path and throw error if it doesn't exist
218
+ fuson_ckpt_path = "../../.." # go back to the FusOn-pLM directory to find the best ckpt
219
+ if not(os.path.exists(fuson_ckpt_path)): raise Exception(f"Error. Cannot find ckpt path: {fuson_ckpt_path}")
220
+
221
+ # Make output directory and output embedding path
222
+ embedding_output_dir = f'embeddings/fuson_plm/{model_name}'
223
+ embedding_output_path = f'{embedding_output_dir}/{input_fname}_{emb_type_tag}_embeddings.pkl'
224
+ os.makedirs(embedding_output_dir,exist_ok=True)
225
+
226
+ # Make dictionary item
227
+ model_type = 'fuson_plm'
228
+ all_embedding_paths[embedding_output_path] = {
229
+ 'model_type': model_type,
230
+ 'model': model_name,
231
+ 'epoch': None
232
+ }
233
+
234
+ # Create embeddings (or skip if they're already made)
235
+ log_update(f"\tUsing ckpt {fuson_ckpt_path} and saving results to {embedding_output_path}...")
236
+ embed_dataset(input_data_path, embedding_output_path,
237
+ seq_col=seq_col, model_type=model_type,
238
+ fuson_ckpt_path=fuson_ckpt_path, average=average,
239
+ overwrite=overwrite,print_updates=True,
240
+ max_length=max_length)
241
+ else:
242
+ raise Exception(f"Error. fuson_ckpts should be a dict or str")
243
+
244
+ # make the embedding files. Put them in an embedding directory
245
+ if benchmark_esm:
246
+ os.makedirs('embeddings/esm2_t33_650M_UR50D',exist_ok=True)
247
+
248
+ # make output path
249
+ embedding_output_path = f'embeddings/esm2_t33_650M_UR50D/{input_fname}_{emb_type_tag}_embeddings.pkl'
250
+
251
+ # Make dictioary item
252
+ model_type = 'esm2_t33_650M_UR50D'
253
+ all_embedding_paths[embedding_output_path] = {
254
+ 'model_type': model_type,
255
+ 'model': model_type,
256
+ 'epoch': np.nan
257
+ }
258
+
259
+ log_update(f"\nMaking ESM-2-650M embeddings for {input_data_path} and saving results to {embedding_output_path}...")
260
+ embed_dataset(input_data_path, embedding_output_path,
261
+ seq_col=seq_col, model_type=model_type,
262
+ fuson_ckpt_path = None, average=average,
263
+ overwrite=overwrite,print_updates=True,
264
+ max_length=max_length)
265
+
266
+ if benchmark_prott5:
267
+ os.makedirs('embeddings/prot_t5_xl_half_uniref50_enc',exist_ok=True)
268
+
269
+ # make output path
270
+ embedding_output_path = f'embeddings/prot_t5_xl_half_uniref50_enc/{input_fname}_{emb_type_tag}_embeddings.pkl'
271
+
272
+ # Make dictioary item
273
+ model_type = 'prot_t5_xl_half_uniref50_enc'
274
+ all_embedding_paths[embedding_output_path] = {
275
+ 'model_type': model_type,
276
+ 'model': model_type,
277
+ 'epoch': np.nan
278
+ }
279
+
280
+ log_update(f"\nMaking ProtT5-XL-UniRef50 embeddings for {input_data_path} and saving results to {embedding_output_path}...")
281
+ embed_dataset(input_data_path, embedding_output_path,
282
+ seq_col=seq_col, model_type=model_type,
283
+ fuson_ckpt_path = None, average=average,
284
+ overwrite=overwrite,print_updates=True,
285
+ max_length=max_length)
286
+
287
+ if benchmark_fo_puncta_ml:
288
+ embedding_output_path =f'FOdb_physicochemical_embeddings.pkl'
289
+ # Make dictionary item
290
+ all_embedding_paths[embedding_output_path] = {
291
+ 'model_type': 'fo_puncta_ml',
292
+ 'model': 'fo_puncta_ml',
293
+ 'epoch': np.nan
294
+ }
295
+
296
+ return all_embedding_paths
fuson_plm/benchmarking/embedding_exploration/README.md ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Embedding exploration
2
+
3
+ This folder contains all the data and code needed to run embedding exploration (Fig. S3).
4
+
5
+ ### Data download
6
+ To help select TF (transcription factor) and Kinase-containing fusions for investigation (Fig. S3a), Supplementary Table 3 from [Salokas et al. 2020](https://doi.org/10.1038/s41598-020-71040-8) was downloaded as a reference of transcription factors and kinases.
7
+
8
+ ```
9
+ benchmarking/
10
+ └── embedding_exploration/
11
+ └── data/
12
+ ├── salokas_2020_tableS3.csv
13
+ ├── tf_and_kinase_fusions.csv
14
+ ├── top_genes.csv
15
+ ```
16
+
17
+ - **`data/salokas_2020_tableS3.csv`**: Supplementary Table 3 from [Salokas et al. 2020](https://doi.org/10.1038/s41598-020-71040-8)
18
+ - **`data/tf_and_kinase_fusions.csv`**: set of TF::TF and Kinase::Kinase fusion oncoproteins from FusOn-DB database. Curated in `plot.py`
19
+ - **`data/top_genes.csv`**: fusion oncoproteins (and their head and tail components) visualized in Fig. S3b. Sequences for head and tail components were pulled from the best-aligned sequences in `fuson_plm/data/blast/blast_outputs/best_htg_alignments_swissprot_seqs.pkl`
20
+
21
+ ### Plotting
22
+
23
+ Run `plot.py` to regenerate plots in Figure S3:
24
+
25
+ ```
26
+ # Dictionary: key = run name, values = epochs. (use this option if you've trained your own model)
27
+ # # Or "FusOn-pLM" to use official model
28
+ FUSON_PLM_CKPT= "FusOn-pLM"
29
+
30
+ # Type of dim reduction
31
+ PLOT_UMAP = True
32
+ PLOT_TSNE = False
33
+
34
+ # Overwriting configs
35
+ PERMISSION_TO_OVERWRITE = False # if False, script will halt if it believes these embeddings have already been made.
36
+ ```
37
+
38
+ To run, use:
39
+ ```
40
+ nohup python plot.py > plot.out 2> plot.err &
41
+ ```
42
+ - All **results** are stored in `embedding_exploration/results/<timestamp>`, where `timestamp` is a unique string encoding the date and time when you started training.
43
+
44
+ Below are the FusOn-pLM paper results in `results/final/umap_plots/fuson_plm/best/`:
45
+
46
+ ```
47
+ benchmarking/
48
+ └── embedding_exploration/
49
+ └── results/final/umap_plots/fuson_plm/best/
50
+ └── favorites/
51
+ ├── umap_favorites_source_data.csv
52
+ ├── umap_favorites_visualization.png
53
+ └── tf_and_kinase/
54
+ ├── umap_tf_and_kinase_fusions_source_data.csv ├── umap_tf_and_kinase_fusions_visualization.png
55
+ ```
56
+
57
+ - **`favorites/umap_favorites_visualization.png`**: Fig. S3b, with the data directly plotted stored in `favorites/umap_favorites_source_data.csv`
58
+ - **`tf_and_kinase/umap_tf_and_kinase_fusions_visualization.png`**: Fig. S3a, with the data directly plotted stored in `tf_and_kinase/umap_tf_and_kinase_fusions_source_data.csv`.
fuson_plm/benchmarking/embedding_exploration/__init__.py ADDED
File without changes
fuson_plm/benchmarking/embedding_exploration/config.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # Dictionary: key = run name, values = epochs. (use this option if you've trained your own model)
2
+ # # Or, List: item goes to path (use this option if you're using the "best" ckpt from FusOn-pLM paper)
3
+ FUSON_PLM_CKPT= "FusOn-pLM"
4
+
5
+ # Type of dim reduction
6
+ PLOT_UMAP = True
7
+ PLOT_TSNE = False
8
+
9
+ # Overwriting configs
10
+ PERMISSION_TO_OVERWRITE = False # if False, script will halt if it believes these embeddings have already been made.
fuson_plm/benchmarking/embedding_exploration/data/salokas_2020_tableS3.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d8bebc0871a4329015a3c6c7843f5bbc86c48811b2a836c42f1ef46b37f4282a
3
+ size 19626
fuson_plm/benchmarking/embedding_exploration/data/tf_and_kinase_fusions.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:372321137ed12b2f8aa7c4891dafd0e88d64d5c5d0ea9c6f3a0aa9d897e8ead6
3
+ size 557262
fuson_plm/benchmarking/embedding_exploration/data/top_genes.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:33d568fe413107318caebd5ee260ee66fe8571461ed8f8d1b47888441f7b5034
3
+ size 16695
fuson_plm/benchmarking/embedding_exploration/plot.py ADDED
@@ -0,0 +1,496 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+ import pickle
4
+ from sklearn.manifold import TSNE
5
+ import matplotlib.font_manager as fm
6
+ from matplotlib.font_manager import FontProperties
7
+ import matplotlib.pyplot as plt
8
+ import matplotlib.gridspec as gridspec
9
+ import matplotlib.patches as patches
10
+ import seaborn as sns
11
+ import umap
12
+ import os
13
+
14
+ from fuson_plm.benchmarking.embed import embed_dataset_for_benchmark
15
+ import fuson_plm.benchmarking.embedding_exploration.config as config
16
+ from fuson_plm.utils.visualizing import set_font
17
+ from fuson_plm.utils.constants import TCGA_CODES, FODB_CODES, VALID_AAS, DELIMITERS
18
+ from fuson_plm.utils.logging import get_local_time, open_logfile, log_update, print_configpy
19
+
20
+
21
+ def get_dimred_embeddings(embeddings, dimred_type="umap"):
22
+ if dimred_type=="umap":
23
+ dimred_embeddings = get_umap_embeddings(embeddings)
24
+ return dimred_embeddings
25
+ if dimred_type=="tsne":
26
+ dimred_embeddings = get_tsne_embeddings(embeddings)
27
+ return dimred_embeddings
28
+
29
+ def get_tsne_embeddings(embeddings):
30
+ embeddings = np.array(embeddings)
31
+ tsne = TSNE(n_components=2, random_state=42,perplexity=5)
32
+ tsne_embeddings = tsne.fit_transform(embeddings)
33
+ return tsne_embeddings
34
+
35
+ def get_umap_embeddings(embeddings):
36
+ embeddings = np.array(embeddings)
37
+ umap_model = umap.UMAP(n_components=2, random_state=42, n_neighbors=15, metric='euclidean') # default parameters for UMAP
38
+ umap_embeddings = umap_model.fit_transform(embeddings)
39
+ return umap_embeddings
40
+
41
+ def plot_half_filled_circle(ax, x, y, left_color, right_color, size=100):
42
+ """
43
+ Plots a circle filled in halves with specified colors.
44
+
45
+ Parameters:
46
+ - ax: Matplotlib axis to draw on.
47
+ - x, y: Coordinates of the marker.
48
+ - left_color: Color of the left half.
49
+ - right_color: Color of the right half.
50
+ - size: Size of the marker.
51
+ """
52
+ radius = (size ** 0.5) / 100 # Scale the radius
53
+ # Create left half-circle (0° to 180°)
54
+ left_half = patches.Wedge((x, y), radius, 90, 270, color=left_color, ec="black")
55
+ # Create right half-circle (180° to 360°)
56
+ right_half = patches.Wedge((x, y), radius, 270, 90, color=right_color, ec="black")
57
+
58
+ # Add both halves to the plot
59
+ ax.add_patch(left_half)
60
+ ax.add_patch(right_half)
61
+
62
+ def plot_umap_scatter_tftf_kk(df, filename="umap.png"):
63
+ """
64
+ Plots a 2D scatterplot of UMAP coordinates with different markers and colors based on 'type'.
65
+ Only for TF::TF and Kinase::Kinase fusions
66
+
67
+ Parameters:
68
+ - df (pd.DataFrame): DataFrame containing 'umap1', 'umap2', 'sequence', and 'type' columns.
69
+ """
70
+ set_font()
71
+
72
+ # Define colors for each type
73
+ colors = {
74
+ "TF": "pink",
75
+ "Kinase": "orange"
76
+ }
77
+
78
+ # Define marker types and colors for each combination
79
+ marker_colors = {
80
+ "TF::TF": colors["TF"],
81
+ "Kinase::Kinase": colors["Kinase"],
82
+ }
83
+
84
+ # Create the plot
85
+ fig, ax = plt.subplots(figsize=(10, 8))
86
+ x_min, x_max = df["umap1"].min() - 1, df["umap1"].max() + 1
87
+ y_min, y_max = df["umap2"].min() - 1, df["umap2"].max() + 1
88
+ ax.set_xlim(x_min, x_max)
89
+ ax.set_ylim(y_min, y_max)
90
+
91
+ # Plot each point with the specified half-filled marker
92
+ for i in range(len(df)):
93
+ row = df.iloc[i]
94
+ marker_type = row["fusion_type"]
95
+ x, y = row["umap1"], row["umap2"]
96
+ color = marker_colors[marker_type]
97
+
98
+ ax.scatter(x, y, color=color, s=15, edgecolors="black", linewidth=0.5)
99
+
100
+ # Add custom legend
101
+ legend_elements = [
102
+ patches.Patch(facecolor="pink", edgecolor="black", label="TF::TF"),
103
+ patches.Patch(facecolor="orange", edgecolor="black", label="Kinase::Kinase")
104
+ ]
105
+ ax.legend(handles=legend_elements, title="Fusion Type", fontsize=16, title_fontsize=16)
106
+
107
+ # Add labels and title
108
+ plt.xlabel("UMAP 1", fontsize=20)
109
+ plt.ylabel("UMAP 2", fontsize=20)
110
+ plt.title("FusOn-pLM-embedded Transcription Factor and Kinase Fusions", fontsize=20)
111
+ plt.tight_layout()
112
+
113
+ # Save and show the plot
114
+ plt.savefig(filename, dpi=300)
115
+ plt.show()
116
+
117
+ def plot_umap_scatter_half_filled(df, filename="umap.png"):
118
+ """
119
+ Plots a 2D scatterplot of UMAP coordinates with different markers and colors based on 'type'.
120
+
121
+ Parameters:
122
+ - df (pd.DataFrame): DataFrame containing 'umap1', 'umap2', 'sequence', and 'type' columns.
123
+ """
124
+ # Define colors for each type
125
+ colors = {
126
+ "TF": "pink",
127
+ "Kinase": "orange",
128
+ "Other": "grey"
129
+ }
130
+
131
+ # Define marker types and colors for each combination
132
+ marker_colors = {
133
+ "TF::TF": {"left": colors["TF"], "right": colors["TF"]},
134
+ "TF::Other": {"left": colors["TF"], "right": colors["Other"]},
135
+ "Other::TF": {"left": colors["Other"], "right": colors["TF"]},
136
+ "Kinase::Kinase": {"left": colors["Kinase"], "right": colors["Kinase"]},
137
+ "Kinase::Other": {"left": colors["Kinase"], "right": colors["Other"]},
138
+ "Other::Kinase": {"left": colors["Other"], "right": colors["Kinase"]},
139
+ "Kinase::TF": {"left": colors["Kinase"], "right": colors["TF"]},
140
+ "TF::Kinase": {"left": colors["TF"], "right": colors["Kinase"]},
141
+ "Other::Other": {"left": colors["Other"], "right": colors["Other"]}
142
+ }
143
+
144
+ # Create the plot
145
+ fig, ax = plt.subplots(figsize=(10, 8))
146
+ x_min, x_max = df["umap1"].min() - 1, df["umap1"].max() + 1
147
+ y_min, y_max = df["umap2"].min() - 1, df["umap2"].max() + 1
148
+ ax.set_xlim(x_min, x_max)
149
+ ax.set_ylim(y_min, y_max)
150
+
151
+ # Plot each point with the specified half-filled marker
152
+ for i in range(len(df)):
153
+ row = df.iloc[i]
154
+ marker_type = row["fusion_type"]
155
+ x, y = row["umap1"], row["umap2"]
156
+ left_color = marker_colors[marker_type]["left"]
157
+ right_color = marker_colors[marker_type]["right"]
158
+ plot_half_filled_circle(ax, x, y, left_color, right_color, size=100)
159
+
160
+ # Add custom legend
161
+ legend_elements = [
162
+ patches.Patch(facecolor="pink", edgecolor="black", label="TF"),
163
+ patches.Patch(facecolor="orange", edgecolor="black", label="Kinase"),
164
+ patches.Patch(facecolor="grey", edgecolor="black", label="Other")
165
+ ]
166
+ ax.legend(handles=legend_elements, title="Type")
167
+
168
+ # Add labels and title
169
+ plt.xlabel("UMAP 1")
170
+ plt.ylabel("UMAP 2")
171
+ plt.title("UMAP Scatter Plot")
172
+ plt.tight_layout()
173
+
174
+ # Save and show the plot
175
+ plt.savefig(filename, dpi=300)
176
+ plt.show()
177
+
178
+ def get_gene_type(gene, d):
179
+ if gene in d:
180
+ if d[gene] == 'kinase':
181
+ return 'Kinase'
182
+ if d[gene] == 'tf':
183
+ return 'TF'
184
+ else:
185
+ return 'Other'
186
+
187
+ def get_tf_and_kinase_fusions_dataset():
188
+ # Load TF and Kinase Fusions
189
+ tf_kinase_parts = pd.read_csv("data/salokas_2020_tableS3.csv")
190
+ print(tf_kinase_parts)
191
+ ht_tf_kinase_dict = dict(zip(tf_kinase_parts['Gene'],tf_kinase_parts['Kinase or TF']))
192
+
193
+ # This one has each row with one fusiongene name
194
+ fuson_ht_db = pd.read_csv("../../data/blast/fuson_ht_db.csv")
195
+ fuson_ht_db[['hg','tg']] = fuson_ht_db['fusiongenes'].str.split("::",expand=True)
196
+
197
+ fuson_ht_db['hg_type'] = fuson_ht_db['hg'].apply(lambda x: get_gene_type(x, ht_tf_kinase_dict))
198
+ fuson_ht_db['tg_type'] = fuson_ht_db['tg'].apply(lambda x: get_gene_type(x, ht_tf_kinase_dict))
199
+ fuson_ht_db['fusion_type'] = fuson_ht_db['hg_type']+'::'+fuson_ht_db['tg_type']
200
+ fuson_ht_db['type']=['fusion']*len(fuson_ht_db)
201
+ # Keep 100 things in each category
202
+ categories = pd.DataFrame(fuson_ht_db['fusion_type'].value_counts()).reset_index()['index'].tolist()
203
+ categories = ["TF::TF","Kinase::Kinase"] # manually set some easier categories
204
+ print(categories)
205
+ plot_df = None
206
+
207
+ for i, cat in enumerate(categories):
208
+ random_sample = fuson_ht_db.loc[fuson_ht_db['fusion_type']==cat].reset_index(drop=True)
209
+ #random_sample = random_sample.sample(n=100, random_state=1).reset_index(drop=True)
210
+ if i==0:
211
+ plot_df = random_sample
212
+ else:
213
+ plot_df = pd.concat([plot_df,random_sample],axis=0).reset_index(drop=True)
214
+
215
+ print(plot_df['fusion_type'].value_counts())
216
+
217
+ # Now, need to add in the embeddings
218
+ plot_df = plot_df[['aa_seq','fusiongenes','fusion_type','type']].rename(
219
+ columns={'aa_seq':'sequence','fusiongenes':'ID'}
220
+ )
221
+
222
+ return plot_df
223
+
224
+ def make_tf_and_kinase_fusions_plot(seqs_with_embeddings, savedir = '', dimred_type='umap'):
225
+ fuson_db = pd.read_csv("../../data/fuson_db.csv")
226
+ seq_id_dict = dict(zip(fuson_db['aa_seq'],fuson_db['seq_id']))
227
+
228
+ # add sequences so we can save results/sequence
229
+ data = seqs_with_embeddings[[f'{dimred_type}1',f'{dimred_type}2','sequence','fusion_type','ID']]
230
+ data['seq_id'] = data['sequence'].map(seq_id_dict)
231
+
232
+ tfkinase_save_dir = f"{savedir}"
233
+ os.makedirs(tfkinase_save_dir,exist_ok=True)
234
+ data.to_csv(f"{tfkinase_save_dir}/{dimred_type}_tf_and_kinase_fusions_source_data.csv",index=False)
235
+ plot_umap_scatter_tftf_kk(data,filename=f"{tfkinase_save_dir}/{dimred_type}_tf_and_kinase_fusions_visualization.png")
236
+
237
+ def tf_and_kinase_fusions_plot(dimred_types, output_dir):
238
+ """
239
+ Makes the embeddings, THEN calls the plot. only on the four favorites
240
+ """
241
+ plot_df = get_tf_and_kinase_fusions_dataset()
242
+ plot_df.to_csv("data/tf_and_kinase_fusions.csv",index=False)
243
+
244
+ # path to the pkl file with FOdb embeddings
245
+ input_fname='tf_and_kinase'
246
+ all_embedding_paths = embed_dataset_for_benchmark(
247
+ fuson_ckpts=config.FUSON_PLM_CKPT,
248
+ input_data_path='data/tf_and_kinase_fusions.csv', input_fname=input_fname,
249
+ average=True, seq_col='sequence',
250
+ benchmark_fusonplm=True,
251
+ benchmark_esm=False,
252
+ benchmark_fo_puncta_ml=False,
253
+ overwrite=config.PERMISSION_TO_OVERWRITE)
254
+
255
+ # For each of the models we are benchmarking, load embeddings and make plots
256
+ log_update("\nEmbedding sequences")
257
+ # loop through the embedding paths and train each one
258
+ for embedding_path, details in all_embedding_paths.items():
259
+ log_update(f"\tBenchmarking embeddings at: {embedding_path}")
260
+ try:
261
+ with open(embedding_path, "rb") as f:
262
+ embeddings = pickle.load(f)
263
+ except:
264
+ raise Exception(f"Cannot read embeddings from {embedding_path}")
265
+
266
+ # combine the embeddings and splits into one dataframe
267
+ seqs_with_embeddings = pd.DataFrame.from_dict(embeddings.items())
268
+ seqs_with_embeddings = seqs_with_embeddings.rename(columns={0: 'sequence', 1: 'embedding'}) # the column that was called FusOn-pLM is now called embedding
269
+ seqs_with_embeddings = pd.merge(seqs_with_embeddings, plot_df, on='sequence', how='inner')
270
+ # get UMAP transform of the embeddings
271
+ for dimred_type in dimred_types:
272
+ dimred_embeddings = get_dimred_embeddings(seqs_with_embeddings['embedding'].tolist(),dimred_type=dimred_type)
273
+
274
+ # turn the result into a dataframe, and add it to seqs_with_embeddings
275
+ data = pd.DataFrame(dimred_embeddings, columns=[f'{dimred_type}1', f'{dimred_type}2'])
276
+ # save the umap data!
277
+ model_name = "_".join(embedding_path.split('embeddings/')[1].split('/')[1:-1])
278
+
279
+ seqs_with_embeddings[[f'{dimred_type}1', f'{dimred_type}2']] = data
280
+
281
+ # make subdirectory
282
+ intermediate = '/'.join(embedding_path.split('embeddings/')[1].split('/')[0:-1])
283
+ cur_output_dir = f"{output_dir}/{dimred_type}_plots/{intermediate}/{input_fname}"
284
+
285
+ os.makedirs(cur_output_dir,exist_ok=True)
286
+ make_tf_and_kinase_fusions_plot(seqs_with_embeddings, savedir = cur_output_dir, dimred_type=dimred_type)
287
+
288
+ def make_fusion_v_parts_favorites_plot(seqs_with_embeddings, savedir = None, dimred_type='umap'):
289
+ """
290
+ Make plots showing that PAX3::FOXO1, EWS::FLI1, SS18::SSX1, EML4::ALK are embedded distinctly from their heads and tails
291
+ """
292
+ set_font()
293
+
294
+ # Load one sequence each for four proteins in the test set: PAX3::FOXO1, EWS::FLI1, SS18::SSX1, EML4::ALK
295
+ data = pd.read_csv("data/top_genes.csv")
296
+ seqs_with_embeddings = pd.merge(seqs_with_embeddings, data, on="sequence")
297
+ seqs_with_embeddings["Type"] = [""]*len(seqs_with_embeddings)
298
+ seqs_with_embeddings.loc[
299
+ seqs_with_embeddings["gene"].str.contains("::"),"Type"
300
+ ] = "fusion_embeddings"
301
+ heads = seqs_with_embeddings.loc[seqs_with_embeddings["gene"].str.contains("::")]["gene"].str.split("::",expand=True)[0].tolist()
302
+ tails = seqs_with_embeddings.loc[seqs_with_embeddings["gene"].str.contains("::")]["gene"].str.split("::",expand=True)[1].tolist()
303
+ seqs_with_embeddings.loc[
304
+ seqs_with_embeddings["gene"].isin(heads),"Type"
305
+ ] = "h_embeddings"
306
+ seqs_with_embeddings.loc[
307
+ seqs_with_embeddings["gene"].isin(tails),"Type"
308
+ ] = "t_embeddings"
309
+
310
+ # make merge
311
+ merge = seqs_with_embeddings.loc[seqs_with_embeddings['gene'].str.contains('::')].reset_index(drop=True)[['gene','sequence']]
312
+ merge["head"] = merge["gene"].str.split("::",expand=True)[0]
313
+ merge["tail"] = merge["gene"].str.split("::",expand=True)[1]
314
+ merge = pd.merge(merge, seqs_with_embeddings[['gene','sequence']].rename(
315
+ columns={'gene': 'head', 'sequence': 'h_sequence'}),
316
+ on='head',how='left'
317
+ )
318
+ merge = pd.merge(merge, seqs_with_embeddings[['gene','sequence']].rename(
319
+ columns={'gene': 'tail', 'sequence': 't_sequence'}),
320
+ on='tail',how='left'
321
+ )
322
+
323
+ plt.figure()
324
+
325
+ # Define colors and markers
326
+ colors = {
327
+ 'fusion_embeddings': '#cf9dfa', # old color #0C4A4D
328
+ 'h_embeddings': '#eb8888', # Updated to original names; old color #619283
329
+ 't_embeddings': '#5fa3e3', # Updated to original names; old color #619283
330
+ }
331
+ markers = {
332
+ 'fusion_embeddings': 'o',
333
+ 'h_embeddings': '^', # Updated to original names
334
+ 't_embeddings': 'v' # Updated to original names
335
+ }
336
+ label_map = {
337
+ 'fusion_embeddings': 'Fusion',
338
+ 'h_embeddings': 'Head', # Updated label
339
+ 't_embeddings': 'Tail', # Updated label
340
+ }
341
+
342
+ # Create a 2x3 grid of plots
343
+ fig, axes = plt.subplots(2, 3, figsize=(18, 12))
344
+ #fig, axes = plt.subplots(1, 4, figsize= (18, 7))
345
+
346
+ # Get the global min and max for the x and y axis ranges
347
+ all_tsne1 = seqs_with_embeddings[f'{dimred_type}1']
348
+ all_tsne2 = seqs_with_embeddings[f'{dimred_type}2']
349
+ x_min, x_max = all_tsne1.min(), all_tsne1.max()
350
+ y_min, y_max = all_tsne2.min(), all_tsne2.max()
351
+ x_min, x_max = [11, 16] # manually set range for cleaner plotting
352
+ y_min, y_max = [10, 22]
353
+
354
+ # Determine tick positions
355
+ x_ticks = np.arange(x_min, x_max + 1, 1)
356
+ y_ticks = np.arange(y_min, y_max + 1, 1)
357
+
358
+ # Flatten the axes array for easier iteration
359
+ axes = axes.flatten()
360
+
361
+ for i, ax in enumerate(axes):
362
+ # Extract the gene names from the current row
363
+ fgene_name = merge.loc[i, 'gene']
364
+ hgene = merge.loc[i, 'head']
365
+ tgene = merge.loc[i, 'tail']
366
+
367
+ # Filter tsne_embeddings for the relevant entries
368
+ tsne_data = seqs_with_embeddings[seqs_with_embeddings['gene'].isin([fgene_name, hgene, tgene])]
369
+
370
+ # Plot each type
371
+ for emb_type in tsne_data['Type'].unique():
372
+ subset = tsne_data[tsne_data['Type'] == emb_type]
373
+ ax.scatter(subset[f'{dimred_type}1'], subset[f'{dimred_type}2'], label=label_map[emb_type], color=colors[emb_type], marker=markers[emb_type], s=120, zorder=3)
374
+
375
+ ax.set_title(f'{fgene_name}',fontsize=44)
376
+ label_transform = {
377
+ 'tsne': 't-SNE',
378
+ 'umap': 'UMAP'
379
+ }
380
+ ax.set_xlabel(f'{label_transform[dimred_type]} 1',fontsize=44)
381
+ ax.set_ylabel(f'{label_transform[dimred_type]} 2',fontsize=44)
382
+ ax.grid(True, which='both', linestyle='--', linewidth=0.5, color='gray', zorder=1)
383
+
384
+ # Set the same limits and ticks for all axes
385
+ ax.set_xlim(x_min, x_max)
386
+ ax.set_ylim(y_min, y_max)
387
+ ax.set_xticks(x_ticks)#\\, labelsize=24)
388
+ ax.set_yticks(y_ticks)#, labelsize=24)
389
+
390
+ # Rotate x-axis labels
391
+ ax.set_xticklabels(ax.get_xticks(), rotation=45, ha='right')
392
+
393
+ ax.tick_params(axis='x', labelsize=16)
394
+ ax.tick_params(axis='y', labelsize=16)
395
+
396
+ for label in ax.get_xticklabels():
397
+ label.set_fontsize(24)
398
+ for label in ax.get_yticklabels():
399
+ label.set_fontsize(24)
400
+
401
+ # Set font size for the legend if needed
402
+ if i == 0:
403
+ legend = ax.legend(fontsize=20, markerscale=2, loc='best')
404
+ for text in legend.get_texts():
405
+ text.set_fontsize(24)
406
+
407
+ # Adjust layout to prevent overlap
408
+ plt.tight_layout()
409
+
410
+ # Show the plot
411
+ plt.show()
412
+
413
+ # Save the figure
414
+ plt.savefig(f'{savedir}/{dimred_type}_favorites_visualization.png', dpi=300)
415
+
416
+ # Save the data
417
+ seq_to_id_dict = pd.read_csv("../../data/fuson_db.csv")
418
+ seq_to_id_dict = dict(zip(seq_to_id_dict['aa_seq'],seq_to_id_dict['seq_id']))
419
+ seqs_with_embeddings['seq_id'] = seqs_with_embeddings['sequence'].map(seq_to_id_dict)
420
+ seqs_with_embeddings[['umap1','umap2','sequence','Type','gene','id','seq_id']].to_csv(f"{savedir}/{dimred_type}_favorites_source_data.csv",index=False)
421
+
422
+ def fusion_v_parts_favorites(dimred_types, output_dir):
423
+ """
424
+ Makes the embeddings, THEN calls the plot. only on the four favorites
425
+ """
426
+
427
+ # path to the pkl file with FOdb embeddings
428
+ input_fname='favorites'
429
+ all_embedding_paths = embed_dataset_for_benchmark(
430
+ fuson_ckpts=config.FUSON_PLM_CKPT,
431
+ input_data_path='data/top_genes.csv', input_fname=input_fname,
432
+ average=True, seq_col='sequence',
433
+ benchmark_fusonplm=True,
434
+ benchmark_esm=False,
435
+ benchmark_fo_puncta_ml=False,
436
+ overwrite=config.PERMISSION_TO_OVERWRITE)
437
+
438
+ # For each of the models we are benchmarking, load embeddings and make plots
439
+ log_update("\nEmbedding sequences")
440
+ # loop through the embedding paths and train each one
441
+ for embedding_path, details in all_embedding_paths.items():
442
+ log_update(f"\tBenchmarking embeddings at: {embedding_path}")
443
+ try:
444
+ with open(embedding_path, "rb") as f:
445
+ embeddings = pickle.load(f)
446
+ except:
447
+ raise Exception(f"Cannot read embeddings from {embedding_path}")
448
+
449
+ # combine the embeddings and splits into one dataframe
450
+ seqs_with_embeddings = pd.DataFrame.from_dict(embeddings.items())
451
+ seqs_with_embeddings = seqs_with_embeddings.rename(columns={0: 'sequence', 1: 'embedding'}) # the column that was called FusOn-pLM is now called embedding
452
+
453
+ # get UMAP transform of the embeddings
454
+ for dimred_type in dimred_types:
455
+ dimred_embeddings = get_dimred_embeddings(seqs_with_embeddings['embedding'].tolist(),dimred_type=dimred_type)
456
+
457
+ # turn the result into a dataframe, and add it to seqs_with_embeddings
458
+ data = pd.DataFrame(dimred_embeddings, columns=[f'{dimred_type}1', f'{dimred_type}2'])
459
+ # save the umap data!
460
+ model_name = "_".join(embedding_path.split('embeddings/')[1].split('/')[1:-1])
461
+
462
+ seqs_with_embeddings[[f'{dimred_type}1', f'{dimred_type}2']] = data
463
+
464
+ # make subdirectory
465
+ intermediate = '/'.join(embedding_path.split('embeddings/')[1].split('/')[0:-1])
466
+ cur_output_dir = f"{output_dir}/{dimred_type}_plots/{intermediate}/{input_fname}"
467
+
468
+ os.makedirs(cur_output_dir,exist_ok=True)
469
+ make_fusion_v_parts_favorites_plot(seqs_with_embeddings, savedir = cur_output_dir, dimred_type=dimred_type)
470
+
471
+ def main():
472
+ # make directory to save results
473
+ os.makedirs('results',exist_ok=True)
474
+ output_dir = f'results/{get_local_time()}'
475
+ os.makedirs(output_dir,exist_ok=True)
476
+
477
+ dimred_types = []
478
+ if config.PLOT_UMAP:
479
+ dimred_types.append("umap")
480
+ #os.makedirs(f"{output_dir}/umap_data",exist_ok=True)
481
+ os.makedirs(f"{output_dir}/umap_plots",exist_ok=True)
482
+ if config.PLOT_TSNE:
483
+ dimred_types.append("tsne")
484
+ #os.makedirs(f"{output_dir}/tsne_data",exist_ok=True)
485
+ os.makedirs(f"{output_dir}/tsne_plots",exist_ok=True)
486
+
487
+ with open_logfile(f'{output_dir}/embedding_exploration_log.txt'):
488
+ print_configpy(config)
489
+ # make the disinct embeddings plot
490
+ fusion_v_parts_favorites(dimred_types, output_dir)
491
+
492
+ tf_and_kinase_fusions_plot(dimred_types, output_dir)
493
+
494
+
495
+ if __name__ == "__main__":
496
+ main()
fuson_plm/benchmarking/embedding_exploration/results/final/umap_plots/fuson_plm/best/favorites/umap_favorites_source_data.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:28c0b51f513da01df3dee3c4e71aa0c583bd57d9878137bdac9e7ebc704694e4
3
+ size 17383
fuson_plm/benchmarking/embedding_exploration/results/final/umap_plots/fuson_plm/best/favorites/umap_favorites_visualization.png ADDED
fuson_plm/benchmarking/embedding_exploration/results/final/umap_plots/fuson_plm/best/tf_and_kinase/umap_tf_and_kinase_fusions_source_data.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b26a5a6c2f8f54225fd46f01dab52813532438732624561af8e2e4ad005e5dc7
3
+ size 570073
fuson_plm/benchmarking/embedding_exploration/results/final/umap_plots/fuson_plm/best/tf_and_kinase/umap_tf_and_kinase_fusions_visualization.png ADDED
fuson_plm/benchmarking/mutation_prediction/README.md CHANGED
@@ -81,7 +81,7 @@ To run, use:
81
  ```
82
  nohup python discover.py > discover.out 2> discover.err &
83
  ```
84
- - All **results** are stored in `idr_prediction/results/<timestamp>`, where `timestamp` is a unique string encoding the date and time when you started training.
85
 
86
  Below are the FusOn-pLM paper results in `results/final`:
87
 
 
81
  ```
82
  nohup python discover.py > discover.out 2> discover.err &
83
  ```
84
+ - All **results** are stored in `mutation_prediction/results/<timestamp>`, where `timestamp` is a unique string encoding the date and time when you started training.
85
 
86
  Below are the FusOn-pLM paper results in `results/final`:
87
 
fuson_plm/benchmarking/puncta/train.py CHANGED
@@ -5,7 +5,7 @@ import numpy as np
5
  import pickle
6
  import os
7
 
8
- from fuson_plm.benchmarking.xgboost_predictor import train_final_predictor, evaluate_predictor, train_predictor_xval
9
  from fuson_plm.benchmarking.embed import embed_dataset_for_benchmark
10
  import fuson_plm.benchmarking.puncta.config as config
11
  from fuson_plm.benchmarking.puncta.plot import make_all_final_bar_charts
 
5
  import pickle
6
  import os
7
 
8
+ from fuson_plm.benchmarking.xgboost_predictor import train_final_predictor, evaluate_predictor
9
  from fuson_plm.benchmarking.embed import embed_dataset_for_benchmark
10
  import fuson_plm.benchmarking.puncta.config as config
11
  from fuson_plm.benchmarking.puncta.plot import make_all_final_bar_charts
fuson_plm/benchmarking/xgboost_predictor.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sklearn.model_selection import train_test_split, StratifiedKFold
2
+ from sklearn.metrics import accuracy_score, recall_score, f1_score, precision_score, roc_auc_score, average_precision_score
3
+ from fuson_plm.utils.logging import log_update
4
+ import time
5
+ import xgboost as xgb
6
+ import numpy as np
7
+ import pandas as pd
8
+
9
+ def train_final_predictor(X_train, y_train, n_estimators=50,tree_method="hist"):
10
+ clf = xgb.XGBClassifier(n_estimators=n_estimators, tree_method=tree_method)
11
+ clf.fit(X_train, y_train)
12
+ return clf
13
+
14
+ def evaluate_predictor(clf,X_test,y_test,class1_thresh=None):
15
+ # Predicting the labels on test set
16
+ y_pred_test = clf.predict(X_test) # labels with automatic thresholds
17
+ y_pred_prob_test = clf.predict_proba(X_test)[:, 1]
18
+ if class1_thresh is not None: y_pred_customthresh_test = np.where(np.array(y_pred_prob_test) >= class1_thresh, 1, 0)
19
+
20
+ # Calculating metrics - automatic
21
+ accuracy = accuracy_score(y_test, y_pred_test)
22
+ precision = precision_score(y_test, y_pred_test)
23
+ recall = recall_score(y_test, y_pred_test)
24
+ f1 = f1_score(y_test, y_pred_test)
25
+ auroc_prob = roc_auc_score(y_test, y_pred_prob_test)
26
+ auprc_prob = average_precision_score(y_test, y_pred_prob_test)
27
+ auroc_label = roc_auc_score(y_test, y_pred_test)
28
+ auprc_label = average_precision_score(y_test, y_pred_test)
29
+
30
+ automatic_stats_df = pd.DataFrame(data={
31
+ 'Accuracy': [accuracy],
32
+ 'Precision': [precision],
33
+ 'Recall': [recall],
34
+ 'F1 Score': [f1],
35
+ 'AUROC': [auroc_prob],
36
+ 'AUROC Label': [auroc_label],
37
+ 'AUPRC': [auprc_prob],
38
+ 'AUPRC Label': [auprc_label]
39
+ })
40
+
41
+ # Calculating metrics - custom threshold (note that probability ones won't change)
42
+ if class1_thresh is not None:
43
+ accuracy_custom = accuracy_score(y_test, y_pred_customthresh_test)
44
+ precision_custom = precision_score(y_test, y_pred_customthresh_test)
45
+ recall_custom = recall_score(y_test, y_pred_customthresh_test)
46
+ f1_custom = f1_score(y_test, y_pred_customthresh_test)
47
+ auroc_prob_custom = roc_auc_score(y_test, y_pred_prob_test)
48
+ auprc_prob_custom = average_precision_score(y_test, y_pred_prob_test)
49
+ auroc_label_custom = roc_auc_score(y_test, y_pred_customthresh_test)
50
+ auprc_label_custom = average_precision_score(y_test, y_pred_customthresh_test)
51
+
52
+ custom_stats_df = pd.DataFrame(data={
53
+ 'Accuracy': [accuracy_custom],
54
+ 'Precision': [precision_custom],
55
+ 'Recall': [recall_custom],
56
+ 'F1 Score': [f1_custom],
57
+ 'AUROC': [auroc_prob_custom],
58
+ 'AUROC Label': [auroc_label_custom],
59
+ 'AUPRC': [auprc_prob_custom],
60
+ 'AUPRC Label': [auprc_label_custom]
61
+ })
62
+ else:
63
+ custom_stats_df = None
64
+
65
+ return automatic_stats_df, custom_stats_df