Upload SegmentBorzoi
Browse files- segment_borzoi.py +20 -11
segment_borzoi.py
CHANGED
@@ -7,7 +7,9 @@ from einops import rearrange
|
|
7 |
from torch import einsum
|
8 |
from transformers import PretrainedConfig, PreTrainedModel
|
9 |
|
10 |
-
from genomics_research.segmentnt.layers.
|
|
|
|
|
11 |
|
12 |
FEATURES = [
|
13 |
"protein_coding_gene",
|
@@ -91,7 +93,7 @@ class SegmentBorzoi(PreTrainedModel):
|
|
91 |
|
92 |
# Correct transformer
|
93 |
for layer in self.transformer:
|
94 |
-
layer[0].fn[1] = BorzoiAttentionLayer(
|
95 |
config.embed_dim,
|
96 |
heads=config.num_attention_heads,
|
97 |
dim_key=config.attention_dim_key,
|
@@ -105,7 +107,7 @@ class SegmentBorzoi(PreTrainedModel):
|
|
105 |
self.separable1.conv_layer[1].bias = None
|
106 |
self.separable0.conv_layer[1].bias = None
|
107 |
|
108 |
-
def forward(self, x):
|
109 |
# Stem
|
110 |
x = x.transpose(1, 2)
|
111 |
x = self.stem(x)
|
@@ -199,14 +201,14 @@ def relative_shift(x: torch.Tensor) -> torch.Tensor:
|
|
199 |
to_pad = torch.zeros_like(x[..., :1])
|
200 |
x = torch.cat((to_pad, x), dim=-1)
|
201 |
_, h, t1, t2 = x.shape
|
202 |
-
x = x.reshape(-1, h, t2, t1)
|
203 |
x = x[:, :, 1:, :]
|
204 |
-
x = x.reshape(-1, h, t1, t2 - 1)
|
205 |
return x[..., : ((t2 + 1) // 2)]
|
206 |
|
207 |
|
208 |
class BorzoiAttentionLayer(nn.Module):
|
209 |
-
def __init__(
|
210 |
self,
|
211 |
dim,
|
212 |
*,
|
@@ -216,7 +218,7 @@ class BorzoiAttentionLayer(nn.Module):
|
|
216 |
dim_value=64,
|
217 |
dropout=0.0,
|
218 |
pos_dropout=0.0,
|
219 |
-
):
|
220 |
super().__init__()
|
221 |
self.scale = dim_key**-0.5
|
222 |
self.heads = heads
|
@@ -232,22 +234,29 @@ class BorzoiAttentionLayer(nn.Module):
|
|
232 |
self.num_rel_pos_features = num_rel_pos_features
|
233 |
|
234 |
self.to_rel_k = nn.Linear(num_rel_pos_features, dim_key * heads, bias=False)
|
235 |
-
self.rel_content_bias = nn.Parameter(
|
236 |
-
|
|
|
|
|
|
|
|
|
237 |
|
238 |
# dropouts
|
239 |
|
240 |
self.pos_dropout = nn.Dropout(pos_dropout)
|
241 |
self.attn_dropout = nn.Dropout(dropout)
|
242 |
|
243 |
-
def forward(self, x):
|
244 |
n, h = x.shape[-2], self.heads
|
245 |
|
246 |
q = self.to_q(x)
|
247 |
k = self.to_k(x)
|
248 |
v = self.to_v(x)
|
249 |
|
250 |
-
q, k, v = map(
|
|
|
|
|
|
|
251 |
|
252 |
q = q * self.scale
|
253 |
|
|
|
7 |
from torch import einsum
|
8 |
from transformers import PretrainedConfig, PreTrainedModel
|
9 |
|
10 |
+
from genomics_research.segmentnt.porting_to_pytorch.layers.segmentation_head import (
|
11 |
+
TorchUNetHead,
|
12 |
+
)
|
13 |
|
14 |
FEATURES = [
|
15 |
"protein_coding_gene",
|
|
|
93 |
|
94 |
# Correct transformer
|
95 |
for layer in self.transformer:
|
96 |
+
layer[0].fn[1] = BorzoiAttentionLayer( # type: ignore
|
97 |
config.embed_dim,
|
98 |
heads=config.num_attention_heads,
|
99 |
dim_key=config.attention_dim_key,
|
|
|
107 |
self.separable1.conv_layer[1].bias = None
|
108 |
self.separable0.conv_layer[1].bias = None
|
109 |
|
110 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
111 |
# Stem
|
112 |
x = x.transpose(1, 2)
|
113 |
x = self.stem(x)
|
|
|
201 |
to_pad = torch.zeros_like(x[..., :1])
|
202 |
x = torch.cat((to_pad, x), dim=-1)
|
203 |
_, h, t1, t2 = x.shape
|
204 |
+
x = x.reshape(-1, h, t2, t1) # noqa: FKA100
|
205 |
x = x[:, :, 1:, :]
|
206 |
+
x = x.reshape(-1, h, t1, t2 - 1) # noqa: FKA100
|
207 |
return x[..., : ((t2 + 1) // 2)]
|
208 |
|
209 |
|
210 |
class BorzoiAttentionLayer(nn.Module):
|
211 |
+
def __init__( # type: ignore
|
212 |
self,
|
213 |
dim,
|
214 |
*,
|
|
|
218 |
dim_value=64,
|
219 |
dropout=0.0,
|
220 |
pos_dropout=0.0,
|
221 |
+
) -> None:
|
222 |
super().__init__()
|
223 |
self.scale = dim_key**-0.5
|
224 |
self.heads = heads
|
|
|
234 |
self.num_rel_pos_features = num_rel_pos_features
|
235 |
|
236 |
self.to_rel_k = nn.Linear(num_rel_pos_features, dim_key * heads, bias=False)
|
237 |
+
self.rel_content_bias = nn.Parameter(
|
238 |
+
torch.randn(1, heads, 1, dim_key) # noqa: FKA100
|
239 |
+
)
|
240 |
+
self.rel_pos_bias = nn.Parameter(
|
241 |
+
torch.randn(1, heads, 1, dim_key) # noqa: FKA100
|
242 |
+
)
|
243 |
|
244 |
# dropouts
|
245 |
|
246 |
self.pos_dropout = nn.Dropout(pos_dropout)
|
247 |
self.attn_dropout = nn.Dropout(dropout)
|
248 |
|
249 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
250 |
n, h = x.shape[-2], self.heads
|
251 |
|
252 |
q = self.to_q(x)
|
253 |
k = self.to_k(x)
|
254 |
v = self.to_v(x)
|
255 |
|
256 |
+
q, k, v = map( # noqa
|
257 |
+
lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), # type: ignore
|
258 |
+
(q, k, v),
|
259 |
+
)
|
260 |
|
261 |
q = q * self.scale
|
262 |
|