Yanisadel commited on
Commit
e5a0f63
·
verified ·
1 Parent(s): 855127e

Upload SegmentEnformer

Browse files
Files changed (1) hide show
  1. segment_enformer.py +9 -6
segment_enformer.py CHANGED
@@ -1,10 +1,13 @@
1
- from typing import List
2
 
 
3
  from einops import rearrange
4
  from enformer_pytorch import Enformer
5
  from transformers import PretrainedConfig, PreTrainedModel
6
 
7
- from genomics_research.segmentnt.layers.torch.segmentation_head import TorchUNetHead
 
 
8
 
9
  FEATURES = [
10
  "protein_coding_gene",
@@ -32,8 +35,8 @@ class SegmentEnformerConfig(PretrainedConfig):
32
  features: List[str] = FEATURES,
33
  embed_dim: int = 1536,
34
  dim_divisible_by: int = 128,
35
- **kwargs
36
- ):
37
  self.features = features
38
  self.embed_dim = embed_dim
39
  self.dim_divisible_by = dim_divisible_by
@@ -44,7 +47,7 @@ class SegmentEnformerConfig(PretrainedConfig):
44
  class SegmentEnformer(PreTrainedModel):
45
  config_class = SegmentEnformerConfig
46
 
47
- def __init__(self, config: SegmentEnformerConfig):
48
  super().__init__(config=config)
49
 
50
  enformer = Enformer.from_pretrained("EleutherAI/enformer-official-rough")
@@ -60,7 +63,7 @@ class SegmentEnformer(PreTrainedModel):
60
  remove_cls_token=False,
61
  )
62
 
63
- def __call__(self, x):
64
  x = rearrange(x, "b n d -> b d n")
65
  x = self.stem(x)
66
 
 
1
+ from typing import Any, Dict, List
2
 
3
+ import torch
4
  from einops import rearrange
5
  from enformer_pytorch import Enformer
6
  from transformers import PretrainedConfig, PreTrainedModel
7
 
8
+ from genomics_research.segmentnt.porting_to_pytorch.layers.segmentation_head import (
9
+ TorchUNetHead,
10
+ )
11
 
12
  FEATURES = [
13
  "protein_coding_gene",
 
35
  features: List[str] = FEATURES,
36
  embed_dim: int = 1536,
37
  dim_divisible_by: int = 128,
38
+ **kwargs: Dict[str, Any]
39
+ ) -> None:
40
  self.features = features
41
  self.embed_dim = embed_dim
42
  self.dim_divisible_by = dim_divisible_by
 
47
  class SegmentEnformer(PreTrainedModel):
48
  config_class = SegmentEnformerConfig
49
 
50
+ def __init__(self, config: SegmentEnformerConfig) -> None:
51
  super().__init__(config=config)
52
 
53
  enformer = Enformer.from_pretrained("EleutherAI/enformer-official-rough")
 
63
  remove_cls_token=False,
64
  )
65
 
66
+ def __call__(self, x: torch.Tensor) -> torch.Tensor:
67
  x = rearrange(x, "b n d -> b d n")
68
  x = self.stem(x)
69