Upload BulkRNABert
Browse files- bulkrnabert.py +24 -17
bulkrnabert.py
CHANGED
|
@@ -1,6 +1,5 @@
|
|
| 1 |
import logging
|
| 2 |
-
from
|
| 3 |
-
from typing import Optional
|
| 4 |
|
| 5 |
import numpy as np
|
| 6 |
import torch
|
|
@@ -198,23 +197,31 @@ class SelfAttentionBlock(nn.Module):
|
|
| 198 |
return output
|
| 199 |
|
| 200 |
|
| 201 |
-
@dataclass
|
| 202 |
class BulkRNABertConfig(PretrainedConfig):
|
| 203 |
model_type = "BulkRNABert"
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 218 |
|
| 219 |
def __post_init__(self):
|
| 220 |
# Validate attention key size
|
|
|
|
| 1 |
import logging
|
| 2 |
+
from typing import Any, Optional
|
|
|
|
| 3 |
|
| 4 |
import numpy as np
|
| 5 |
import torch
|
|
|
|
| 197 |
return output
|
| 198 |
|
| 199 |
|
|
|
|
| 200 |
class BulkRNABertConfig(PretrainedConfig):
|
| 201 |
model_type = "BulkRNABert"
|
| 202 |
+
|
| 203 |
+
def __init__(self, **kwargs: Any) -> None:
|
| 204 |
+
super().__init__(**kwargs)
|
| 205 |
+
self.n_genes = kwargs.get("n_genes", 19_062)
|
| 206 |
+
self.n_expressions_bins = kwargs.get("n_expressions_bins", 64)
|
| 207 |
+
self.embed_dim = kwargs.get("embed_dim", 256)
|
| 208 |
+
self.init_gene_embed_dim = kwargs.get("init_gene_embed_dim", 200)
|
| 209 |
+
self.use_gene_embedding = kwargs.get("use_gene_embedding", True)
|
| 210 |
+
self.project_gene_embedding = kwargs.get("project_gene_embedding", True)
|
| 211 |
+
self.num_attention_heads = kwargs.get("num_attention_heads", 8)
|
| 212 |
+
self.key_size = kwargs.get("key_size", None)
|
| 213 |
+
self.ffn_embed_dim = kwargs.get("ffn_embed_dim", 512)
|
| 214 |
+
self.num_layers = kwargs.get("num_layers", 4)
|
| 215 |
+
|
| 216 |
+
# return
|
| 217 |
+
self.embeddings_layers_to_save: tuple[int, ...] = kwargs.get(
|
| 218 |
+
"embeddings_layers_to_save", ()
|
| 219 |
+
)
|
| 220 |
+
self.attention_maps_to_save: list[tuple[int, int]] = kwargs.get(
|
| 221 |
+
"attention_maps_to_save", []
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
self.__post_init__()
|
| 225 |
|
| 226 |
def __post_init__(self):
|
| 227 |
# Validate attention key size
|