|
from typing import Any, Dict, List |
|
|
|
import torch |
|
from einops import rearrange |
|
from enformer_pytorch import Enformer |
|
from transformers import PretrainedConfig, PreTrainedModel |
|
|
|
from genomics_research.segmentnt.porting_to_pytorch.layers.segmentation_head import ( |
|
TorchUNetHead, |
|
) |
|
|
|
FEATURES = [ |
|
"protein_coding_gene", |
|
"lncRNA", |
|
"exon", |
|
"intron", |
|
"splice_donor", |
|
"splice_acceptor", |
|
"5UTR", |
|
"3UTR", |
|
"CTCF-bound", |
|
"polyA_signal", |
|
"enhancer_Tissue_specific", |
|
"enhancer_Tissue_invariant", |
|
"promoter_Tissue_specific", |
|
"promoter_Tissue_invariant", |
|
] |
|
|
|
|
|
class SegmentEnformerConfig(PretrainedConfig): |
|
model_type = "segment_enformer" |
|
|
|
def __init__( |
|
self, |
|
features: List[str] = FEATURES, |
|
embed_dim: int = 1536, |
|
dim_divisible_by: int = 128, |
|
**kwargs: Dict[str, Any] |
|
) -> None: |
|
self.features = features |
|
self.embed_dim = embed_dim |
|
self.dim_divisible_by = dim_divisible_by |
|
|
|
super().__init__(**kwargs) |
|
|
|
|
|
class SegmentEnformer(PreTrainedModel): |
|
config_class = SegmentEnformerConfig |
|
|
|
def __init__(self, config: SegmentEnformerConfig) -> None: |
|
super().__init__(config=config) |
|
|
|
enformer = Enformer.from_pretrained("EleutherAI/enformer-official-rough") |
|
|
|
self.stem = enformer.stem |
|
self.conv_tower = enformer.conv_tower |
|
self.transformer = enformer.transformer |
|
|
|
self.unet_head = TorchUNetHead( |
|
features=config.features, |
|
embed_dimension=config.embed_dim, |
|
nucl_per_token=config.dim_divisible_by, |
|
remove_cls_token=False, |
|
) |
|
|
|
def __call__(self, x: torch.Tensor) -> torch.Tensor: |
|
x = rearrange(x, "b n d -> b d n") |
|
x = self.stem(x) |
|
|
|
x = self.conv_tower(x) |
|
|
|
x = rearrange(x, "b d n -> b n d") |
|
x = self.transformer(x) |
|
|
|
x = rearrange(x, "b n d -> b d n") |
|
x = self.unet_head(x) |
|
|
|
return x |
|
|