Yanisadel commited on
Commit
5d7e1fc
·
verified ·
1 Parent(s): c83f0a9

Upload SegmentEnformer

Browse files
Files changed (2) hide show
  1. config.json +4 -0
  2. segment_enformer.py +75 -0
config.json CHANGED
@@ -2,6 +2,10 @@
2
  "architectures": [
3
  "SegmentEnformer"
4
  ],
 
 
 
 
5
  "dim_divisible_by": 128,
6
  "embed_dim": 1536,
7
  "features": [
 
2
  "architectures": [
3
  "SegmentEnformer"
4
  ],
5
+ "auto_map": {
6
+ "AutoConfig": "segment_enformer.SegmentEnformerConfig",
7
+ "AutoModel": "segment_enformer.SegmentEnformer"
8
+ },
9
  "dim_divisible_by": 128,
10
  "embed_dim": 1536,
11
  "features": [
segment_enformer.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ from einops import rearrange
4
+ from enformer_pytorch import Enformer
5
+ from transformers import PretrainedConfig, PreTrainedModel
6
+
7
+ from genomics_research.segmentnt.layers.torch.segmentation_head import TorchUNetHead
8
+
9
+ FEATURES = [
10
+ "protein_coding_gene",
11
+ "lncRNA",
12
+ "exon",
13
+ "intron",
14
+ "splice_donor",
15
+ "splice_acceptor",
16
+ "5UTR",
17
+ "3UTR",
18
+ "CTCF-bound",
19
+ "polyA_signal",
20
+ "enhancer_Tissue_specific",
21
+ "enhancer_Tissue_invariant",
22
+ "promoter_Tissue_specific",
23
+ "promoter_Tissue_invariant",
24
+ ]
25
+
26
+
27
+ class SegmentEnformerConfig(PretrainedConfig):
28
+ model_type = "segment_enformer"
29
+
30
+ def __init__(
31
+ self,
32
+ features: List[str] = FEATURES,
33
+ embed_dim: int = 1536,
34
+ dim_divisible_by: int = 128,
35
+ **kwargs
36
+ ):
37
+ self.features = features
38
+ self.embed_dim = embed_dim
39
+ self.dim_divisible_by = dim_divisible_by
40
+
41
+ super().__init__(**kwargs)
42
+
43
+
44
+ class SegmentEnformer(PreTrainedModel):
45
+ config_class = SegmentEnformerConfig
46
+
47
+ def __init__(self, config: SegmentEnformerConfig):
48
+ super().__init__(config=config)
49
+
50
+ enformer = Enformer.from_pretrained("EleutherAI/enformer-official-rough")
51
+
52
+ self.stem = enformer.stem
53
+ self.conv_tower = enformer.conv_tower
54
+ self.transformer = enformer.transformer
55
+
56
+ self.unet_head = TorchUNetHead(
57
+ features=config.features,
58
+ embed_dimension=config.embed_dim,
59
+ nucl_per_token=config.dim_divisible_by,
60
+ remove_cls_token=False,
61
+ )
62
+
63
+ def __call__(self, x):
64
+ x = rearrange(x, "b n d -> b d n")
65
+ x = self.stem(x)
66
+
67
+ x = self.conv_tower(x)
68
+
69
+ x = rearrange(x, "b d n -> b n d")
70
+ x = self.transformer(x)
71
+
72
+ x = rearrange(x, "b n d -> b d n")
73
+ x = self.unet_head(x)
74
+
75
+ return x