Yanisadel commited on
Commit
669aaf5
·
verified ·
1 Parent(s): 95d75ca

Upload SegmentBorzoi

Browse files
Files changed (1) hide show
  1. segment_borzoi.py +20 -11
segment_borzoi.py CHANGED
@@ -7,7 +7,9 @@ from einops import rearrange
7
  from torch import einsum
8
  from transformers import PretrainedConfig, PreTrainedModel
9
 
10
- from genomics_research.segmentnt.layers.torch.segmentation_head import TorchUNetHead
 
 
11
 
12
  FEATURES = [
13
  "protein_coding_gene",
@@ -91,7 +93,7 @@ class SegmentBorzoi(PreTrainedModel):
91
 
92
  # Correct transformer
93
  for layer in self.transformer:
94
- layer[0].fn[1] = BorzoiAttentionLayer(
95
  config.embed_dim,
96
  heads=config.num_attention_heads,
97
  dim_key=config.attention_dim_key,
@@ -105,7 +107,7 @@ class SegmentBorzoi(PreTrainedModel):
105
  self.separable1.conv_layer[1].bias = None
106
  self.separable0.conv_layer[1].bias = None
107
 
108
- def forward(self, x):
109
  # Stem
110
  x = x.transpose(1, 2)
111
  x = self.stem(x)
@@ -199,14 +201,14 @@ def relative_shift(x: torch.Tensor) -> torch.Tensor:
199
  to_pad = torch.zeros_like(x[..., :1])
200
  x = torch.cat((to_pad, x), dim=-1)
201
  _, h, t1, t2 = x.shape
202
- x = x.reshape(-1, h, t2, t1)
203
  x = x[:, :, 1:, :]
204
- x = x.reshape(-1, h, t1, t2 - 1)
205
  return x[..., : ((t2 + 1) // 2)]
206
 
207
 
208
  class BorzoiAttentionLayer(nn.Module):
209
- def __init__(
210
  self,
211
  dim,
212
  *,
@@ -216,7 +218,7 @@ class BorzoiAttentionLayer(nn.Module):
216
  dim_value=64,
217
  dropout=0.0,
218
  pos_dropout=0.0,
219
- ):
220
  super().__init__()
221
  self.scale = dim_key**-0.5
222
  self.heads = heads
@@ -232,22 +234,29 @@ class BorzoiAttentionLayer(nn.Module):
232
  self.num_rel_pos_features = num_rel_pos_features
233
 
234
  self.to_rel_k = nn.Linear(num_rel_pos_features, dim_key * heads, bias=False)
235
- self.rel_content_bias = nn.Parameter(torch.randn(1, heads, 1, dim_key))
236
- self.rel_pos_bias = nn.Parameter(torch.randn(1, heads, 1, dim_key))
 
 
 
 
237
 
238
  # dropouts
239
 
240
  self.pos_dropout = nn.Dropout(pos_dropout)
241
  self.attn_dropout = nn.Dropout(dropout)
242
 
243
- def forward(self, x):
244
  n, h = x.shape[-2], self.heads
245
 
246
  q = self.to_q(x)
247
  k = self.to_k(x)
248
  v = self.to_v(x)
249
 
250
- q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
 
 
 
251
 
252
  q = q * self.scale
253
 
 
7
  from torch import einsum
8
  from transformers import PretrainedConfig, PreTrainedModel
9
 
10
+ from genomics_research.segmentnt.porting_to_pytorch.layers.segmentation_head import (
11
+ TorchUNetHead,
12
+ )
13
 
14
  FEATURES = [
15
  "protein_coding_gene",
 
93
 
94
  # Correct transformer
95
  for layer in self.transformer:
96
+ layer[0].fn[1] = BorzoiAttentionLayer( # type: ignore
97
  config.embed_dim,
98
  heads=config.num_attention_heads,
99
  dim_key=config.attention_dim_key,
 
107
  self.separable1.conv_layer[1].bias = None
108
  self.separable0.conv_layer[1].bias = None
109
 
110
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
111
  # Stem
112
  x = x.transpose(1, 2)
113
  x = self.stem(x)
 
201
  to_pad = torch.zeros_like(x[..., :1])
202
  x = torch.cat((to_pad, x), dim=-1)
203
  _, h, t1, t2 = x.shape
204
+ x = x.reshape(-1, h, t2, t1) # noqa: FKA100
205
  x = x[:, :, 1:, :]
206
+ x = x.reshape(-1, h, t1, t2 - 1) # noqa: FKA100
207
  return x[..., : ((t2 + 1) // 2)]
208
 
209
 
210
  class BorzoiAttentionLayer(nn.Module):
211
+ def __init__( # type: ignore
212
  self,
213
  dim,
214
  *,
 
218
  dim_value=64,
219
  dropout=0.0,
220
  pos_dropout=0.0,
221
+ ) -> None:
222
  super().__init__()
223
  self.scale = dim_key**-0.5
224
  self.heads = heads
 
234
  self.num_rel_pos_features = num_rel_pos_features
235
 
236
  self.to_rel_k = nn.Linear(num_rel_pos_features, dim_key * heads, bias=False)
237
+ self.rel_content_bias = nn.Parameter(
238
+ torch.randn(1, heads, 1, dim_key) # noqa: FKA100
239
+ )
240
+ self.rel_pos_bias = nn.Parameter(
241
+ torch.randn(1, heads, 1, dim_key) # noqa: FKA100
242
+ )
243
 
244
  # dropouts
245
 
246
  self.pos_dropout = nn.Dropout(pos_dropout)
247
  self.attn_dropout = nn.Dropout(dropout)
248
 
249
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
250
  n, h = x.shape[-2], self.heads
251
 
252
  q = self.to_q(x)
253
  k = self.to_k(x)
254
  v = self.to_v(x)
255
 
256
+ q, k, v = map( # noqa
257
+ lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), # type: ignore
258
+ (q, k, v),
259
+ )
260
 
261
  q = q * self.scale
262