Yanisadel commited on
Commit
f7aa1ae
·
verified ·
1 Parent(s): d33edb8

Upload SegmentEnformer

Browse files
Files changed (2) hide show
  1. config.json +1 -1
  2. segment_enformer.py +337 -5
config.json CHANGED
@@ -26,5 +26,5 @@
26
  ],
27
  "model_type": "segment_enformer",
28
  "torch_dtype": "float32",
29
- "transformers_version": "4.41.1"
30
  }
 
26
  ],
27
  "model_type": "segment_enformer",
28
  "torch_dtype": "float32",
29
+ "transformers_version": "4.48.0"
30
  }
segment_enformer.py CHANGED
@@ -1,13 +1,345 @@
1
- from typing import Any, Dict, List
2
 
3
  import torch
 
4
  from einops import rearrange
5
  from enformer_pytorch import Enformer
6
  from transformers import PretrainedConfig, PreTrainedModel
7
 
8
- from genomics_research.segmentnt.porting_to_pytorch.layers.segmentation_head import (
9
- TorchUNetHead,
10
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  FEATURES = [
13
  "protein_coding_gene",
@@ -35,7 +367,7 @@ class SegmentEnformerConfig(PretrainedConfig):
35
  features: List[str] = FEATURES,
36
  embed_dim: int = 1536,
37
  dim_divisible_by: int = 128,
38
- **kwargs: Dict[str, Any]
39
  ) -> None:
40
  self.features = features
41
  self.embed_dim = embed_dim
 
1
+ from typing import Any, Callable, Dict, List, Optional, Tuple
2
 
3
  import torch
4
+ import torch.nn as nn
5
  from einops import rearrange
6
  from enformer_pytorch import Enformer
7
  from transformers import PretrainedConfig, PreTrainedModel
8
 
9
+
10
+ def get_activation_fn(activation_name: str) -> Callable:
11
+ """
12
+ Returns torch activation function
13
+
14
+ Args:
15
+ activation_name (str): Name of the activation function. Possible values are
16
+ 'swish', 'relu', 'gelu', 'sin'
17
+
18
+ Raises:
19
+ ValueError: If activation_name is not supported
20
+
21
+ Returns:
22
+ Callable: Activation function
23
+ """
24
+ if activation_name == "swish":
25
+ return nn.functional.silu # type: ignore
26
+ elif activation_name == "relu":
27
+ return nn.functional.relu # type: ignore
28
+ elif activation_name == "gelu":
29
+ return nn.functional.gelu # type: ignore
30
+ elif activation_name == "sin":
31
+ return torch.sin # type: ignore
32
+ else:
33
+ raise ValueError(f"Unsupported activation function: {activation_name}")
34
+
35
+
36
+ class TorchDownSample1D(nn.Module):
37
+ """
38
+ Torch adaptation of DownSample1D in trix.layers.heads.unet_segmentation_head.py
39
+ """
40
+
41
+ def __init__(
42
+ self,
43
+ input_channels: int,
44
+ output_channels: int,
45
+ activation_fn: str = "swish",
46
+ num_layers: int = 2,
47
+ ):
48
+ """
49
+ Args:
50
+ input_channels: number of input channels
51
+ output_channels: number of output channels.
52
+ activation_fn: name of the activation function to use.
53
+ Should be one of "gelu",
54
+ "gelu-no-approx", "relu", "swish", "silu", "sin".
55
+ num_layers: number of convolution layers.
56
+ """
57
+ super().__init__()
58
+
59
+ self.conv_layers = nn.ModuleList(
60
+ [
61
+ nn.Conv1d(
62
+ in_channels=input_channels if i == 0 else output_channels,
63
+ out_channels=output_channels,
64
+ kernel_size=3,
65
+ stride=1,
66
+ padding=1,
67
+ )
68
+ for i in range(num_layers)
69
+ ]
70
+ )
71
+
72
+ self.avg_pool = nn.AvgPool1d(kernel_size=2, stride=2, padding=0)
73
+
74
+ self.activation_fn: Callable = get_activation_fn(activation_fn)
75
+
76
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
77
+ for conv_layer in self.conv_layers:
78
+ x = self.activation_fn(conv_layer(x))
79
+ hidden = x
80
+ x = self.avg_pool(hidden)
81
+ return x, hidden
82
+
83
+
84
+ class TorchUpSample1D(nn.Module):
85
+ """
86
+ Torch adaptation of UpSample1D in trix.layers.heads.unet_segmentation_head.py
87
+ """
88
+
89
+ def __init__(
90
+ self,
91
+ input_channels: int,
92
+ output_channels: int,
93
+ activation_fn: str = "swish",
94
+ num_layers: int = 2,
95
+ interpolation_method: str = "nearest",
96
+ ):
97
+ """
98
+ Args:
99
+ input_channels: number of input channels.
100
+ output_channels: number of output channels.
101
+ activation_fn: name of the activation function to use.
102
+ Should be one of "gelu",
103
+ "gelu-no-approx", "relu", "swish", "silu", "sin".
104
+ interpolation_method: Method to be used for upsampling interpolation.
105
+ Should be one of "nearest", "linear", "cubic", "lanczos3", "lanczos5".
106
+ num_layers: number of convolution layers.
107
+ """
108
+ super().__init__()
109
+
110
+ self.conv_transpose_layers = nn.ModuleList(
111
+ [
112
+ nn.ConvTranspose1d(
113
+ in_channels=input_channels if i == 0 else output_channels,
114
+ out_channels=output_channels,
115
+ kernel_size=3,
116
+ stride=1,
117
+ padding=1,
118
+ )
119
+ for i in range(num_layers)
120
+ ]
121
+ )
122
+
123
+ self.interpolation_mode = interpolation_method
124
+ self.activation_fn: Callable = get_activation_fn(activation_fn)
125
+
126
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
127
+ for conv_layer in self.conv_transpose_layers:
128
+ x = self.activation_fn(conv_layer(x))
129
+ x = nn.functional.interpolate(
130
+ x,
131
+ scale_factor=2,
132
+ mode=self.interpolation_mode,
133
+ align_corners=False if self.interpolation_mode != "nearest" else None,
134
+ )
135
+ return x
136
+
137
+
138
+ class TorchFinalConv1D(nn.Module):
139
+ """
140
+ Torch adaptation of FinalConv1D in trix.layers.heads.unet_segmentation_head.py
141
+ """
142
+
143
+ def __init__(
144
+ self,
145
+ input_channels: int,
146
+ output_channels: int,
147
+ activation_fn: str = "swish",
148
+ num_layers: int = 2,
149
+ ):
150
+ """
151
+ Args:
152
+ input_channels: number of input channels
153
+ output_channels: number of output channels.
154
+ activation_fn: name of the activation function to use.
155
+ Should be one of "gelu",
156
+ "gelu-no-approx", "relu", "swish", "silu", "sin".
157
+ num_layers: number of convolution layers.
158
+ name: module name.
159
+ """
160
+ super().__init__()
161
+
162
+ self.conv_layers = nn.ModuleList(
163
+ [
164
+ nn.Conv1d(
165
+ in_channels=input_channels if i == 0 else output_channels,
166
+ out_channels=output_channels,
167
+ kernel_size=3,
168
+ stride=1,
169
+ padding=1,
170
+ )
171
+ for i in range(num_layers)
172
+ ]
173
+ )
174
+
175
+ self.activation_fn: Callable = get_activation_fn(activation_fn)
176
+
177
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
178
+ for i, conv_layer in enumerate(self.conv_layers):
179
+ x = conv_layer(x)
180
+ if i < len(self.conv_layers) - 1:
181
+ x = self.activation_fn(x)
182
+ return x
183
+
184
+
185
+ class TorchUNET1DSegmentationHead(nn.Module):
186
+ """
187
+ Torch adaptation of UNET1DSegmentationHead in
188
+ trix.layers.heads.unet_segmentation_head.py
189
+ """
190
+
191
+ def __init__(
192
+ self,
193
+ num_classes: int,
194
+ input_embed_dim: int,
195
+ output_channels_list: Tuple[int, ...] = (64, 128, 256),
196
+ activation_fn: str = "swish",
197
+ num_conv_layers_per_block: int = 2,
198
+ upsampling_interpolation_method: str = "nearest",
199
+ ):
200
+ """
201
+ Args:
202
+ num_classes: number of classes to segment
203
+ output_channels_list: list of the number of output channel at each level of
204
+ the UNET
205
+ activation_fn: name of the activation function to use.
206
+ Should be one of "gelu",
207
+ "gelu-no-approx", "relu", "swish", "silu", "sin".
208
+ num_conv_layers_per_block: number of convolution layers per block.
209
+ upsampling_interpolation_method: Method to be used for
210
+ interpolation in upsampling blocks. Should be one of "nearest",
211
+ "linear", "cubic", "lanczos3", "lanczos5".
212
+ """
213
+ super().__init__()
214
+
215
+ input_channels_list = (input_embed_dim,) + output_channels_list[:-1]
216
+
217
+ self.num_pooling_layers = len(output_channels_list)
218
+ self.downsample_blocks = nn.ModuleList(
219
+ [
220
+ TorchDownSample1D(
221
+ input_channels=input_channels,
222
+ output_channels=output_channels,
223
+ activation_fn=activation_fn,
224
+ num_layers=num_conv_layers_per_block,
225
+ )
226
+ for input_channels, output_channels in zip(
227
+ input_channels_list, output_channels_list
228
+ )
229
+ ]
230
+ )
231
+
232
+ input_channels_list = (output_channels_list[-1],) + tuple(
233
+ list(reversed(output_channels_list))[:-1]
234
+ )
235
+
236
+ self.upsample_blocks = nn.ModuleList(
237
+ [
238
+ TorchUpSample1D(
239
+ input_channels=input_channels,
240
+ output_channels=output_channels,
241
+ activation_fn=activation_fn,
242
+ num_layers=num_conv_layers_per_block,
243
+ interpolation_method=upsampling_interpolation_method,
244
+ )
245
+ for input_channels, output_channels in zip(
246
+ input_channels_list, reversed(output_channels_list)
247
+ )
248
+ ]
249
+ )
250
+
251
+ self.final_block = TorchFinalConv1D(
252
+ activation_fn=activation_fn,
253
+ input_channels=output_channels_list[0],
254
+ output_channels=num_classes * 2,
255
+ num_layers=num_conv_layers_per_block,
256
+ )
257
+
258
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
259
+ if x.shape[-1] % 2**self.num_pooling_layers:
260
+ raise ValueError(
261
+ "Input length must be divisible by 2 to the power of "
262
+ "the number of pooling layers."
263
+ )
264
+
265
+ hiddens = []
266
+ for downsample_block in self.downsample_blocks:
267
+ x, hidden = downsample_block(x)
268
+ hiddens.append(hidden)
269
+
270
+ for upsample_block, hidden in zip(self.upsample_blocks, reversed(hiddens)):
271
+ x = upsample_block(x) + hidden
272
+
273
+ x = self.final_block(x)
274
+ return x
275
+
276
+
277
+ class TorchUNetHead(nn.Module):
278
+ """
279
+ Torch adaptation of UNetHead in
280
+ genomics_research/segmentnt/layers/segmentation_head.py
281
+ """
282
+
283
+ def __init__(
284
+ self,
285
+ features: List[str],
286
+ num_classes: int = 2,
287
+ embed_dimension: int = 1024,
288
+ nucl_per_token: int = 6,
289
+ num_layers: int = 2,
290
+ remove_cls_token: bool = True,
291
+ ):
292
+ """
293
+ Args:
294
+ features (List[str]): List of features names.
295
+ num_classes (int): Number of classes.
296
+ embed_dimension (int): Embedding dimension.
297
+ nucl_per_token (int): Number of nucleotides per token.
298
+ num_layers (int): Number of layers.
299
+ remove_cls_token (bool): Whether to remove the CLS token.
300
+ name: Name the layer. Defaults to None.
301
+ """
302
+ super().__init__()
303
+ self._num_features = len(features)
304
+ self._num_classes = num_classes
305
+ self.nucl_per_token = nucl_per_token
306
+ self.remove_cls_token = remove_cls_token
307
+
308
+ self.unet = TorchUNET1DSegmentationHead(
309
+ num_classes=embed_dimension // 2,
310
+ output_channels_list=tuple(
311
+ embed_dimension * (2**i) for i in range(num_layers)
312
+ ),
313
+ input_embed_dim=embed_dimension,
314
+ )
315
+
316
+ self.fc = nn.Linear(
317
+ embed_dimension,
318
+ self.nucl_per_token * self._num_classes * self._num_features,
319
+ )
320
+
321
+ def forward(
322
+ self, x: torch.Tensor, sequence_mask: Optional[torch.Tensor] = None
323
+ ) -> Dict[str, torch.Tensor]:
324
+ if self.remove_cls_token:
325
+ x = x[:, 1:]
326
+
327
+ x = self.unet(x)
328
+ x = nn.functional.silu(x)
329
+
330
+ x = x.transpose(2, 1)
331
+ logits = self.fc(x)
332
+
333
+ batch_size, seq_len, _ = x.shape
334
+ logits = logits.view( # noqa
335
+ batch_size,
336
+ seq_len * self.nucl_per_token,
337
+ self._num_features,
338
+ self._num_classes,
339
+ )
340
+
341
+ return {"logits": logits}
342
+
343
 
344
  FEATURES = [
345
  "protein_coding_gene",
 
367
  features: List[str] = FEATURES,
368
  embed_dim: int = 1536,
369
  dim_divisible_by: int = 128,
370
+ **kwargs: Dict[str, Any],
371
  ) -> None:
372
  self.features = features
373
  self.embed_dim = embed_dim