segment_borzoi / segment_borzoi.py
Yanisadel's picture
Upload SegmentBorzoi
e0c22b2 verified
from typing import Any, Callable, Dict, List, Optional, Tuple
import borzoi_pytorch
import torch
import torch.nn as nn
from einops import rearrange
from torch import einsum
from transformers import PretrainedConfig, PreTrainedModel
def get_activation_fn(activation_name: str) -> Callable:
"""
Returns torch activation function
Args:
activation_name (str): Name of the activation function. Possible values are
'swish', 'relu', 'gelu', 'sin'
Raises:
ValueError: If activation_name is not supported
Returns:
Callable: Activation function
"""
if activation_name == "swish":
return nn.functional.silu # type: ignore
elif activation_name == "relu":
return nn.functional.relu # type: ignore
elif activation_name == "gelu":
return nn.functional.gelu # type: ignore
elif activation_name == "sin":
return torch.sin # type: ignore
else:
raise ValueError(f"Unsupported activation function: {activation_name}")
class TorchDownSample1D(nn.Module):
"""
Torch adaptation of DownSample1D in trix.layers.heads.unet_segmentation_head.py
"""
def __init__(
self,
input_channels: int,
output_channels: int,
activation_fn: str = "swish",
num_layers: int = 2,
):
"""
Args:
input_channels: number of input channels
output_channels: number of output channels.
activation_fn: name of the activation function to use.
Should be one of "gelu",
"gelu-no-approx", "relu", "swish", "silu", "sin".
num_layers: number of convolution layers.
"""
super().__init__()
self.conv_layers = nn.ModuleList(
[
nn.Conv1d(
in_channels=input_channels if i == 0 else output_channels,
out_channels=output_channels,
kernel_size=3,
stride=1,
padding=1,
)
for i in range(num_layers)
]
)
self.avg_pool = nn.AvgPool1d(kernel_size=2, stride=2, padding=0)
self.activation_fn: Callable = get_activation_fn(activation_fn)
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
for conv_layer in self.conv_layers:
x = self.activation_fn(conv_layer(x))
hidden = x
x = self.avg_pool(hidden)
return x, hidden
class TorchUpSample1D(nn.Module):
"""
Torch adaptation of UpSample1D in trix.layers.heads.unet_segmentation_head.py
"""
def __init__(
self,
input_channels: int,
output_channels: int,
activation_fn: str = "swish",
num_layers: int = 2,
interpolation_method: str = "nearest",
):
"""
Args:
input_channels: number of input channels.
output_channels: number of output channels.
activation_fn: name of the activation function to use.
Should be one of "gelu",
"gelu-no-approx", "relu", "swish", "silu", "sin".
interpolation_method: Method to be used for upsampling interpolation.
Should be one of "nearest", "linear", "cubic", "lanczos3", "lanczos5".
num_layers: number of convolution layers.
"""
super().__init__()
self.conv_transpose_layers = nn.ModuleList(
[
nn.ConvTranspose1d(
in_channels=input_channels if i == 0 else output_channels,
out_channels=output_channels,
kernel_size=3,
stride=1,
padding=1,
)
for i in range(num_layers)
]
)
self.interpolation_mode = interpolation_method
self.activation_fn: Callable = get_activation_fn(activation_fn)
def forward(self, x: torch.Tensor) -> torch.Tensor:
for conv_layer in self.conv_transpose_layers:
x = self.activation_fn(conv_layer(x))
x = nn.functional.interpolate(
x,
scale_factor=2,
mode=self.interpolation_mode,
align_corners=False if self.interpolation_mode != "nearest" else None,
)
return x
class TorchFinalConv1D(nn.Module):
"""
Torch adaptation of FinalConv1D in trix.layers.heads.unet_segmentation_head.py
"""
def __init__(
self,
input_channels: int,
output_channels: int,
activation_fn: str = "swish",
num_layers: int = 2,
):
"""
Args:
input_channels: number of input channels
output_channels: number of output channels.
activation_fn: name of the activation function to use.
Should be one of "gelu",
"gelu-no-approx", "relu", "swish", "silu", "sin".
num_layers: number of convolution layers.
name: module name.
"""
super().__init__()
self.conv_layers = nn.ModuleList(
[
nn.Conv1d(
in_channels=input_channels if i == 0 else output_channels,
out_channels=output_channels,
kernel_size=3,
stride=1,
padding=1,
)
for i in range(num_layers)
]
)
self.activation_fn: Callable = get_activation_fn(activation_fn)
def forward(self, x: torch.Tensor) -> torch.Tensor:
for i, conv_layer in enumerate(self.conv_layers):
x = conv_layer(x)
if i < len(self.conv_layers) - 1:
x = self.activation_fn(x)
return x
class TorchUNET1DSegmentationHead(nn.Module):
"""
Torch adaptation of UNET1DSegmentationHead in
trix.layers.heads.unet_segmentation_head.py
"""
def __init__(
self,
num_classes: int,
input_embed_dim: int,
output_channels_list: Tuple[int, ...] = (64, 128, 256),
activation_fn: str = "swish",
num_conv_layers_per_block: int = 2,
upsampling_interpolation_method: str = "nearest",
):
"""
Args:
num_classes: number of classes to segment
output_channels_list: list of the number of output channel at each level of
the UNET
activation_fn: name of the activation function to use.
Should be one of "gelu",
"gelu-no-approx", "relu", "swish", "silu", "sin".
num_conv_layers_per_block: number of convolution layers per block.
upsampling_interpolation_method: Method to be used for
interpolation in upsampling blocks. Should be one of "nearest",
"linear", "cubic", "lanczos3", "lanczos5".
"""
super().__init__()
input_channels_list = (input_embed_dim,) + output_channels_list[:-1]
self.num_pooling_layers = len(output_channels_list)
self.downsample_blocks = nn.ModuleList(
[
TorchDownSample1D(
input_channels=input_channels,
output_channels=output_channels,
activation_fn=activation_fn,
num_layers=num_conv_layers_per_block,
)
for input_channels, output_channels in zip(
input_channels_list, output_channels_list
)
]
)
input_channels_list = (output_channels_list[-1],) + tuple(
list(reversed(output_channels_list))[:-1]
)
self.upsample_blocks = nn.ModuleList(
[
TorchUpSample1D(
input_channels=input_channels,
output_channels=output_channels,
activation_fn=activation_fn,
num_layers=num_conv_layers_per_block,
interpolation_method=upsampling_interpolation_method,
)
for input_channels, output_channels in zip(
input_channels_list, reversed(output_channels_list)
)
]
)
self.final_block = TorchFinalConv1D(
activation_fn=activation_fn,
input_channels=output_channels_list[0],
output_channels=num_classes * 2,
num_layers=num_conv_layers_per_block,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if x.shape[-1] % 2**self.num_pooling_layers:
raise ValueError(
"Input length must be divisible by 2 to the power of "
"the number of pooling layers."
)
hiddens = []
for downsample_block in self.downsample_blocks:
x, hidden = downsample_block(x)
hiddens.append(hidden)
for upsample_block, hidden in zip(self.upsample_blocks, reversed(hiddens)):
x = upsample_block(x) + hidden
x = self.final_block(x)
return x
class TorchUNetHead(nn.Module):
"""
Torch adaptation of UNetHead in
genomics_research/segmentnt/layers/segmentation_head.py
"""
def __init__(
self,
features: List[str],
num_classes: int = 2,
embed_dimension: int = 1024,
nucl_per_token: int = 6,
num_layers: int = 2,
remove_cls_token: bool = True,
):
"""
Args:
features (List[str]): List of features names.
num_classes (int): Number of classes.
embed_dimension (int): Embedding dimension.
nucl_per_token (int): Number of nucleotides per token.
num_layers (int): Number of layers.
remove_cls_token (bool): Whether to remove the CLS token.
name: Name the layer. Defaults to None.
"""
super().__init__()
self._num_features = len(features)
self._num_classes = num_classes
self.nucl_per_token = nucl_per_token
self.remove_cls_token = remove_cls_token
self.unet = TorchUNET1DSegmentationHead(
num_classes=embed_dimension // 2,
output_channels_list=tuple(
embed_dimension * (2**i) for i in range(num_layers)
),
input_embed_dim=embed_dimension,
)
self.fc = nn.Linear(
embed_dimension,
self.nucl_per_token * self._num_classes * self._num_features,
)
def forward(
self, x: torch.Tensor, sequence_mask: Optional[torch.Tensor] = None
) -> Dict[str, torch.Tensor]:
if self.remove_cls_token:
x = x[:, 1:]
x = self.unet(x)
x = nn.functional.silu(x)
x = x.transpose(2, 1)
logits = self.fc(x)
batch_size, seq_len, _ = x.shape
logits = logits.view( # noqa
batch_size,
seq_len * self.nucl_per_token,
self._num_features,
self._num_classes,
)
return {"logits": logits}
FEATURES = [
"protein_coding_gene",
"lncRNA",
"exon",
"intron",
"splice_donor",
"splice_acceptor",
"5UTR",
"3UTR",
"CTCF-bound",
"polyA_signal",
"enhancer_Tissue_specific",
"enhancer_Tissue_invariant",
"promoter_Tissue_specific",
"promoter_Tissue_invariant",
]
class SegmentBorzoiConfig(PretrainedConfig):
model_type = "segment_borzoi"
def __init__(
self,
features: List[str] = FEATURES,
embed_dim: int = 1536,
dim_divisible_by: int = 32,
attention_dim_key: int = 64,
num_attention_heads: int = 8,
num_rel_pos_features: int = 32,
**kwargs: Dict[str, Any],
):
self.features = features
self.embed_dim = embed_dim
self.dim_divisible_by = dim_divisible_by
self.attention_dim_key = attention_dim_key
self.num_attention_heads = num_attention_heads
self.num_rel_pos_features = num_rel_pos_features
super().__init__(**kwargs)
class SegmentBorzoi(PreTrainedModel):
config_class = SegmentBorzoiConfig
def __init__(self, config: SegmentBorzoiConfig):
super().__init__(config=config)
borzoi = borzoi_pytorch.Borzoi.from_pretrained("johahi/borzoi-replicate-0")
# Stem
self.stem = borzoi.conv_dna
# Conv tower
self.res_tower = borzoi.res_tower
self.unet1 = borzoi.unet1
self._max_pool = borzoi._max_pool
# Transformer tower
self.transformer = borzoi.transformer
# UNet convolution layers
self.horizontal_conv1 = borzoi.horizontal_conv1
self.horizontal_conv0 = borzoi.horizontal_conv0
self.upsampling_unet1 = borzoi.upsampling_unet1
self.upsampling_unet0 = borzoi.upsampling_unet0
self.separable1 = borzoi.separable1
self.separable0 = borzoi.separable0
# Target length crop
self.crop = borzoi.crop
# Final convolution block
self.final_joined_convs = borzoi.final_joined_convs
self.unet_head = TorchUNetHead(
features=config.features,
embed_dimension=config.embed_dim,
nucl_per_token=config.dim_divisible_by,
remove_cls_token=False,
)
# Correct transformer
for layer in self.transformer:
layer[0].fn[1] = BorzoiAttentionLayer( # type: ignore
config.embed_dim,
heads=config.num_attention_heads,
dim_key=config.attention_dim_key,
dim_value=config.embed_dim // config.num_attention_heads,
dropout=0.05,
pos_dropout=0.01,
num_rel_pos_features=config.num_rel_pos_features,
)
# Correct conv layer in downsample block
self.unet_head.unet.downsample_blocks[0].conv_layers[0] = nn.Conv1d(
in_channels=1920, out_channels=1536, kernel_size=3, stride=1, padding=1
)
# Correct bias in separable layers
self.separable1.conv_layer[1].bias = None
self.separable0.conv_layer[1].bias = None
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Stem
x = x.transpose(1, 2)
x = self.stem(x)
# Conv tower
x_unet0 = self.res_tower(x)
x_unet1 = self.unet1(x_unet0)
x = self._max_pool(x_unet1)
# Transformer tower
x = x.permute(0, 2, 1)
x = self.transformer(x)
x = x.permute(0, 2, 1)
# UNet conv
x_unet1 = self.horizontal_conv1(x_unet1)
x_unet0 = self.horizontal_conv0(x_unet0)
# UNet upsampling and separable convolutions
x = self.upsampling_unet1(x)
x += x_unet1
x = self.separable1(x)
x = self.upsampling_unet0(x)
x += x_unet0
x = self.separable0(x)
# Target length crop
x = self.crop(x.permute(0, 2, 1))
x = x.permute(0, 2, 1)
# Final convolution block
x = self.final_joined_convs(x)
x = self.unet_head(x)
return x
# Define custom attention layer for PyTorch model because Attention layer from the
# imported model is not the same (the positional embeddings are not the same)
def _prepend_dims(tensor: torch.Tensor, num_dims: int) -> torch.Tensor:
"""Prepends dimensions to match the required shape."""
for _ in range(num_dims - tensor.dim()):
tensor = tensor.unsqueeze(0)
return tensor
def get_positional_features_central_mask_borzoi(
positions: torch.Tensor, feature_size: int, seq_length: int
) -> torch.Tensor:
"""Positional features using a central mask (allow only central features)."""
pow_rate = torch.exp(torch.log(torch.tensor(seq_length + 1.0)) / feature_size)
center_widths = torch.pow(pow_rate, torch.arange(1, feature_size + 1).float()) - 1
center_widths = _prepend_dims(center_widths, positions.ndim)
outputs = (center_widths > torch.abs(positions).unsqueeze(-1)).float()
return outputs
def get_positional_embed_borzoi(seq_len: int, feature_size: int) -> torch.Tensor:
"""
Compute positional embedding for Borzoi. Note that it is different than the one
used in Enformer.
"""
distances = torch.arange(-seq_len + 1, seq_len)
num_components = 2
if (feature_size % num_components) != 0:
raise ValueError(
f"feature size is not divisible by number of components ({num_components})"
)
num_basis_per_class = feature_size // num_components
embeddings = []
embeddings.append(
get_positional_features_central_mask_borzoi(
distances, num_basis_per_class, seq_len
)
)
embeddings = torch.cat(embeddings, dim=-1)
embeddings = torch.cat(
(embeddings, torch.sign(distances).unsqueeze(-1) * embeddings), dim=-1
)
return embeddings
def relative_shift(x: torch.Tensor) -> torch.Tensor:
to_pad = torch.zeros_like(x[..., :1])
x = torch.cat((to_pad, x), dim=-1)
_, h, t1, t2 = x.shape
x = x.reshape(-1, h, t2, t1) # noqa: FKA100
x = x[:, :, 1:, :]
x = x.reshape(-1, h, t1, t2 - 1) # noqa: FKA100
return x[..., : ((t2 + 1) // 2)]
class BorzoiAttentionLayer(nn.Module):
def __init__( # type: ignore
self,
dim,
*,
num_rel_pos_features,
heads=8,
dim_key=64,
dim_value=64,
dropout=0.0,
pos_dropout=0.0,
) -> None:
super().__init__()
self.scale = dim_key**-0.5
self.heads = heads
self.to_q = nn.Linear(dim, dim_key * heads, bias=False)
self.to_k = nn.Linear(dim, dim_key * heads, bias=False)
self.to_v = nn.Linear(dim, dim_value * heads, bias=False)
self.to_out = nn.Linear(dim_value * heads, dim)
nn.init.zeros_(self.to_out.weight)
nn.init.zeros_(self.to_out.bias)
self.num_rel_pos_features = num_rel_pos_features
self.to_rel_k = nn.Linear(num_rel_pos_features, dim_key * heads, bias=False)
self.rel_content_bias = nn.Parameter(
torch.randn(1, heads, 1, dim_key) # noqa: FKA100
)
self.rel_pos_bias = nn.Parameter(
torch.randn(1, heads, 1, dim_key) # noqa: FKA100
)
# dropouts
self.pos_dropout = nn.Dropout(pos_dropout)
self.attn_dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor) -> torch.Tensor:
n, h = x.shape[-2], self.heads
q = self.to_q(x)
k = self.to_k(x)
v = self.to_v(x)
q, k, v = map( # noqa
lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), # type: ignore
(q, k, v),
)
q = q * self.scale
content_logits = einsum(
"b h i d, b h j d -> b h i j", q + self.rel_content_bias, k
)
positions = get_positional_embed_borzoi(n, self.num_rel_pos_features)
positions = self.pos_dropout(positions)
rel_k = self.to_rel_k(positions)
rel_k = rearrange(rel_k, "n (h d) -> h n d", h=h)
rel_logits = einsum("b h i d, h j d -> b h i j", q + self.rel_pos_bias, rel_k)
rel_logits = relative_shift(rel_logits)
logits = content_logits + rel_logits
attn = logits.softmax(dim=-1)
attn = self.attn_dropout(attn)
out = einsum("b h i j, b h j d -> b h i d", attn, v)
out = rearrange(out, "b h n d -> b n (h d)")
return self.to_out(out)