segment_enformer / segment_enformer.py
Yanisadel's picture
Upload SegmentEnformer
f7aa1ae verified
raw
history blame
12.9 kB
from typing import Any, Callable, Dict, List, Optional, Tuple
import torch
import torch.nn as nn
from einops import rearrange
from enformer_pytorch import Enformer
from transformers import PretrainedConfig, PreTrainedModel
def get_activation_fn(activation_name: str) -> Callable:
"""
Returns torch activation function
Args:
activation_name (str): Name of the activation function. Possible values are
'swish', 'relu', 'gelu', 'sin'
Raises:
ValueError: If activation_name is not supported
Returns:
Callable: Activation function
"""
if activation_name == "swish":
return nn.functional.silu # type: ignore
elif activation_name == "relu":
return nn.functional.relu # type: ignore
elif activation_name == "gelu":
return nn.functional.gelu # type: ignore
elif activation_name == "sin":
return torch.sin # type: ignore
else:
raise ValueError(f"Unsupported activation function: {activation_name}")
class TorchDownSample1D(nn.Module):
"""
Torch adaptation of DownSample1D in trix.layers.heads.unet_segmentation_head.py
"""
def __init__(
self,
input_channels: int,
output_channels: int,
activation_fn: str = "swish",
num_layers: int = 2,
):
"""
Args:
input_channels: number of input channels
output_channels: number of output channels.
activation_fn: name of the activation function to use.
Should be one of "gelu",
"gelu-no-approx", "relu", "swish", "silu", "sin".
num_layers: number of convolution layers.
"""
super().__init__()
self.conv_layers = nn.ModuleList(
[
nn.Conv1d(
in_channels=input_channels if i == 0 else output_channels,
out_channels=output_channels,
kernel_size=3,
stride=1,
padding=1,
)
for i in range(num_layers)
]
)
self.avg_pool = nn.AvgPool1d(kernel_size=2, stride=2, padding=0)
self.activation_fn: Callable = get_activation_fn(activation_fn)
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
for conv_layer in self.conv_layers:
x = self.activation_fn(conv_layer(x))
hidden = x
x = self.avg_pool(hidden)
return x, hidden
class TorchUpSample1D(nn.Module):
"""
Torch adaptation of UpSample1D in trix.layers.heads.unet_segmentation_head.py
"""
def __init__(
self,
input_channels: int,
output_channels: int,
activation_fn: str = "swish",
num_layers: int = 2,
interpolation_method: str = "nearest",
):
"""
Args:
input_channels: number of input channels.
output_channels: number of output channels.
activation_fn: name of the activation function to use.
Should be one of "gelu",
"gelu-no-approx", "relu", "swish", "silu", "sin".
interpolation_method: Method to be used for upsampling interpolation.
Should be one of "nearest", "linear", "cubic", "lanczos3", "lanczos5".
num_layers: number of convolution layers.
"""
super().__init__()
self.conv_transpose_layers = nn.ModuleList(
[
nn.ConvTranspose1d(
in_channels=input_channels if i == 0 else output_channels,
out_channels=output_channels,
kernel_size=3,
stride=1,
padding=1,
)
for i in range(num_layers)
]
)
self.interpolation_mode = interpolation_method
self.activation_fn: Callable = get_activation_fn(activation_fn)
def forward(self, x: torch.Tensor) -> torch.Tensor:
for conv_layer in self.conv_transpose_layers:
x = self.activation_fn(conv_layer(x))
x = nn.functional.interpolate(
x,
scale_factor=2,
mode=self.interpolation_mode,
align_corners=False if self.interpolation_mode != "nearest" else None,
)
return x
class TorchFinalConv1D(nn.Module):
"""
Torch adaptation of FinalConv1D in trix.layers.heads.unet_segmentation_head.py
"""
def __init__(
self,
input_channels: int,
output_channels: int,
activation_fn: str = "swish",
num_layers: int = 2,
):
"""
Args:
input_channels: number of input channels
output_channels: number of output channels.
activation_fn: name of the activation function to use.
Should be one of "gelu",
"gelu-no-approx", "relu", "swish", "silu", "sin".
num_layers: number of convolution layers.
name: module name.
"""
super().__init__()
self.conv_layers = nn.ModuleList(
[
nn.Conv1d(
in_channels=input_channels if i == 0 else output_channels,
out_channels=output_channels,
kernel_size=3,
stride=1,
padding=1,
)
for i in range(num_layers)
]
)
self.activation_fn: Callable = get_activation_fn(activation_fn)
def forward(self, x: torch.Tensor) -> torch.Tensor:
for i, conv_layer in enumerate(self.conv_layers):
x = conv_layer(x)
if i < len(self.conv_layers) - 1:
x = self.activation_fn(x)
return x
class TorchUNET1DSegmentationHead(nn.Module):
"""
Torch adaptation of UNET1DSegmentationHead in
trix.layers.heads.unet_segmentation_head.py
"""
def __init__(
self,
num_classes: int,
input_embed_dim: int,
output_channels_list: Tuple[int, ...] = (64, 128, 256),
activation_fn: str = "swish",
num_conv_layers_per_block: int = 2,
upsampling_interpolation_method: str = "nearest",
):
"""
Args:
num_classes: number of classes to segment
output_channels_list: list of the number of output channel at each level of
the UNET
activation_fn: name of the activation function to use.
Should be one of "gelu",
"gelu-no-approx", "relu", "swish", "silu", "sin".
num_conv_layers_per_block: number of convolution layers per block.
upsampling_interpolation_method: Method to be used for
interpolation in upsampling blocks. Should be one of "nearest",
"linear", "cubic", "lanczos3", "lanczos5".
"""
super().__init__()
input_channels_list = (input_embed_dim,) + output_channels_list[:-1]
self.num_pooling_layers = len(output_channels_list)
self.downsample_blocks = nn.ModuleList(
[
TorchDownSample1D(
input_channels=input_channels,
output_channels=output_channels,
activation_fn=activation_fn,
num_layers=num_conv_layers_per_block,
)
for input_channels, output_channels in zip(
input_channels_list, output_channels_list
)
]
)
input_channels_list = (output_channels_list[-1],) + tuple(
list(reversed(output_channels_list))[:-1]
)
self.upsample_blocks = nn.ModuleList(
[
TorchUpSample1D(
input_channels=input_channels,
output_channels=output_channels,
activation_fn=activation_fn,
num_layers=num_conv_layers_per_block,
interpolation_method=upsampling_interpolation_method,
)
for input_channels, output_channels in zip(
input_channels_list, reversed(output_channels_list)
)
]
)
self.final_block = TorchFinalConv1D(
activation_fn=activation_fn,
input_channels=output_channels_list[0],
output_channels=num_classes * 2,
num_layers=num_conv_layers_per_block,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if x.shape[-1] % 2**self.num_pooling_layers:
raise ValueError(
"Input length must be divisible by 2 to the power of "
"the number of pooling layers."
)
hiddens = []
for downsample_block in self.downsample_blocks:
x, hidden = downsample_block(x)
hiddens.append(hidden)
for upsample_block, hidden in zip(self.upsample_blocks, reversed(hiddens)):
x = upsample_block(x) + hidden
x = self.final_block(x)
return x
class TorchUNetHead(nn.Module):
"""
Torch adaptation of UNetHead in
genomics_research/segmentnt/layers/segmentation_head.py
"""
def __init__(
self,
features: List[str],
num_classes: int = 2,
embed_dimension: int = 1024,
nucl_per_token: int = 6,
num_layers: int = 2,
remove_cls_token: bool = True,
):
"""
Args:
features (List[str]): List of features names.
num_classes (int): Number of classes.
embed_dimension (int): Embedding dimension.
nucl_per_token (int): Number of nucleotides per token.
num_layers (int): Number of layers.
remove_cls_token (bool): Whether to remove the CLS token.
name: Name the layer. Defaults to None.
"""
super().__init__()
self._num_features = len(features)
self._num_classes = num_classes
self.nucl_per_token = nucl_per_token
self.remove_cls_token = remove_cls_token
self.unet = TorchUNET1DSegmentationHead(
num_classes=embed_dimension // 2,
output_channels_list=tuple(
embed_dimension * (2**i) for i in range(num_layers)
),
input_embed_dim=embed_dimension,
)
self.fc = nn.Linear(
embed_dimension,
self.nucl_per_token * self._num_classes * self._num_features,
)
def forward(
self, x: torch.Tensor, sequence_mask: Optional[torch.Tensor] = None
) -> Dict[str, torch.Tensor]:
if self.remove_cls_token:
x = x[:, 1:]
x = self.unet(x)
x = nn.functional.silu(x)
x = x.transpose(2, 1)
logits = self.fc(x)
batch_size, seq_len, _ = x.shape
logits = logits.view( # noqa
batch_size,
seq_len * self.nucl_per_token,
self._num_features,
self._num_classes,
)
return {"logits": logits}
FEATURES = [
"protein_coding_gene",
"lncRNA",
"exon",
"intron",
"splice_donor",
"splice_acceptor",
"5UTR",
"3UTR",
"CTCF-bound",
"polyA_signal",
"enhancer_Tissue_specific",
"enhancer_Tissue_invariant",
"promoter_Tissue_specific",
"promoter_Tissue_invariant",
]
class SegmentEnformerConfig(PretrainedConfig):
model_type = "segment_enformer"
def __init__(
self,
features: List[str] = FEATURES,
embed_dim: int = 1536,
dim_divisible_by: int = 128,
**kwargs: Dict[str, Any],
) -> None:
self.features = features
self.embed_dim = embed_dim
self.dim_divisible_by = dim_divisible_by
super().__init__(**kwargs)
class SegmentEnformer(PreTrainedModel):
config_class = SegmentEnformerConfig
def __init__(self, config: SegmentEnformerConfig) -> None:
super().__init__(config=config)
enformer = Enformer.from_pretrained("EleutherAI/enformer-official-rough")
self.stem = enformer.stem
self.conv_tower = enformer.conv_tower
self.transformer = enformer.transformer
self.unet_head = TorchUNetHead(
features=config.features,
embed_dimension=config.embed_dim,
nucl_per_token=config.dim_divisible_by,
remove_cls_token=False,
)
def __call__(self, x: torch.Tensor) -> torch.Tensor:
x = rearrange(x, "b n d -> b d n")
x = self.stem(x)
x = self.conv_tower(x)
x = rearrange(x, "b d n -> b n d")
x = self.transformer(x)
x = rearrange(x, "b n d -> b d n")
x = self.unet_head(x)
return x