from typing import Any, Callable, Dict, List, Optional, Tuple import borzoi_pytorch import torch import torch.nn as nn from einops import rearrange from torch import einsum from transformers import PretrainedConfig, PreTrainedModel def get_activation_fn(activation_name: str) -> Callable: """ Returns torch activation function Args: activation_name (str): Name of the activation function. Possible values are 'swish', 'relu', 'gelu', 'sin' Raises: ValueError: If activation_name is not supported Returns: Callable: Activation function """ if activation_name == "swish": return nn.functional.silu # type: ignore elif activation_name == "relu": return nn.functional.relu # type: ignore elif activation_name == "gelu": return nn.functional.gelu # type: ignore elif activation_name == "sin": return torch.sin # type: ignore else: raise ValueError(f"Unsupported activation function: {activation_name}") class TorchDownSample1D(nn.Module): """ Torch adaptation of DownSample1D in trix.layers.heads.unet_segmentation_head.py """ def __init__( self, input_channels: int, output_channels: int, activation_fn: str = "swish", num_layers: int = 2, ): """ Args: input_channels: number of input channels output_channels: number of output channels. activation_fn: name of the activation function to use. Should be one of "gelu", "gelu-no-approx", "relu", "swish", "silu", "sin". num_layers: number of convolution layers. """ super().__init__() self.conv_layers = nn.ModuleList( [ nn.Conv1d( in_channels=input_channels if i == 0 else output_channels, out_channels=output_channels, kernel_size=3, stride=1, padding=1, ) for i in range(num_layers) ] ) self.avg_pool = nn.AvgPool1d(kernel_size=2, stride=2, padding=0) self.activation_fn: Callable = get_activation_fn(activation_fn) def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: for conv_layer in self.conv_layers: x = self.activation_fn(conv_layer(x)) hidden = x x = self.avg_pool(hidden) return x, hidden class TorchUpSample1D(nn.Module): """ Torch adaptation of UpSample1D in trix.layers.heads.unet_segmentation_head.py """ def __init__( self, input_channels: int, output_channels: int, activation_fn: str = "swish", num_layers: int = 2, interpolation_method: str = "nearest", ): """ Args: input_channels: number of input channels. output_channels: number of output channels. activation_fn: name of the activation function to use. Should be one of "gelu", "gelu-no-approx", "relu", "swish", "silu", "sin". interpolation_method: Method to be used for upsampling interpolation. Should be one of "nearest", "linear", "cubic", "lanczos3", "lanczos5". num_layers: number of convolution layers. """ super().__init__() self.conv_transpose_layers = nn.ModuleList( [ nn.ConvTranspose1d( in_channels=input_channels if i == 0 else output_channels, out_channels=output_channels, kernel_size=3, stride=1, padding=1, ) for i in range(num_layers) ] ) self.interpolation_mode = interpolation_method self.activation_fn: Callable = get_activation_fn(activation_fn) def forward(self, x: torch.Tensor) -> torch.Tensor: for conv_layer in self.conv_transpose_layers: x = self.activation_fn(conv_layer(x)) x = nn.functional.interpolate( x, scale_factor=2, mode=self.interpolation_mode, align_corners=False if self.interpolation_mode != "nearest" else None, ) return x class TorchFinalConv1D(nn.Module): """ Torch adaptation of FinalConv1D in trix.layers.heads.unet_segmentation_head.py """ def __init__( self, input_channels: int, output_channels: int, activation_fn: str = "swish", num_layers: int = 2, ): """ Args: input_channels: number of input channels output_channels: number of output channels. activation_fn: name of the activation function to use. Should be one of "gelu", "gelu-no-approx", "relu", "swish", "silu", "sin". num_layers: number of convolution layers. name: module name. """ super().__init__() self.conv_layers = nn.ModuleList( [ nn.Conv1d( in_channels=input_channels if i == 0 else output_channels, out_channels=output_channels, kernel_size=3, stride=1, padding=1, ) for i in range(num_layers) ] ) self.activation_fn: Callable = get_activation_fn(activation_fn) def forward(self, x: torch.Tensor) -> torch.Tensor: for i, conv_layer in enumerate(self.conv_layers): x = conv_layer(x) if i < len(self.conv_layers) - 1: x = self.activation_fn(x) return x class TorchUNET1DSegmentationHead(nn.Module): """ Torch adaptation of UNET1DSegmentationHead in trix.layers.heads.unet_segmentation_head.py """ def __init__( self, num_classes: int, input_embed_dim: int, output_channels_list: Tuple[int, ...] = (64, 128, 256), activation_fn: str = "swish", num_conv_layers_per_block: int = 2, upsampling_interpolation_method: str = "nearest", ): """ Args: num_classes: number of classes to segment output_channels_list: list of the number of output channel at each level of the UNET activation_fn: name of the activation function to use. Should be one of "gelu", "gelu-no-approx", "relu", "swish", "silu", "sin". num_conv_layers_per_block: number of convolution layers per block. upsampling_interpolation_method: Method to be used for interpolation in upsampling blocks. Should be one of "nearest", "linear", "cubic", "lanczos3", "lanczos5". """ super().__init__() input_channels_list = (input_embed_dim,) + output_channels_list[:-1] self.num_pooling_layers = len(output_channels_list) self.downsample_blocks = nn.ModuleList( [ TorchDownSample1D( input_channels=input_channels, output_channels=output_channels, activation_fn=activation_fn, num_layers=num_conv_layers_per_block, ) for input_channels, output_channels in zip( input_channels_list, output_channels_list ) ] ) input_channels_list = (output_channels_list[-1],) + tuple( list(reversed(output_channels_list))[:-1] ) self.upsample_blocks = nn.ModuleList( [ TorchUpSample1D( input_channels=input_channels, output_channels=output_channels, activation_fn=activation_fn, num_layers=num_conv_layers_per_block, interpolation_method=upsampling_interpolation_method, ) for input_channels, output_channels in zip( input_channels_list, reversed(output_channels_list) ) ] ) self.final_block = TorchFinalConv1D( activation_fn=activation_fn, input_channels=output_channels_list[0], output_channels=num_classes * 2, num_layers=num_conv_layers_per_block, ) def forward(self, x: torch.Tensor) -> torch.Tensor: if x.shape[-1] % 2**self.num_pooling_layers: raise ValueError( "Input length must be divisible by 2 to the power of " "the number of pooling layers." ) hiddens = [] for downsample_block in self.downsample_blocks: x, hidden = downsample_block(x) hiddens.append(hidden) for upsample_block, hidden in zip(self.upsample_blocks, reversed(hiddens)): x = upsample_block(x) + hidden x = self.final_block(x) return x class TorchUNetHead(nn.Module): """ Torch adaptation of UNetHead in genomics_research/segmentnt/layers/segmentation_head.py """ def __init__( self, features: List[str], num_classes: int = 2, embed_dimension: int = 1024, nucl_per_token: int = 6, num_layers: int = 2, remove_cls_token: bool = True, ): """ Args: features (List[str]): List of features names. num_classes (int): Number of classes. embed_dimension (int): Embedding dimension. nucl_per_token (int): Number of nucleotides per token. num_layers (int): Number of layers. remove_cls_token (bool): Whether to remove the CLS token. name: Name the layer. Defaults to None. """ super().__init__() self._num_features = len(features) self._num_classes = num_classes self.nucl_per_token = nucl_per_token self.remove_cls_token = remove_cls_token self.unet = TorchUNET1DSegmentationHead( num_classes=embed_dimension // 2, output_channels_list=tuple( embed_dimension * (2**i) for i in range(num_layers) ), input_embed_dim=embed_dimension, ) self.fc = nn.Linear( embed_dimension, self.nucl_per_token * self._num_classes * self._num_features, ) def forward( self, x: torch.Tensor, sequence_mask: Optional[torch.Tensor] = None ) -> Dict[str, torch.Tensor]: if self.remove_cls_token: x = x[:, 1:] x = self.unet(x) x = nn.functional.silu(x) x = x.transpose(2, 1) logits = self.fc(x) batch_size, seq_len, _ = x.shape logits = logits.view( # noqa batch_size, seq_len * self.nucl_per_token, self._num_features, self._num_classes, ) return {"logits": logits} FEATURES = [ "protein_coding_gene", "lncRNA", "exon", "intron", "splice_donor", "splice_acceptor", "5UTR", "3UTR", "CTCF-bound", "polyA_signal", "enhancer_Tissue_specific", "enhancer_Tissue_invariant", "promoter_Tissue_specific", "promoter_Tissue_invariant", ] class SegmentBorzoiConfig(PretrainedConfig): model_type = "segment_borzoi" def __init__( self, features: List[str] = FEATURES, embed_dim: int = 1536, dim_divisible_by: int = 32, attention_dim_key: int = 64, num_attention_heads: int = 8, num_rel_pos_features: int = 32, **kwargs: Dict[str, Any], ): self.features = features self.embed_dim = embed_dim self.dim_divisible_by = dim_divisible_by self.attention_dim_key = attention_dim_key self.num_attention_heads = num_attention_heads self.num_rel_pos_features = num_rel_pos_features super().__init__(**kwargs) class SegmentBorzoi(PreTrainedModel): config_class = SegmentBorzoiConfig def __init__(self, config: SegmentBorzoiConfig): super().__init__(config=config) borzoi = borzoi_pytorch.Borzoi.from_pretrained("johahi/borzoi-replicate-0") # Stem self.stem = borzoi.conv_dna # Conv tower self.res_tower = borzoi.res_tower self.unet1 = borzoi.unet1 self._max_pool = borzoi._max_pool # Transformer tower self.transformer = borzoi.transformer # UNet convolution layers self.horizontal_conv1 = borzoi.horizontal_conv1 self.horizontal_conv0 = borzoi.horizontal_conv0 self.upsampling_unet1 = borzoi.upsampling_unet1 self.upsampling_unet0 = borzoi.upsampling_unet0 self.separable1 = borzoi.separable1 self.separable0 = borzoi.separable0 # Target length crop self.crop = borzoi.crop # Final convolution block self.final_joined_convs = borzoi.final_joined_convs self.unet_head = TorchUNetHead( features=config.features, embed_dimension=config.embed_dim, nucl_per_token=config.dim_divisible_by, remove_cls_token=False, ) # Correct transformer for layer in self.transformer: layer[0].fn[1] = BorzoiAttentionLayer( # type: ignore config.embed_dim, heads=config.num_attention_heads, dim_key=config.attention_dim_key, dim_value=config.embed_dim // config.num_attention_heads, dropout=0.05, pos_dropout=0.01, num_rel_pos_features=config.num_rel_pos_features, ) # Correct conv layer in downsample block self.unet_head.unet.downsample_blocks[0].conv_layers[0] = nn.Conv1d( in_channels=1920, out_channels=1536, kernel_size=3, stride=1, padding=1 ) # Correct bias in separable layers self.separable1.conv_layer[1].bias = None self.separable0.conv_layer[1].bias = None def forward(self, x: torch.Tensor) -> torch.Tensor: # Stem x = x.transpose(1, 2) x = self.stem(x) # Conv tower x_unet0 = self.res_tower(x) x_unet1 = self.unet1(x_unet0) x = self._max_pool(x_unet1) # Transformer tower x = x.permute(0, 2, 1) x = self.transformer(x) x = x.permute(0, 2, 1) # UNet conv x_unet1 = self.horizontal_conv1(x_unet1) x_unet0 = self.horizontal_conv0(x_unet0) # UNet upsampling and separable convolutions x = self.upsampling_unet1(x) x += x_unet1 x = self.separable1(x) x = self.upsampling_unet0(x) x += x_unet0 x = self.separable0(x) # Target length crop x = self.crop(x.permute(0, 2, 1)) x = x.permute(0, 2, 1) # Final convolution block x = self.final_joined_convs(x) x = self.unet_head(x) return x # Define custom attention layer for PyTorch model because Attention layer from the # imported model is not the same (the positional embeddings are not the same) def _prepend_dims(tensor: torch.Tensor, num_dims: int) -> torch.Tensor: """Prepends dimensions to match the required shape.""" for _ in range(num_dims - tensor.dim()): tensor = tensor.unsqueeze(0) return tensor def get_positional_features_central_mask_borzoi( positions: torch.Tensor, feature_size: int, seq_length: int ) -> torch.Tensor: """Positional features using a central mask (allow only central features).""" pow_rate = torch.exp(torch.log(torch.tensor(seq_length + 1.0)) / feature_size) center_widths = torch.pow(pow_rate, torch.arange(1, feature_size + 1).float()) - 1 center_widths = _prepend_dims(center_widths, positions.ndim) outputs = (center_widths > torch.abs(positions).unsqueeze(-1)).float() return outputs def get_positional_embed_borzoi(seq_len: int, feature_size: int) -> torch.Tensor: """ Compute positional embedding for Borzoi. Note that it is different than the one used in Enformer. """ distances = torch.arange(-seq_len + 1, seq_len) num_components = 2 if (feature_size % num_components) != 0: raise ValueError( f"feature size is not divisible by number of components ({num_components})" ) num_basis_per_class = feature_size // num_components embeddings = [] embeddings.append( get_positional_features_central_mask_borzoi( distances, num_basis_per_class, seq_len ) ) embeddings = torch.cat(embeddings, dim=-1) embeddings = torch.cat( (embeddings, torch.sign(distances).unsqueeze(-1) * embeddings), dim=-1 ) return embeddings def relative_shift(x: torch.Tensor) -> torch.Tensor: to_pad = torch.zeros_like(x[..., :1]) x = torch.cat((to_pad, x), dim=-1) _, h, t1, t2 = x.shape x = x.reshape(-1, h, t2, t1) # noqa: FKA100 x = x[:, :, 1:, :] x = x.reshape(-1, h, t1, t2 - 1) # noqa: FKA100 return x[..., : ((t2 + 1) // 2)] class BorzoiAttentionLayer(nn.Module): def __init__( # type: ignore self, dim, *, num_rel_pos_features, heads=8, dim_key=64, dim_value=64, dropout=0.0, pos_dropout=0.0, ) -> None: super().__init__() self.scale = dim_key**-0.5 self.heads = heads self.to_q = nn.Linear(dim, dim_key * heads, bias=False) self.to_k = nn.Linear(dim, dim_key * heads, bias=False) self.to_v = nn.Linear(dim, dim_value * heads, bias=False) self.to_out = nn.Linear(dim_value * heads, dim) nn.init.zeros_(self.to_out.weight) nn.init.zeros_(self.to_out.bias) self.num_rel_pos_features = num_rel_pos_features self.to_rel_k = nn.Linear(num_rel_pos_features, dim_key * heads, bias=False) self.rel_content_bias = nn.Parameter( torch.randn(1, heads, 1, dim_key) # noqa: FKA100 ) self.rel_pos_bias = nn.Parameter( torch.randn(1, heads, 1, dim_key) # noqa: FKA100 ) # dropouts self.pos_dropout = nn.Dropout(pos_dropout) self.attn_dropout = nn.Dropout(dropout) def forward(self, x: torch.Tensor) -> torch.Tensor: n, h = x.shape[-2], self.heads q = self.to_q(x) k = self.to_k(x) v = self.to_v(x) q, k, v = map( # noqa lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), # type: ignore (q, k, v), ) q = q * self.scale content_logits = einsum( "b h i d, b h j d -> b h i j", q + self.rel_content_bias, k ) positions = get_positional_embed_borzoi(n, self.num_rel_pos_features) positions = self.pos_dropout(positions) rel_k = self.to_rel_k(positions) rel_k = rearrange(rel_k, "n (h d) -> h n d", h=h) rel_logits = einsum("b h i d, h j d -> b h i j", q + self.rel_pos_bias, rel_k) rel_logits = relative_shift(rel_logits) logits = content_logits + rel_logits attn = logits.softmax(dim=-1) attn = self.attn_dropout(attn) out = einsum("b h i j, b h j d -> b h i d", attn, v) out = rearrange(out, "b h n d -> b n (h d)") return self.to_out(out)