from typing import List from einops import rearrange from enformer_pytorch import Enformer from transformers import PretrainedConfig, PreTrainedModel from genomics_research.segmentnt.layers.torch.segmentation_head import TorchUNetHead FEATURES = [ "protein_coding_gene", "lncRNA", "exon", "intron", "splice_donor", "splice_acceptor", "5UTR", "3UTR", "CTCF-bound", "polyA_signal", "enhancer_Tissue_specific", "enhancer_Tissue_invariant", "promoter_Tissue_specific", "promoter_Tissue_invariant", ] class SegmentEnformerConfig(PretrainedConfig): model_type = "segment_enformer" def __init__( self, features: List[str] = FEATURES, embed_dim: int = 1536, dim_divisible_by: int = 128, **kwargs ): self.features = features self.embed_dim = embed_dim self.dim_divisible_by = dim_divisible_by super().__init__(**kwargs) class SegmentEnformer(PreTrainedModel): config_class = SegmentEnformerConfig def __init__(self, config: SegmentEnformerConfig): super().__init__(config=config) enformer = Enformer.from_pretrained("EleutherAI/enformer-official-rough") self.stem = enformer.stem self.conv_tower = enformer.conv_tower self.transformer = enformer.transformer self.unet_head = TorchUNetHead( features=config.features, embed_dimension=config.embed_dim, nucl_per_token=config.dim_divisible_by, remove_cls_token=False, ) def __call__(self, x): x = rearrange(x, "b n d -> b d n") x = self.stem(x) x = self.conv_tower(x) x = rearrange(x, "b d n -> b n d") x = self.transformer(x) x = rearrange(x, "b n d -> b d n") x = self.unet_head(x) return x