segment_enformer / segment_enformer.py
Yanisadel's picture
Upload SegmentEnformer
e5a0f63 verified
raw
history blame
1.98 kB
from typing import Any, Dict, List
import torch
from einops import rearrange
from enformer_pytorch import Enformer
from transformers import PretrainedConfig, PreTrainedModel
from genomics_research.segmentnt.porting_to_pytorch.layers.segmentation_head import (
TorchUNetHead,
)
FEATURES = [
"protein_coding_gene",
"lncRNA",
"exon",
"intron",
"splice_donor",
"splice_acceptor",
"5UTR",
"3UTR",
"CTCF-bound",
"polyA_signal",
"enhancer_Tissue_specific",
"enhancer_Tissue_invariant",
"promoter_Tissue_specific",
"promoter_Tissue_invariant",
]
class SegmentEnformerConfig(PretrainedConfig):
model_type = "segment_enformer"
def __init__(
self,
features: List[str] = FEATURES,
embed_dim: int = 1536,
dim_divisible_by: int = 128,
**kwargs: Dict[str, Any]
) -> None:
self.features = features
self.embed_dim = embed_dim
self.dim_divisible_by = dim_divisible_by
super().__init__(**kwargs)
class SegmentEnformer(PreTrainedModel):
config_class = SegmentEnformerConfig
def __init__(self, config: SegmentEnformerConfig) -> None:
super().__init__(config=config)
enformer = Enformer.from_pretrained("EleutherAI/enformer-official-rough")
self.stem = enformer.stem
self.conv_tower = enformer.conv_tower
self.transformer = enformer.transformer
self.unet_head = TorchUNetHead(
features=config.features,
embed_dimension=config.embed_dim,
nucl_per_token=config.dim_divisible_by,
remove_cls_token=False,
)
def __call__(self, x: torch.Tensor) -> torch.Tensor:
x = rearrange(x, "b n d -> b d n")
x = self.stem(x)
x = self.conv_tower(x)
x = rearrange(x, "b d n -> b n d")
x = self.transformer(x)
x = rearrange(x, "b n d -> b d n")
x = self.unet_head(x)
return x