|
from typing import Any, Callable, Dict, List, Optional, Tuple |
|
|
|
import borzoi_pytorch |
|
import torch |
|
import torch.nn as nn |
|
from einops import rearrange |
|
from torch import einsum |
|
from transformers import PretrainedConfig, PreTrainedModel |
|
|
|
|
|
def get_activation_fn(activation_name: str) -> Callable: |
|
""" |
|
Returns torch activation function |
|
|
|
Args: |
|
activation_name (str): Name of the activation function. Possible values are |
|
'swish', 'relu', 'gelu', 'sin' |
|
|
|
Raises: |
|
ValueError: If activation_name is not supported |
|
|
|
Returns: |
|
Callable: Activation function |
|
""" |
|
if activation_name == "swish": |
|
return nn.functional.silu |
|
elif activation_name == "relu": |
|
return nn.functional.relu |
|
elif activation_name == "gelu": |
|
return nn.functional.gelu |
|
elif activation_name == "sin": |
|
return torch.sin |
|
else: |
|
raise ValueError(f"Unsupported activation function: {activation_name}") |
|
|
|
|
|
class TorchDownSample1D(nn.Module): |
|
""" |
|
Torch adaptation of DownSample1D in trix.layers.heads.unet_segmentation_head.py |
|
""" |
|
|
|
def __init__( |
|
self, |
|
input_channels: int, |
|
output_channels: int, |
|
activation_fn: str = "swish", |
|
num_layers: int = 2, |
|
): |
|
""" |
|
Args: |
|
input_channels: number of input channels |
|
output_channels: number of output channels. |
|
activation_fn: name of the activation function to use. |
|
Should be one of "gelu", |
|
"gelu-no-approx", "relu", "swish", "silu", "sin". |
|
num_layers: number of convolution layers. |
|
""" |
|
super().__init__() |
|
|
|
self.conv_layers = nn.ModuleList( |
|
[ |
|
nn.Conv1d( |
|
in_channels=input_channels if i == 0 else output_channels, |
|
out_channels=output_channels, |
|
kernel_size=3, |
|
stride=1, |
|
padding=1, |
|
) |
|
for i in range(num_layers) |
|
] |
|
) |
|
|
|
self.avg_pool = nn.AvgPool1d(kernel_size=2, stride=2, padding=0) |
|
|
|
self.activation_fn: Callable = get_activation_fn(activation_fn) |
|
|
|
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
|
for conv_layer in self.conv_layers: |
|
x = self.activation_fn(conv_layer(x)) |
|
hidden = x |
|
x = self.avg_pool(hidden) |
|
return x, hidden |
|
|
|
|
|
class TorchUpSample1D(nn.Module): |
|
""" |
|
Torch adaptation of UpSample1D in trix.layers.heads.unet_segmentation_head.py |
|
""" |
|
|
|
def __init__( |
|
self, |
|
input_channels: int, |
|
output_channels: int, |
|
activation_fn: str = "swish", |
|
num_layers: int = 2, |
|
interpolation_method: str = "nearest", |
|
): |
|
""" |
|
Args: |
|
input_channels: number of input channels. |
|
output_channels: number of output channels. |
|
activation_fn: name of the activation function to use. |
|
Should be one of "gelu", |
|
"gelu-no-approx", "relu", "swish", "silu", "sin". |
|
interpolation_method: Method to be used for upsampling interpolation. |
|
Should be one of "nearest", "linear", "cubic", "lanczos3", "lanczos5". |
|
num_layers: number of convolution layers. |
|
""" |
|
super().__init__() |
|
|
|
self.conv_transpose_layers = nn.ModuleList( |
|
[ |
|
nn.ConvTranspose1d( |
|
in_channels=input_channels if i == 0 else output_channels, |
|
out_channels=output_channels, |
|
kernel_size=3, |
|
stride=1, |
|
padding=1, |
|
) |
|
for i in range(num_layers) |
|
] |
|
) |
|
|
|
self.interpolation_mode = interpolation_method |
|
self.activation_fn: Callable = get_activation_fn(activation_fn) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
for conv_layer in self.conv_transpose_layers: |
|
x = self.activation_fn(conv_layer(x)) |
|
x = nn.functional.interpolate( |
|
x, |
|
scale_factor=2, |
|
mode=self.interpolation_mode, |
|
align_corners=False if self.interpolation_mode != "nearest" else None, |
|
) |
|
return x |
|
|
|
|
|
class TorchFinalConv1D(nn.Module): |
|
""" |
|
Torch adaptation of FinalConv1D in trix.layers.heads.unet_segmentation_head.py |
|
""" |
|
|
|
def __init__( |
|
self, |
|
input_channels: int, |
|
output_channels: int, |
|
activation_fn: str = "swish", |
|
num_layers: int = 2, |
|
): |
|
""" |
|
Args: |
|
input_channels: number of input channels |
|
output_channels: number of output channels. |
|
activation_fn: name of the activation function to use. |
|
Should be one of "gelu", |
|
"gelu-no-approx", "relu", "swish", "silu", "sin". |
|
num_layers: number of convolution layers. |
|
name: module name. |
|
""" |
|
super().__init__() |
|
|
|
self.conv_layers = nn.ModuleList( |
|
[ |
|
nn.Conv1d( |
|
in_channels=input_channels if i == 0 else output_channels, |
|
out_channels=output_channels, |
|
kernel_size=3, |
|
stride=1, |
|
padding=1, |
|
) |
|
for i in range(num_layers) |
|
] |
|
) |
|
|
|
self.activation_fn: Callable = get_activation_fn(activation_fn) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
for i, conv_layer in enumerate(self.conv_layers): |
|
x = conv_layer(x) |
|
if i < len(self.conv_layers) - 1: |
|
x = self.activation_fn(x) |
|
return x |
|
|
|
|
|
class TorchUNET1DSegmentationHead(nn.Module): |
|
""" |
|
Torch adaptation of UNET1DSegmentationHead in |
|
trix.layers.heads.unet_segmentation_head.py |
|
""" |
|
|
|
def __init__( |
|
self, |
|
num_classes: int, |
|
input_embed_dim: int, |
|
output_channels_list: Tuple[int, ...] = (64, 128, 256), |
|
activation_fn: str = "swish", |
|
num_conv_layers_per_block: int = 2, |
|
upsampling_interpolation_method: str = "nearest", |
|
): |
|
""" |
|
Args: |
|
num_classes: number of classes to segment |
|
output_channels_list: list of the number of output channel at each level of |
|
the UNET |
|
activation_fn: name of the activation function to use. |
|
Should be one of "gelu", |
|
"gelu-no-approx", "relu", "swish", "silu", "sin". |
|
num_conv_layers_per_block: number of convolution layers per block. |
|
upsampling_interpolation_method: Method to be used for |
|
interpolation in upsampling blocks. Should be one of "nearest", |
|
"linear", "cubic", "lanczos3", "lanczos5". |
|
""" |
|
super().__init__() |
|
|
|
input_channels_list = (input_embed_dim,) + output_channels_list[:-1] |
|
|
|
self.num_pooling_layers = len(output_channels_list) |
|
self.downsample_blocks = nn.ModuleList( |
|
[ |
|
TorchDownSample1D( |
|
input_channels=input_channels, |
|
output_channels=output_channels, |
|
activation_fn=activation_fn, |
|
num_layers=num_conv_layers_per_block, |
|
) |
|
for input_channels, output_channels in zip( |
|
input_channels_list, output_channels_list |
|
) |
|
] |
|
) |
|
|
|
input_channels_list = (output_channels_list[-1],) + tuple( |
|
list(reversed(output_channels_list))[:-1] |
|
) |
|
|
|
self.upsample_blocks = nn.ModuleList( |
|
[ |
|
TorchUpSample1D( |
|
input_channels=input_channels, |
|
output_channels=output_channels, |
|
activation_fn=activation_fn, |
|
num_layers=num_conv_layers_per_block, |
|
interpolation_method=upsampling_interpolation_method, |
|
) |
|
for input_channels, output_channels in zip( |
|
input_channels_list, reversed(output_channels_list) |
|
) |
|
] |
|
) |
|
|
|
self.final_block = TorchFinalConv1D( |
|
activation_fn=activation_fn, |
|
input_channels=output_channels_list[0], |
|
output_channels=num_classes * 2, |
|
num_layers=num_conv_layers_per_block, |
|
) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
if x.shape[-1] % 2**self.num_pooling_layers: |
|
raise ValueError( |
|
"Input length must be divisible by 2 to the power of " |
|
"the number of pooling layers." |
|
) |
|
|
|
hiddens = [] |
|
for downsample_block in self.downsample_blocks: |
|
x, hidden = downsample_block(x) |
|
hiddens.append(hidden) |
|
|
|
for upsample_block, hidden in zip(self.upsample_blocks, reversed(hiddens)): |
|
x = upsample_block(x) + hidden |
|
|
|
x = self.final_block(x) |
|
return x |
|
|
|
|
|
class TorchUNetHead(nn.Module): |
|
""" |
|
Torch adaptation of UNetHead in |
|
genomics_research/segmentnt/layers/segmentation_head.py |
|
""" |
|
|
|
def __init__( |
|
self, |
|
features: List[str], |
|
num_classes: int = 2, |
|
embed_dimension: int = 1024, |
|
nucl_per_token: int = 6, |
|
num_layers: int = 2, |
|
remove_cls_token: bool = True, |
|
): |
|
""" |
|
Args: |
|
features (List[str]): List of features names. |
|
num_classes (int): Number of classes. |
|
embed_dimension (int): Embedding dimension. |
|
nucl_per_token (int): Number of nucleotides per token. |
|
num_layers (int): Number of layers. |
|
remove_cls_token (bool): Whether to remove the CLS token. |
|
name: Name the layer. Defaults to None. |
|
""" |
|
super().__init__() |
|
self._num_features = len(features) |
|
self._num_classes = num_classes |
|
self.nucl_per_token = nucl_per_token |
|
self.remove_cls_token = remove_cls_token |
|
|
|
self.unet = TorchUNET1DSegmentationHead( |
|
num_classes=embed_dimension // 2, |
|
output_channels_list=tuple( |
|
embed_dimension * (2**i) for i in range(num_layers) |
|
), |
|
input_embed_dim=embed_dimension, |
|
) |
|
|
|
self.fc = nn.Linear( |
|
embed_dimension, |
|
self.nucl_per_token * self._num_classes * self._num_features, |
|
) |
|
|
|
def forward( |
|
self, x: torch.Tensor, sequence_mask: Optional[torch.Tensor] = None |
|
) -> Dict[str, torch.Tensor]: |
|
if self.remove_cls_token: |
|
x = x[:, 1:] |
|
|
|
x = self.unet(x) |
|
x = nn.functional.silu(x) |
|
|
|
x = x.transpose(2, 1) |
|
logits = self.fc(x) |
|
|
|
batch_size, seq_len, _ = x.shape |
|
logits = logits.view( |
|
batch_size, |
|
seq_len * self.nucl_per_token, |
|
self._num_features, |
|
self._num_classes, |
|
) |
|
|
|
return {"logits": logits} |
|
|
|
|
|
FEATURES = [ |
|
"protein_coding_gene", |
|
"lncRNA", |
|
"exon", |
|
"intron", |
|
"splice_donor", |
|
"splice_acceptor", |
|
"5UTR", |
|
"3UTR", |
|
"CTCF-bound", |
|
"polyA_signal", |
|
"enhancer_Tissue_specific", |
|
"enhancer_Tissue_invariant", |
|
"promoter_Tissue_specific", |
|
"promoter_Tissue_invariant", |
|
] |
|
|
|
|
|
class SegmentBorzoiConfig(PretrainedConfig): |
|
model_type = "segment_borzoi" |
|
|
|
def __init__( |
|
self, |
|
features: List[str] = FEATURES, |
|
embed_dim: int = 1536, |
|
dim_divisible_by: int = 32, |
|
attention_dim_key: int = 64, |
|
num_attention_heads: int = 8, |
|
num_rel_pos_features: int = 32, |
|
**kwargs: Dict[str, Any], |
|
): |
|
self.features = features |
|
self.embed_dim = embed_dim |
|
self.dim_divisible_by = dim_divisible_by |
|
self.attention_dim_key = attention_dim_key |
|
self.num_attention_heads = num_attention_heads |
|
self.num_rel_pos_features = num_rel_pos_features |
|
|
|
super().__init__(**kwargs) |
|
|
|
|
|
class SegmentBorzoi(PreTrainedModel): |
|
config_class = SegmentBorzoiConfig |
|
|
|
def __init__(self, config: SegmentBorzoiConfig): |
|
super().__init__(config=config) |
|
borzoi = borzoi_pytorch.Borzoi.from_pretrained("johahi/borzoi-replicate-0") |
|
|
|
|
|
self.stem = borzoi.conv_dna |
|
|
|
|
|
self.res_tower = borzoi.res_tower |
|
self.unet1 = borzoi.unet1 |
|
self._max_pool = borzoi._max_pool |
|
|
|
|
|
self.transformer = borzoi.transformer |
|
|
|
|
|
self.horizontal_conv1 = borzoi.horizontal_conv1 |
|
self.horizontal_conv0 = borzoi.horizontal_conv0 |
|
self.upsampling_unet1 = borzoi.upsampling_unet1 |
|
self.upsampling_unet0 = borzoi.upsampling_unet0 |
|
self.separable1 = borzoi.separable1 |
|
self.separable0 = borzoi.separable0 |
|
|
|
|
|
self.crop = borzoi.crop |
|
|
|
|
|
self.final_joined_convs = borzoi.final_joined_convs |
|
|
|
self.unet_head = TorchUNetHead( |
|
features=config.features, |
|
embed_dimension=config.embed_dim, |
|
nucl_per_token=config.dim_divisible_by, |
|
remove_cls_token=False, |
|
) |
|
|
|
|
|
for layer in self.transformer: |
|
layer[0].fn[1] = BorzoiAttentionLayer( |
|
config.embed_dim, |
|
heads=config.num_attention_heads, |
|
dim_key=config.attention_dim_key, |
|
dim_value=config.embed_dim // config.num_attention_heads, |
|
dropout=0.05, |
|
pos_dropout=0.01, |
|
num_rel_pos_features=config.num_rel_pos_features, |
|
) |
|
|
|
|
|
self.unet_head.unet.downsample_blocks[0].conv_layers[0] = nn.Conv1d( |
|
in_channels=1920, out_channels=1536, kernel_size=3, stride=1, padding=1 |
|
) |
|
|
|
|
|
self.separable1.conv_layer[1].bias = None |
|
self.separable0.conv_layer[1].bias = None |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
|
x = x.transpose(1, 2) |
|
x = self.stem(x) |
|
|
|
|
|
x_unet0 = self.res_tower(x) |
|
x_unet1 = self.unet1(x_unet0) |
|
x = self._max_pool(x_unet1) |
|
|
|
|
|
x = x.permute(0, 2, 1) |
|
x = self.transformer(x) |
|
x = x.permute(0, 2, 1) |
|
|
|
|
|
x_unet1 = self.horizontal_conv1(x_unet1) |
|
x_unet0 = self.horizontal_conv0(x_unet0) |
|
|
|
|
|
x = self.upsampling_unet1(x) |
|
x += x_unet1 |
|
x = self.separable1(x) |
|
x = self.upsampling_unet0(x) |
|
x += x_unet0 |
|
x = self.separable0(x) |
|
|
|
|
|
x = self.crop(x.permute(0, 2, 1)) |
|
x = x.permute(0, 2, 1) |
|
|
|
|
|
x = self.final_joined_convs(x) |
|
|
|
x = self.unet_head(x) |
|
|
|
return x |
|
|
|
|
|
|
|
|
|
def _prepend_dims(tensor: torch.Tensor, num_dims: int) -> torch.Tensor: |
|
"""Prepends dimensions to match the required shape.""" |
|
for _ in range(num_dims - tensor.dim()): |
|
tensor = tensor.unsqueeze(0) |
|
return tensor |
|
|
|
|
|
def get_positional_features_central_mask_borzoi( |
|
positions: torch.Tensor, feature_size: int, seq_length: int |
|
) -> torch.Tensor: |
|
"""Positional features using a central mask (allow only central features).""" |
|
pow_rate = torch.exp(torch.log(torch.tensor(seq_length + 1.0)) / feature_size) |
|
center_widths = torch.pow(pow_rate, torch.arange(1, feature_size + 1).float()) - 1 |
|
center_widths = _prepend_dims(center_widths, positions.ndim) |
|
outputs = (center_widths > torch.abs(positions).unsqueeze(-1)).float() |
|
return outputs |
|
|
|
|
|
def get_positional_embed_borzoi(seq_len: int, feature_size: int) -> torch.Tensor: |
|
""" |
|
Compute positional embedding for Borzoi. Note that it is different than the one |
|
used in Enformer. |
|
""" |
|
distances = torch.arange(-seq_len + 1, seq_len) |
|
|
|
num_components = 2 |
|
|
|
if (feature_size % num_components) != 0: |
|
raise ValueError( |
|
f"feature size is not divisible by number of components ({num_components})" |
|
) |
|
|
|
num_basis_per_class = feature_size // num_components |
|
|
|
embeddings = [] |
|
|
|
embeddings.append( |
|
get_positional_features_central_mask_borzoi( |
|
distances, num_basis_per_class, seq_len |
|
) |
|
) |
|
|
|
embeddings = torch.cat(embeddings, dim=-1) |
|
embeddings = torch.cat( |
|
(embeddings, torch.sign(distances).unsqueeze(-1) * embeddings), dim=-1 |
|
) |
|
return embeddings |
|
|
|
|
|
def relative_shift(x: torch.Tensor) -> torch.Tensor: |
|
to_pad = torch.zeros_like(x[..., :1]) |
|
x = torch.cat((to_pad, x), dim=-1) |
|
_, h, t1, t2 = x.shape |
|
x = x.reshape(-1, h, t2, t1) |
|
x = x[:, :, 1:, :] |
|
x = x.reshape(-1, h, t1, t2 - 1) |
|
return x[..., : ((t2 + 1) // 2)] |
|
|
|
|
|
class BorzoiAttentionLayer(nn.Module): |
|
def __init__( |
|
self, |
|
dim, |
|
*, |
|
num_rel_pos_features, |
|
heads=8, |
|
dim_key=64, |
|
dim_value=64, |
|
dropout=0.0, |
|
pos_dropout=0.0, |
|
) -> None: |
|
super().__init__() |
|
self.scale = dim_key**-0.5 |
|
self.heads = heads |
|
|
|
self.to_q = nn.Linear(dim, dim_key * heads, bias=False) |
|
self.to_k = nn.Linear(dim, dim_key * heads, bias=False) |
|
self.to_v = nn.Linear(dim, dim_value * heads, bias=False) |
|
|
|
self.to_out = nn.Linear(dim_value * heads, dim) |
|
nn.init.zeros_(self.to_out.weight) |
|
nn.init.zeros_(self.to_out.bias) |
|
|
|
self.num_rel_pos_features = num_rel_pos_features |
|
|
|
self.to_rel_k = nn.Linear(num_rel_pos_features, dim_key * heads, bias=False) |
|
self.rel_content_bias = nn.Parameter( |
|
torch.randn(1, heads, 1, dim_key) |
|
) |
|
self.rel_pos_bias = nn.Parameter( |
|
torch.randn(1, heads, 1, dim_key) |
|
) |
|
|
|
|
|
|
|
self.pos_dropout = nn.Dropout(pos_dropout) |
|
self.attn_dropout = nn.Dropout(dropout) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
n, h = x.shape[-2], self.heads |
|
|
|
q = self.to_q(x) |
|
k = self.to_k(x) |
|
v = self.to_v(x) |
|
|
|
q, k, v = map( |
|
lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), |
|
(q, k, v), |
|
) |
|
|
|
q = q * self.scale |
|
|
|
content_logits = einsum( |
|
"b h i d, b h j d -> b h i j", q + self.rel_content_bias, k |
|
) |
|
|
|
positions = get_positional_embed_borzoi(n, self.num_rel_pos_features) |
|
positions = self.pos_dropout(positions) |
|
rel_k = self.to_rel_k(positions) |
|
|
|
rel_k = rearrange(rel_k, "n (h d) -> h n d", h=h) |
|
rel_logits = einsum("b h i d, h j d -> b h i j", q + self.rel_pos_bias, rel_k) |
|
rel_logits = relative_shift(rel_logits) |
|
|
|
logits = content_logits + rel_logits |
|
attn = logits.softmax(dim=-1) |
|
attn = self.attn_dropout(attn) |
|
|
|
out = einsum("b h i j, b h j d -> b h i d", attn, v) |
|
out = rearrange(out, "b h n d -> b n (h d)") |
|
return self.to_out(out) |
|
|