Upload SegmentEnformer
Browse files- config.json +1 -1
- segment_enformer.py +337 -5
config.json
CHANGED
@@ -26,5 +26,5 @@
|
|
26 |
],
|
27 |
"model_type": "segment_enformer",
|
28 |
"torch_dtype": "float32",
|
29 |
-
"transformers_version": "4.
|
30 |
}
|
|
|
26 |
],
|
27 |
"model_type": "segment_enformer",
|
28 |
"torch_dtype": "float32",
|
29 |
+
"transformers_version": "4.48.0"
|
30 |
}
|
segment_enformer.py
CHANGED
@@ -1,13 +1,345 @@
|
|
1 |
-
from typing import Any, Dict, List
|
2 |
|
3 |
import torch
|
|
|
4 |
from einops import rearrange
|
5 |
from enformer_pytorch import Enformer
|
6 |
from transformers import PretrainedConfig, PreTrainedModel
|
7 |
|
8 |
-
|
9 |
-
|
10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
FEATURES = [
|
13 |
"protein_coding_gene",
|
@@ -35,7 +367,7 @@ class SegmentEnformerConfig(PretrainedConfig):
|
|
35 |
features: List[str] = FEATURES,
|
36 |
embed_dim: int = 1536,
|
37 |
dim_divisible_by: int = 128,
|
38 |
-
**kwargs: Dict[str, Any]
|
39 |
) -> None:
|
40 |
self.features = features
|
41 |
self.embed_dim = embed_dim
|
|
|
1 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple
|
2 |
|
3 |
import torch
|
4 |
+
import torch.nn as nn
|
5 |
from einops import rearrange
|
6 |
from enformer_pytorch import Enformer
|
7 |
from transformers import PretrainedConfig, PreTrainedModel
|
8 |
|
9 |
+
|
10 |
+
def get_activation_fn(activation_name: str) -> Callable:
|
11 |
+
"""
|
12 |
+
Returns torch activation function
|
13 |
+
|
14 |
+
Args:
|
15 |
+
activation_name (str): Name of the activation function. Possible values are
|
16 |
+
'swish', 'relu', 'gelu', 'sin'
|
17 |
+
|
18 |
+
Raises:
|
19 |
+
ValueError: If activation_name is not supported
|
20 |
+
|
21 |
+
Returns:
|
22 |
+
Callable: Activation function
|
23 |
+
"""
|
24 |
+
if activation_name == "swish":
|
25 |
+
return nn.functional.silu # type: ignore
|
26 |
+
elif activation_name == "relu":
|
27 |
+
return nn.functional.relu # type: ignore
|
28 |
+
elif activation_name == "gelu":
|
29 |
+
return nn.functional.gelu # type: ignore
|
30 |
+
elif activation_name == "sin":
|
31 |
+
return torch.sin # type: ignore
|
32 |
+
else:
|
33 |
+
raise ValueError(f"Unsupported activation function: {activation_name}")
|
34 |
+
|
35 |
+
|
36 |
+
class TorchDownSample1D(nn.Module):
|
37 |
+
"""
|
38 |
+
Torch adaptation of DownSample1D in trix.layers.heads.unet_segmentation_head.py
|
39 |
+
"""
|
40 |
+
|
41 |
+
def __init__(
|
42 |
+
self,
|
43 |
+
input_channels: int,
|
44 |
+
output_channels: int,
|
45 |
+
activation_fn: str = "swish",
|
46 |
+
num_layers: int = 2,
|
47 |
+
):
|
48 |
+
"""
|
49 |
+
Args:
|
50 |
+
input_channels: number of input channels
|
51 |
+
output_channels: number of output channels.
|
52 |
+
activation_fn: name of the activation function to use.
|
53 |
+
Should be one of "gelu",
|
54 |
+
"gelu-no-approx", "relu", "swish", "silu", "sin".
|
55 |
+
num_layers: number of convolution layers.
|
56 |
+
"""
|
57 |
+
super().__init__()
|
58 |
+
|
59 |
+
self.conv_layers = nn.ModuleList(
|
60 |
+
[
|
61 |
+
nn.Conv1d(
|
62 |
+
in_channels=input_channels if i == 0 else output_channels,
|
63 |
+
out_channels=output_channels,
|
64 |
+
kernel_size=3,
|
65 |
+
stride=1,
|
66 |
+
padding=1,
|
67 |
+
)
|
68 |
+
for i in range(num_layers)
|
69 |
+
]
|
70 |
+
)
|
71 |
+
|
72 |
+
self.avg_pool = nn.AvgPool1d(kernel_size=2, stride=2, padding=0)
|
73 |
+
|
74 |
+
self.activation_fn: Callable = get_activation_fn(activation_fn)
|
75 |
+
|
76 |
+
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
77 |
+
for conv_layer in self.conv_layers:
|
78 |
+
x = self.activation_fn(conv_layer(x))
|
79 |
+
hidden = x
|
80 |
+
x = self.avg_pool(hidden)
|
81 |
+
return x, hidden
|
82 |
+
|
83 |
+
|
84 |
+
class TorchUpSample1D(nn.Module):
|
85 |
+
"""
|
86 |
+
Torch adaptation of UpSample1D in trix.layers.heads.unet_segmentation_head.py
|
87 |
+
"""
|
88 |
+
|
89 |
+
def __init__(
|
90 |
+
self,
|
91 |
+
input_channels: int,
|
92 |
+
output_channels: int,
|
93 |
+
activation_fn: str = "swish",
|
94 |
+
num_layers: int = 2,
|
95 |
+
interpolation_method: str = "nearest",
|
96 |
+
):
|
97 |
+
"""
|
98 |
+
Args:
|
99 |
+
input_channels: number of input channels.
|
100 |
+
output_channels: number of output channels.
|
101 |
+
activation_fn: name of the activation function to use.
|
102 |
+
Should be one of "gelu",
|
103 |
+
"gelu-no-approx", "relu", "swish", "silu", "sin".
|
104 |
+
interpolation_method: Method to be used for upsampling interpolation.
|
105 |
+
Should be one of "nearest", "linear", "cubic", "lanczos3", "lanczos5".
|
106 |
+
num_layers: number of convolution layers.
|
107 |
+
"""
|
108 |
+
super().__init__()
|
109 |
+
|
110 |
+
self.conv_transpose_layers = nn.ModuleList(
|
111 |
+
[
|
112 |
+
nn.ConvTranspose1d(
|
113 |
+
in_channels=input_channels if i == 0 else output_channels,
|
114 |
+
out_channels=output_channels,
|
115 |
+
kernel_size=3,
|
116 |
+
stride=1,
|
117 |
+
padding=1,
|
118 |
+
)
|
119 |
+
for i in range(num_layers)
|
120 |
+
]
|
121 |
+
)
|
122 |
+
|
123 |
+
self.interpolation_mode = interpolation_method
|
124 |
+
self.activation_fn: Callable = get_activation_fn(activation_fn)
|
125 |
+
|
126 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
127 |
+
for conv_layer in self.conv_transpose_layers:
|
128 |
+
x = self.activation_fn(conv_layer(x))
|
129 |
+
x = nn.functional.interpolate(
|
130 |
+
x,
|
131 |
+
scale_factor=2,
|
132 |
+
mode=self.interpolation_mode,
|
133 |
+
align_corners=False if self.interpolation_mode != "nearest" else None,
|
134 |
+
)
|
135 |
+
return x
|
136 |
+
|
137 |
+
|
138 |
+
class TorchFinalConv1D(nn.Module):
|
139 |
+
"""
|
140 |
+
Torch adaptation of FinalConv1D in trix.layers.heads.unet_segmentation_head.py
|
141 |
+
"""
|
142 |
+
|
143 |
+
def __init__(
|
144 |
+
self,
|
145 |
+
input_channels: int,
|
146 |
+
output_channels: int,
|
147 |
+
activation_fn: str = "swish",
|
148 |
+
num_layers: int = 2,
|
149 |
+
):
|
150 |
+
"""
|
151 |
+
Args:
|
152 |
+
input_channels: number of input channels
|
153 |
+
output_channels: number of output channels.
|
154 |
+
activation_fn: name of the activation function to use.
|
155 |
+
Should be one of "gelu",
|
156 |
+
"gelu-no-approx", "relu", "swish", "silu", "sin".
|
157 |
+
num_layers: number of convolution layers.
|
158 |
+
name: module name.
|
159 |
+
"""
|
160 |
+
super().__init__()
|
161 |
+
|
162 |
+
self.conv_layers = nn.ModuleList(
|
163 |
+
[
|
164 |
+
nn.Conv1d(
|
165 |
+
in_channels=input_channels if i == 0 else output_channels,
|
166 |
+
out_channels=output_channels,
|
167 |
+
kernel_size=3,
|
168 |
+
stride=1,
|
169 |
+
padding=1,
|
170 |
+
)
|
171 |
+
for i in range(num_layers)
|
172 |
+
]
|
173 |
+
)
|
174 |
+
|
175 |
+
self.activation_fn: Callable = get_activation_fn(activation_fn)
|
176 |
+
|
177 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
178 |
+
for i, conv_layer in enumerate(self.conv_layers):
|
179 |
+
x = conv_layer(x)
|
180 |
+
if i < len(self.conv_layers) - 1:
|
181 |
+
x = self.activation_fn(x)
|
182 |
+
return x
|
183 |
+
|
184 |
+
|
185 |
+
class TorchUNET1DSegmentationHead(nn.Module):
|
186 |
+
"""
|
187 |
+
Torch adaptation of UNET1DSegmentationHead in
|
188 |
+
trix.layers.heads.unet_segmentation_head.py
|
189 |
+
"""
|
190 |
+
|
191 |
+
def __init__(
|
192 |
+
self,
|
193 |
+
num_classes: int,
|
194 |
+
input_embed_dim: int,
|
195 |
+
output_channels_list: Tuple[int, ...] = (64, 128, 256),
|
196 |
+
activation_fn: str = "swish",
|
197 |
+
num_conv_layers_per_block: int = 2,
|
198 |
+
upsampling_interpolation_method: str = "nearest",
|
199 |
+
):
|
200 |
+
"""
|
201 |
+
Args:
|
202 |
+
num_classes: number of classes to segment
|
203 |
+
output_channels_list: list of the number of output channel at each level of
|
204 |
+
the UNET
|
205 |
+
activation_fn: name of the activation function to use.
|
206 |
+
Should be one of "gelu",
|
207 |
+
"gelu-no-approx", "relu", "swish", "silu", "sin".
|
208 |
+
num_conv_layers_per_block: number of convolution layers per block.
|
209 |
+
upsampling_interpolation_method: Method to be used for
|
210 |
+
interpolation in upsampling blocks. Should be one of "nearest",
|
211 |
+
"linear", "cubic", "lanczos3", "lanczos5".
|
212 |
+
"""
|
213 |
+
super().__init__()
|
214 |
+
|
215 |
+
input_channels_list = (input_embed_dim,) + output_channels_list[:-1]
|
216 |
+
|
217 |
+
self.num_pooling_layers = len(output_channels_list)
|
218 |
+
self.downsample_blocks = nn.ModuleList(
|
219 |
+
[
|
220 |
+
TorchDownSample1D(
|
221 |
+
input_channels=input_channels,
|
222 |
+
output_channels=output_channels,
|
223 |
+
activation_fn=activation_fn,
|
224 |
+
num_layers=num_conv_layers_per_block,
|
225 |
+
)
|
226 |
+
for input_channels, output_channels in zip(
|
227 |
+
input_channels_list, output_channels_list
|
228 |
+
)
|
229 |
+
]
|
230 |
+
)
|
231 |
+
|
232 |
+
input_channels_list = (output_channels_list[-1],) + tuple(
|
233 |
+
list(reversed(output_channels_list))[:-1]
|
234 |
+
)
|
235 |
+
|
236 |
+
self.upsample_blocks = nn.ModuleList(
|
237 |
+
[
|
238 |
+
TorchUpSample1D(
|
239 |
+
input_channels=input_channels,
|
240 |
+
output_channels=output_channels,
|
241 |
+
activation_fn=activation_fn,
|
242 |
+
num_layers=num_conv_layers_per_block,
|
243 |
+
interpolation_method=upsampling_interpolation_method,
|
244 |
+
)
|
245 |
+
for input_channels, output_channels in zip(
|
246 |
+
input_channels_list, reversed(output_channels_list)
|
247 |
+
)
|
248 |
+
]
|
249 |
+
)
|
250 |
+
|
251 |
+
self.final_block = TorchFinalConv1D(
|
252 |
+
activation_fn=activation_fn,
|
253 |
+
input_channels=output_channels_list[0],
|
254 |
+
output_channels=num_classes * 2,
|
255 |
+
num_layers=num_conv_layers_per_block,
|
256 |
+
)
|
257 |
+
|
258 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
259 |
+
if x.shape[-1] % 2**self.num_pooling_layers:
|
260 |
+
raise ValueError(
|
261 |
+
"Input length must be divisible by 2 to the power of "
|
262 |
+
"the number of pooling layers."
|
263 |
+
)
|
264 |
+
|
265 |
+
hiddens = []
|
266 |
+
for downsample_block in self.downsample_blocks:
|
267 |
+
x, hidden = downsample_block(x)
|
268 |
+
hiddens.append(hidden)
|
269 |
+
|
270 |
+
for upsample_block, hidden in zip(self.upsample_blocks, reversed(hiddens)):
|
271 |
+
x = upsample_block(x) + hidden
|
272 |
+
|
273 |
+
x = self.final_block(x)
|
274 |
+
return x
|
275 |
+
|
276 |
+
|
277 |
+
class TorchUNetHead(nn.Module):
|
278 |
+
"""
|
279 |
+
Torch adaptation of UNetHead in
|
280 |
+
genomics_research/segmentnt/layers/segmentation_head.py
|
281 |
+
"""
|
282 |
+
|
283 |
+
def __init__(
|
284 |
+
self,
|
285 |
+
features: List[str],
|
286 |
+
num_classes: int = 2,
|
287 |
+
embed_dimension: int = 1024,
|
288 |
+
nucl_per_token: int = 6,
|
289 |
+
num_layers: int = 2,
|
290 |
+
remove_cls_token: bool = True,
|
291 |
+
):
|
292 |
+
"""
|
293 |
+
Args:
|
294 |
+
features (List[str]): List of features names.
|
295 |
+
num_classes (int): Number of classes.
|
296 |
+
embed_dimension (int): Embedding dimension.
|
297 |
+
nucl_per_token (int): Number of nucleotides per token.
|
298 |
+
num_layers (int): Number of layers.
|
299 |
+
remove_cls_token (bool): Whether to remove the CLS token.
|
300 |
+
name: Name the layer. Defaults to None.
|
301 |
+
"""
|
302 |
+
super().__init__()
|
303 |
+
self._num_features = len(features)
|
304 |
+
self._num_classes = num_classes
|
305 |
+
self.nucl_per_token = nucl_per_token
|
306 |
+
self.remove_cls_token = remove_cls_token
|
307 |
+
|
308 |
+
self.unet = TorchUNET1DSegmentationHead(
|
309 |
+
num_classes=embed_dimension // 2,
|
310 |
+
output_channels_list=tuple(
|
311 |
+
embed_dimension * (2**i) for i in range(num_layers)
|
312 |
+
),
|
313 |
+
input_embed_dim=embed_dimension,
|
314 |
+
)
|
315 |
+
|
316 |
+
self.fc = nn.Linear(
|
317 |
+
embed_dimension,
|
318 |
+
self.nucl_per_token * self._num_classes * self._num_features,
|
319 |
+
)
|
320 |
+
|
321 |
+
def forward(
|
322 |
+
self, x: torch.Tensor, sequence_mask: Optional[torch.Tensor] = None
|
323 |
+
) -> Dict[str, torch.Tensor]:
|
324 |
+
if self.remove_cls_token:
|
325 |
+
x = x[:, 1:]
|
326 |
+
|
327 |
+
x = self.unet(x)
|
328 |
+
x = nn.functional.silu(x)
|
329 |
+
|
330 |
+
x = x.transpose(2, 1)
|
331 |
+
logits = self.fc(x)
|
332 |
+
|
333 |
+
batch_size, seq_len, _ = x.shape
|
334 |
+
logits = logits.view( # noqa
|
335 |
+
batch_size,
|
336 |
+
seq_len * self.nucl_per_token,
|
337 |
+
self._num_features,
|
338 |
+
self._num_classes,
|
339 |
+
)
|
340 |
+
|
341 |
+
return {"logits": logits}
|
342 |
+
|
343 |
|
344 |
FEATURES = [
|
345 |
"protein_coding_gene",
|
|
|
367 |
features: List[str] = FEATURES,
|
368 |
embed_dim: int = 1536,
|
369 |
dim_divisible_by: int = 128,
|
370 |
+
**kwargs: Dict[str, Any],
|
371 |
) -> None:
|
372 |
self.features = features
|
373 |
self.embed_dim = embed_dim
|