Upload MOJO
Browse files
mojo.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
import logging
|
| 2 |
import math
|
| 3 |
-
from dataclasses import dataclass
|
| 4 |
-
from typing import Optional, Tuple
|
| 5 |
|
| 6 |
import numpy as np
|
| 7 |
import torch
|
|
@@ -510,32 +510,40 @@ class LMHead(nn.Module):
|
|
| 510 |
return out
|
| 511 |
|
| 512 |
|
| 513 |
-
@dataclass
|
| 514 |
class MOJOConfig(PretrainedConfig): # noqa: N801
|
| 515 |
model_type = "MOJO"
|
| 516 |
-
|
| 517 |
-
|
| 518 |
-
|
| 519 |
-
|
| 520 |
-
|
| 521 |
-
|
| 522 |
-
|
| 523 |
-
|
| 524 |
-
|
| 525 |
-
|
| 526 |
-
|
| 527 |
-
|
| 528 |
-
|
| 529 |
-
|
| 530 |
-
|
| 531 |
-
|
| 532 |
-
|
| 533 |
-
|
| 534 |
-
|
| 535 |
-
|
| 536 |
-
|
| 537 |
-
|
| 538 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 539 |
|
| 540 |
def __post_init__(self):
|
| 541 |
# Validate attention key size
|
|
|
|
| 1 |
import logging
|
| 2 |
import math
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from typing import Any, Optional, Tuple
|
| 5 |
|
| 6 |
import numpy as np
|
| 7 |
import torch
|
|
|
|
| 510 |
return out
|
| 511 |
|
| 512 |
|
|
|
|
| 513 |
class MOJOConfig(PretrainedConfig): # noqa: N801
|
| 514 |
model_type = "MOJO"
|
| 515 |
+
|
| 516 |
+
def __init__(self, **kwargs: Any) -> None:
|
| 517 |
+
super().__init__(**kwargs)
|
| 518 |
+
self.alphabet_size = kwargs.get(
|
| 519 |
+
"alphabet_size", {"rnaseq": 66, "methylation": 66}
|
| 520 |
+
)
|
| 521 |
+
self.token_embed_dim = kwargs.get("token_embed_dim", 256)
|
| 522 |
+
self.init_gene_embed_dim = kwargs.get("init_gene_embed_dim", 200)
|
| 523 |
+
self.use_gene_embedding = kwargs.get("use_gene_embedding", True)
|
| 524 |
+
self.project_gene_embedding = kwargs.get("project_gene_embedding", True)
|
| 525 |
+
self.sequence_length = kwargs.get("sequence_length", 17_116) # n_genes
|
| 526 |
+
self.fixed_sequence_length = kwargs.get("fixed_sequence_length", None)
|
| 527 |
+
self.num_downsamples = kwargs.get("num_downsamples", 8)
|
| 528 |
+
self.conv_init_embed_dim = kwargs.get("conv_init_embed_dim", 512)
|
| 529 |
+
self.stem_kernel_shape = kwargs.get("stem_kernel_shape", 15)
|
| 530 |
+
self.embed_dim = kwargs.get("embed_dim", 512)
|
| 531 |
+
self.filter_list = kwargs.get("filter_list", [])
|
| 532 |
+
self.num_attention_heads = kwargs.get("num_attention_heads", 16)
|
| 533 |
+
self.key_size = kwargs.get("key_size", None)
|
| 534 |
+
self.ffn_embed_dim = kwargs.get("ffn_embed_dim", 1_024)
|
| 535 |
+
self.num_layers = kwargs.get("num_layers", 8)
|
| 536 |
+
self.num_hidden_layers_head = kwargs.get("num_hidden_layers_head", 1)
|
| 537 |
+
|
| 538 |
+
# return
|
| 539 |
+
self.embeddings_layers_to_save: tuple[int, ...] = kwargs.get(
|
| 540 |
+
"embeddings_layers_to_save", ()
|
| 541 |
+
)
|
| 542 |
+
self.attention_maps_to_save: list[tuple[int, int]] = kwargs.get(
|
| 543 |
+
"attention_maps_to_save", []
|
| 544 |
+
)
|
| 545 |
+
|
| 546 |
+
self.__post_init__()
|
| 547 |
|
| 548 |
def __post_init__(self):
|
| 549 |
# Validate attention key size
|