Yanisadel commited on
Commit
e0c22b2
·
verified ·
1 Parent(s): e3d58de

Upload SegmentBorzoi

Browse files
Files changed (2) hide show
  1. config.json +1 -1
  2. segment_borzoi.py +335 -4
config.json CHANGED
@@ -29,5 +29,5 @@
29
  "num_attention_heads": 8,
30
  "num_rel_pos_features": 32,
31
  "torch_dtype": "float32",
32
- "transformers_version": "4.41.1"
33
  }
 
29
  "num_attention_heads": 8,
30
  "num_rel_pos_features": 32,
31
  "torch_dtype": "float32",
32
+ "transformers_version": "4.48.0"
33
  }
segment_borzoi.py CHANGED
@@ -1,4 +1,4 @@
1
- from typing import Any, Dict, List
2
 
3
  import borzoi_pytorch
4
  import torch
@@ -7,9 +7,340 @@ from einops import rearrange
7
  from torch import einsum
8
  from transformers import PretrainedConfig, PreTrainedModel
9
 
10
- from genomics_research.segmentnt.porting_to_pytorch.layers.segmentation_head import (
11
- TorchUNetHead,
12
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  FEATURES = [
15
  "protein_coding_gene",
 
1
+ from typing import Any, Callable, Dict, List, Optional, Tuple
2
 
3
  import borzoi_pytorch
4
  import torch
 
7
  from torch import einsum
8
  from transformers import PretrainedConfig, PreTrainedModel
9
 
10
+
11
+ def get_activation_fn(activation_name: str) -> Callable:
12
+ """
13
+ Returns torch activation function
14
+
15
+ Args:
16
+ activation_name (str): Name of the activation function. Possible values are
17
+ 'swish', 'relu', 'gelu', 'sin'
18
+
19
+ Raises:
20
+ ValueError: If activation_name is not supported
21
+
22
+ Returns:
23
+ Callable: Activation function
24
+ """
25
+ if activation_name == "swish":
26
+ return nn.functional.silu # type: ignore
27
+ elif activation_name == "relu":
28
+ return nn.functional.relu # type: ignore
29
+ elif activation_name == "gelu":
30
+ return nn.functional.gelu # type: ignore
31
+ elif activation_name == "sin":
32
+ return torch.sin # type: ignore
33
+ else:
34
+ raise ValueError(f"Unsupported activation function: {activation_name}")
35
+
36
+
37
+ class TorchDownSample1D(nn.Module):
38
+ """
39
+ Torch adaptation of DownSample1D in trix.layers.heads.unet_segmentation_head.py
40
+ """
41
+
42
+ def __init__(
43
+ self,
44
+ input_channels: int,
45
+ output_channels: int,
46
+ activation_fn: str = "swish",
47
+ num_layers: int = 2,
48
+ ):
49
+ """
50
+ Args:
51
+ input_channels: number of input channels
52
+ output_channels: number of output channels.
53
+ activation_fn: name of the activation function to use.
54
+ Should be one of "gelu",
55
+ "gelu-no-approx", "relu", "swish", "silu", "sin".
56
+ num_layers: number of convolution layers.
57
+ """
58
+ super().__init__()
59
+
60
+ self.conv_layers = nn.ModuleList(
61
+ [
62
+ nn.Conv1d(
63
+ in_channels=input_channels if i == 0 else output_channels,
64
+ out_channels=output_channels,
65
+ kernel_size=3,
66
+ stride=1,
67
+ padding=1,
68
+ )
69
+ for i in range(num_layers)
70
+ ]
71
+ )
72
+
73
+ self.avg_pool = nn.AvgPool1d(kernel_size=2, stride=2, padding=0)
74
+
75
+ self.activation_fn: Callable = get_activation_fn(activation_fn)
76
+
77
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
78
+ for conv_layer in self.conv_layers:
79
+ x = self.activation_fn(conv_layer(x))
80
+ hidden = x
81
+ x = self.avg_pool(hidden)
82
+ return x, hidden
83
+
84
+
85
+ class TorchUpSample1D(nn.Module):
86
+ """
87
+ Torch adaptation of UpSample1D in trix.layers.heads.unet_segmentation_head.py
88
+ """
89
+
90
+ def __init__(
91
+ self,
92
+ input_channels: int,
93
+ output_channels: int,
94
+ activation_fn: str = "swish",
95
+ num_layers: int = 2,
96
+ interpolation_method: str = "nearest",
97
+ ):
98
+ """
99
+ Args:
100
+ input_channels: number of input channels.
101
+ output_channels: number of output channels.
102
+ activation_fn: name of the activation function to use.
103
+ Should be one of "gelu",
104
+ "gelu-no-approx", "relu", "swish", "silu", "sin".
105
+ interpolation_method: Method to be used for upsampling interpolation.
106
+ Should be one of "nearest", "linear", "cubic", "lanczos3", "lanczos5".
107
+ num_layers: number of convolution layers.
108
+ """
109
+ super().__init__()
110
+
111
+ self.conv_transpose_layers = nn.ModuleList(
112
+ [
113
+ nn.ConvTranspose1d(
114
+ in_channels=input_channels if i == 0 else output_channels,
115
+ out_channels=output_channels,
116
+ kernel_size=3,
117
+ stride=1,
118
+ padding=1,
119
+ )
120
+ for i in range(num_layers)
121
+ ]
122
+ )
123
+
124
+ self.interpolation_mode = interpolation_method
125
+ self.activation_fn: Callable = get_activation_fn(activation_fn)
126
+
127
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
128
+ for conv_layer in self.conv_transpose_layers:
129
+ x = self.activation_fn(conv_layer(x))
130
+ x = nn.functional.interpolate(
131
+ x,
132
+ scale_factor=2,
133
+ mode=self.interpolation_mode,
134
+ align_corners=False if self.interpolation_mode != "nearest" else None,
135
+ )
136
+ return x
137
+
138
+
139
+ class TorchFinalConv1D(nn.Module):
140
+ """
141
+ Torch adaptation of FinalConv1D in trix.layers.heads.unet_segmentation_head.py
142
+ """
143
+
144
+ def __init__(
145
+ self,
146
+ input_channels: int,
147
+ output_channels: int,
148
+ activation_fn: str = "swish",
149
+ num_layers: int = 2,
150
+ ):
151
+ """
152
+ Args:
153
+ input_channels: number of input channels
154
+ output_channels: number of output channels.
155
+ activation_fn: name of the activation function to use.
156
+ Should be one of "gelu",
157
+ "gelu-no-approx", "relu", "swish", "silu", "sin".
158
+ num_layers: number of convolution layers.
159
+ name: module name.
160
+ """
161
+ super().__init__()
162
+
163
+ self.conv_layers = nn.ModuleList(
164
+ [
165
+ nn.Conv1d(
166
+ in_channels=input_channels if i == 0 else output_channels,
167
+ out_channels=output_channels,
168
+ kernel_size=3,
169
+ stride=1,
170
+ padding=1,
171
+ )
172
+ for i in range(num_layers)
173
+ ]
174
+ )
175
+
176
+ self.activation_fn: Callable = get_activation_fn(activation_fn)
177
+
178
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
179
+ for i, conv_layer in enumerate(self.conv_layers):
180
+ x = conv_layer(x)
181
+ if i < len(self.conv_layers) - 1:
182
+ x = self.activation_fn(x)
183
+ return x
184
+
185
+
186
+ class TorchUNET1DSegmentationHead(nn.Module):
187
+ """
188
+ Torch adaptation of UNET1DSegmentationHead in
189
+ trix.layers.heads.unet_segmentation_head.py
190
+ """
191
+
192
+ def __init__(
193
+ self,
194
+ num_classes: int,
195
+ input_embed_dim: int,
196
+ output_channels_list: Tuple[int, ...] = (64, 128, 256),
197
+ activation_fn: str = "swish",
198
+ num_conv_layers_per_block: int = 2,
199
+ upsampling_interpolation_method: str = "nearest",
200
+ ):
201
+ """
202
+ Args:
203
+ num_classes: number of classes to segment
204
+ output_channels_list: list of the number of output channel at each level of
205
+ the UNET
206
+ activation_fn: name of the activation function to use.
207
+ Should be one of "gelu",
208
+ "gelu-no-approx", "relu", "swish", "silu", "sin".
209
+ num_conv_layers_per_block: number of convolution layers per block.
210
+ upsampling_interpolation_method: Method to be used for
211
+ interpolation in upsampling blocks. Should be one of "nearest",
212
+ "linear", "cubic", "lanczos3", "lanczos5".
213
+ """
214
+ super().__init__()
215
+
216
+ input_channels_list = (input_embed_dim,) + output_channels_list[:-1]
217
+
218
+ self.num_pooling_layers = len(output_channels_list)
219
+ self.downsample_blocks = nn.ModuleList(
220
+ [
221
+ TorchDownSample1D(
222
+ input_channels=input_channels,
223
+ output_channels=output_channels,
224
+ activation_fn=activation_fn,
225
+ num_layers=num_conv_layers_per_block,
226
+ )
227
+ for input_channels, output_channels in zip(
228
+ input_channels_list, output_channels_list
229
+ )
230
+ ]
231
+ )
232
+
233
+ input_channels_list = (output_channels_list[-1],) + tuple(
234
+ list(reversed(output_channels_list))[:-1]
235
+ )
236
+
237
+ self.upsample_blocks = nn.ModuleList(
238
+ [
239
+ TorchUpSample1D(
240
+ input_channels=input_channels,
241
+ output_channels=output_channels,
242
+ activation_fn=activation_fn,
243
+ num_layers=num_conv_layers_per_block,
244
+ interpolation_method=upsampling_interpolation_method,
245
+ )
246
+ for input_channels, output_channels in zip(
247
+ input_channels_list, reversed(output_channels_list)
248
+ )
249
+ ]
250
+ )
251
+
252
+ self.final_block = TorchFinalConv1D(
253
+ activation_fn=activation_fn,
254
+ input_channels=output_channels_list[0],
255
+ output_channels=num_classes * 2,
256
+ num_layers=num_conv_layers_per_block,
257
+ )
258
+
259
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
260
+ if x.shape[-1] % 2**self.num_pooling_layers:
261
+ raise ValueError(
262
+ "Input length must be divisible by 2 to the power of "
263
+ "the number of pooling layers."
264
+ )
265
+
266
+ hiddens = []
267
+ for downsample_block in self.downsample_blocks:
268
+ x, hidden = downsample_block(x)
269
+ hiddens.append(hidden)
270
+
271
+ for upsample_block, hidden in zip(self.upsample_blocks, reversed(hiddens)):
272
+ x = upsample_block(x) + hidden
273
+
274
+ x = self.final_block(x)
275
+ return x
276
+
277
+
278
+ class TorchUNetHead(nn.Module):
279
+ """
280
+ Torch adaptation of UNetHead in
281
+ genomics_research/segmentnt/layers/segmentation_head.py
282
+ """
283
+
284
+ def __init__(
285
+ self,
286
+ features: List[str],
287
+ num_classes: int = 2,
288
+ embed_dimension: int = 1024,
289
+ nucl_per_token: int = 6,
290
+ num_layers: int = 2,
291
+ remove_cls_token: bool = True,
292
+ ):
293
+ """
294
+ Args:
295
+ features (List[str]): List of features names.
296
+ num_classes (int): Number of classes.
297
+ embed_dimension (int): Embedding dimension.
298
+ nucl_per_token (int): Number of nucleotides per token.
299
+ num_layers (int): Number of layers.
300
+ remove_cls_token (bool): Whether to remove the CLS token.
301
+ name: Name the layer. Defaults to None.
302
+ """
303
+ super().__init__()
304
+ self._num_features = len(features)
305
+ self._num_classes = num_classes
306
+ self.nucl_per_token = nucl_per_token
307
+ self.remove_cls_token = remove_cls_token
308
+
309
+ self.unet = TorchUNET1DSegmentationHead(
310
+ num_classes=embed_dimension // 2,
311
+ output_channels_list=tuple(
312
+ embed_dimension * (2**i) for i in range(num_layers)
313
+ ),
314
+ input_embed_dim=embed_dimension,
315
+ )
316
+
317
+ self.fc = nn.Linear(
318
+ embed_dimension,
319
+ self.nucl_per_token * self._num_classes * self._num_features,
320
+ )
321
+
322
+ def forward(
323
+ self, x: torch.Tensor, sequence_mask: Optional[torch.Tensor] = None
324
+ ) -> Dict[str, torch.Tensor]:
325
+ if self.remove_cls_token:
326
+ x = x[:, 1:]
327
+
328
+ x = self.unet(x)
329
+ x = nn.functional.silu(x)
330
+
331
+ x = x.transpose(2, 1)
332
+ logits = self.fc(x)
333
+
334
+ batch_size, seq_len, _ = x.shape
335
+ logits = logits.view( # noqa
336
+ batch_size,
337
+ seq_len * self.nucl_per_token,
338
+ self._num_features,
339
+ self._num_classes,
340
+ )
341
+
342
+ return {"logits": logits}
343
+
344
 
345
  FEATURES = [
346
  "protein_coding_gene",