Update modeling_omnigenome.py
Browse files- modeling_omnigenome.py +468 -140
modeling_omnigenome.py
CHANGED
@@ -15,8 +15,11 @@
|
|
15 |
""" PyTorch OmniGenome model."""
|
16 |
|
17 |
import math
|
|
|
|
|
18 |
from typing import List, Optional, Tuple, Union
|
19 |
|
|
|
20 |
import torch
|
21 |
import torch.utils.checkpoint
|
22 |
from torch import nn
|
@@ -300,6 +303,178 @@ class OmniGenomeEmbeddings(nn.Module):
|
|
300 |
)
|
301 |
return position_ids.unsqueeze(0).expand(input_shape)
|
302 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
303 |
|
304 |
# Copied from transformers.models.esm.modeling_esm.EsmSelfAttention with Esm->OmniGenome
|
305 |
class OmniGenomeSelfAttention(nn.Module):
|
@@ -339,6 +514,14 @@ class OmniGenomeSelfAttention(nn.Module):
|
|
339 |
|
340 |
self.is_decoder = config.is_decoder
|
341 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
342 |
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
|
343 |
new_x_shape = x.size()[:-1] + (
|
344 |
self.num_attention_heads,
|
@@ -359,13 +542,9 @@ class OmniGenomeSelfAttention(nn.Module):
|
|
359 |
) -> Tuple[torch.Tensor]:
|
360 |
mixed_query_layer = self.query(hidden_states)
|
361 |
|
362 |
-
# If this is instantiated as a cross-attention module, the keys
|
363 |
-
# and values come from an encoder; the attention mask needs to be
|
364 |
-
# such that the encoder's padding tokens are not attended to.
|
365 |
is_cross_attention = encoder_hidden_states is not None
|
366 |
|
367 |
if is_cross_attention and past_key_value is not None:
|
368 |
-
# reuse k,v, cross_attentions
|
369 |
key_layer = past_key_value[0]
|
370 |
value_layer = past_key_value[1]
|
371 |
attention_mask = encoder_attention_mask
|
@@ -384,95 +563,75 @@ class OmniGenomeSelfAttention(nn.Module):
|
|
384 |
|
385 |
query_layer = self.transpose_for_scores(mixed_query_layer)
|
386 |
|
387 |
-
# Matt: Our BERT model (which this code was derived from) scales attention logits down by sqrt(head_dim).
|
388 |
-
# OmniGenome scales the query down by the same factor instead. Modulo numerical stability these are equivalent,
|
389 |
-
# but not when rotary embeddings get involved. Therefore, we scale the query here to match the original
|
390 |
-
# OmniGenome code and fix rotary embeddings.
|
391 |
-
query_layer = query_layer * self.attention_head_size ** -0.5
|
392 |
-
|
393 |
if self.is_decoder:
|
394 |
-
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
395 |
-
# Further calls to cross_attention layer can then reuse all cross-attention
|
396 |
-
# key/value_states (first "if" case)
|
397 |
-
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
398 |
-
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
399 |
-
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
400 |
-
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
401 |
past_key_value = (key_layer, value_layer)
|
402 |
|
403 |
-
|
|
|
|
|
|
|
404 |
query_layer, key_layer = self.rotary_embeddings(query_layer, key_layer)
|
405 |
|
406 |
-
|
407 |
-
|
408 |
-
|
409 |
-
|
410 |
-
|
411 |
-
|
412 |
-
|
413 |
-
|
414 |
-
|
415 |
-
|
416 |
-
|
417 |
-
position_ids_r = torch.arange(
|
418 |
-
seq_length, dtype=torch.long, device=hidden_states.device
|
419 |
-
).view(1, -1)
|
420 |
-
distance = position_ids_l - position_ids_r
|
421 |
-
positional_embedding = self.distance_embedding(
|
422 |
-
distance + self.max_position_embeddings - 1
|
423 |
)
|
424 |
-
positional_embedding = positional_embedding.to(
|
425 |
-
dtype=query_layer.dtype
|
426 |
-
) # fp16 compatibility
|
427 |
|
428 |
-
|
429 |
-
|
430 |
-
|
431 |
-
|
432 |
-
|
433 |
-
elif self.position_embedding_type == "relative_key_query":
|
434 |
-
relative_position_scores_query = torch.einsum(
|
435 |
-
"bhld,lrd->bhlr", query_layer, positional_embedding
|
436 |
-
)
|
437 |
-
relative_position_scores_key = torch.einsum(
|
438 |
-
"bhrd,lrd->bhlr", key_layer, positional_embedding
|
439 |
-
)
|
440 |
-
attention_scores = (
|
441 |
-
attention_scores
|
442 |
-
+ relative_position_scores_query
|
443 |
-
+ relative_position_scores_key
|
444 |
-
)
|
445 |
|
446 |
-
|
447 |
-
|
448 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
449 |
|
450 |
-
|
451 |
-
|
|
|
|
|
|
|
|
|
|
|
452 |
|
453 |
-
|
454 |
-
|
455 |
-
attention_probs = self.dropout(attention_probs)
|
456 |
|
457 |
-
|
458 |
-
|
459 |
-
attention_probs = attention_probs * head_mask
|
460 |
|
461 |
-
|
|
|
|
|
|
|
462 |
|
463 |
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
464 |
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
465 |
context_layer = context_layer.view(new_context_layer_shape)
|
466 |
|
467 |
-
outputs = (
|
468 |
-
(context_layer, attention_probs) if output_attentions else (context_layer,)
|
469 |
-
)
|
470 |
-
|
471 |
if self.is_decoder:
|
472 |
outputs = outputs + (past_key_value,)
|
473 |
return outputs
|
474 |
|
475 |
-
|
476 |
# Copied from transformers.models.esm.modeling_esm.EsmSelfOutput with Esm->OmniGenome
|
477 |
class OmniGenomeSelfOutput(nn.Module):
|
478 |
def __init__(self, config):
|
@@ -530,6 +689,7 @@ class OmniGenomeAttention(nn.Module):
|
|
530 |
output_attentions=False,
|
531 |
):
|
532 |
hidden_states_ln = self.LayerNorm(hidden_states)
|
|
|
533 |
self_outputs = self.self(
|
534 |
hidden_states_ln,
|
535 |
attention_mask,
|
@@ -1053,6 +1213,7 @@ class OmniGenomeModel(OmniGenomePreTrainedModel):
|
|
1053 |
inputs_embeds=inputs_embeds,
|
1054 |
past_key_values_length=past_key_values_length,
|
1055 |
)
|
|
|
1056 |
encoder_outputs = self.encoder(
|
1057 |
embedding_output,
|
1058 |
attention_mask=extended_attention_mask,
|
@@ -1117,7 +1278,7 @@ class OmniGenomeForMaskedLM(OmniGenomePreTrainedModel):
|
|
1117 |
|
1118 |
self.OmniGenome = OmniGenomeModel(config, add_pooling_layer=False)
|
1119 |
self.lm_head = OmniGenomeLMHead(config)
|
1120 |
-
|
1121 |
|
1122 |
def get_output_embeddings(self):
|
1123 |
return self.lm_head.decoder
|
@@ -1237,7 +1398,7 @@ class OmniGenomeForSequenceClassification(OmniGenomePreTrainedModel):
|
|
1237 |
self.config = config
|
1238 |
self.OmniGenome = OmniGenomeModel(config, add_pooling_layer=False)
|
1239 |
self.classifier = OmniGenomeClassificationHead(config)
|
1240 |
-
|
1241 |
|
1242 |
@add_start_docstrings_to_model_forward(
|
1243 |
OmniGenome_INPUTS_DOCSTRING.format("batch_size, sequence_length")
|
@@ -1279,8 +1440,8 @@ class OmniGenomeForSequenceClassification(OmniGenomePreTrainedModel):
|
|
1279 |
output_hidden_states=output_hidden_states,
|
1280 |
return_dict=return_dict,
|
1281 |
)
|
1282 |
-
|
1283 |
-
logits = self.classifier(
|
1284 |
|
1285 |
loss = None
|
1286 |
if labels is not None:
|
@@ -1336,12 +1497,10 @@ class OmniGenomeForTokenClassification(OmniGenomePreTrainedModel):
|
|
1336 |
super().__init__(config)
|
1337 |
self.num_labels = config.num_labels
|
1338 |
self.OmniGenome = OmniGenomeModel(config, add_pooling_layer=False)
|
1339 |
-
self.lm_head = OmniGenomeLMHead(config)
|
1340 |
self.dense = torch.nn.Linear(config.hidden_size, config.hidden_size)
|
1341 |
self.classifier = torch.nn.Linear(self.config.hidden_size, self.num_labels)
|
1342 |
-
self.
|
1343 |
-
self.
|
1344 |
-
# self.init_weights()
|
1345 |
|
1346 |
@add_start_docstrings_to_model_forward(
|
1347 |
OmniGenome_INPUTS_DOCSTRING.format("batch_size, sequence_length")
|
@@ -1367,12 +1526,12 @@ class OmniGenomeForTokenClassification(OmniGenomePreTrainedModel):
|
|
1367 |
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
1368 |
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
|
1369 |
"""
|
1370 |
-
|
1371 |
return_dict = (
|
1372 |
return_dict if return_dict is not None else self.config.use_return_dict
|
1373 |
)
|
1374 |
-
|
1375 |
-
|
1376 |
input_ids,
|
1377 |
attention_mask=attention_mask,
|
1378 |
position_ids=position_ids,
|
@@ -1382,17 +1541,11 @@ class OmniGenomeForTokenClassification(OmniGenomePreTrainedModel):
|
|
1382 |
output_hidden_states=output_hidden_states,
|
1383 |
return_dict=return_dict,
|
1384 |
)
|
1385 |
-
try:
|
1386 |
-
last_hidden_state = mlm_outputs[0]
|
1387 |
-
last_hidden_state = self.dense(last_hidden_state)
|
1388 |
-
except:
|
1389 |
-
last_hidden_state = mlm_outputs.hidden_states[-1]
|
1390 |
-
last_hidden_state = self.dense(last_hidden_state)
|
1391 |
|
|
|
|
|
1392 |
logits = self.classifier(last_hidden_state)
|
1393 |
-
logits =
|
1394 |
-
logits = self.activation(logits)
|
1395 |
-
logits = self.dropout(logits)
|
1396 |
|
1397 |
loss = None
|
1398 |
if labels is not None:
|
@@ -1400,14 +1553,14 @@ class OmniGenomeForTokenClassification(OmniGenomePreTrainedModel):
|
|
1400 |
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
1401 |
|
1402 |
if not return_dict:
|
1403 |
-
output = (logits,) +
|
1404 |
return ((loss,) + output) if loss is not None else output
|
1405 |
|
1406 |
return TokenClassifierOutput(
|
1407 |
loss=loss,
|
1408 |
logits=logits,
|
1409 |
-
hidden_states=
|
1410 |
-
attentions=
|
1411 |
)
|
1412 |
|
1413 |
@staticmethod
|
@@ -1433,15 +1586,26 @@ class OmniGenomeForTokenClassification(OmniGenomePreTrainedModel):
|
|
1433 |
|
1434 |
return structure
|
1435 |
|
1436 |
-
def
|
1437 |
self,
|
1438 |
-
|
1439 |
-
attention_mask: Optional[torch.Tensor] = None,
|
1440 |
**kwargs
|
1441 |
) -> List[str]:
|
|
|
|
|
|
|
1442 |
"""
|
1443 |
-
|
1444 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1445 |
outputs = self.forward(input_ids, attention_mask, **kwargs)
|
1446 |
|
1447 |
logits = torch.argmax(outputs.logits, dim=-1)
|
@@ -1458,18 +1622,26 @@ class OmniGenomeForTokenClassification(OmniGenomePreTrainedModel):
|
|
1458 |
|
1459 |
@add_start_docstrings(
|
1460 |
"""
|
1461 |
-
|
|
|
1462 |
""",
|
1463 |
OmniGenome_START_DOCSTRING,
|
1464 |
)
|
1465 |
-
class
|
1466 |
def __init__(self, config):
|
1467 |
super().__init__(config)
|
1468 |
self.num_labels = config.num_labels
|
1469 |
-
self.OmniGenome =
|
|
|
1470 |
self.num_generation = config.num_generation
|
1471 |
self.num_population = config.num_population
|
1472 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
1473 |
|
1474 |
@add_start_docstrings_to_model_forward(
|
1475 |
OmniGenome_INPUTS_DOCSTRING.format("batch_size, sequence_length")
|
@@ -1495,43 +1667,199 @@ class OmniGenomeMaskedLMForRNADesign(OmniGenomePreTrainedModel):
|
|
1495 |
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
1496 |
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
|
1497 |
"""
|
1498 |
-
|
1499 |
-
|
1500 |
-
)
|
1501 |
-
|
1502 |
-
outputs = self.OmniGenome(
|
1503 |
-
input_ids,
|
1504 |
-
attention_mask=attention_mask,
|
1505 |
-
position_ids=position_ids,
|
1506 |
-
head_mask=head_mask,
|
1507 |
-
inputs_embeds=inputs_embeds,
|
1508 |
-
output_attentions=output_attentions,
|
1509 |
-
output_hidden_states=output_hidden_states,
|
1510 |
-
return_dict=return_dict,
|
1511 |
-
)
|
1512 |
-
|
1513 |
-
sequence_output = outputs[0]
|
1514 |
-
|
1515 |
-
sequence_output = self.dropout(sequence_output)
|
1516 |
-
logits = self.classifier(sequence_output)
|
1517 |
|
1518 |
-
|
1519 |
-
|
1520 |
-
|
1521 |
-
|
1522 |
-
|
1523 |
-
|
1524 |
-
|
1525 |
-
|
1526 |
-
|
1527 |
-
|
1528 |
-
|
1529 |
-
|
1530 |
-
|
1531 |
-
|
1532 |
-
|
1533 |
-
|
1534 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1535 |
|
1536 |
|
1537 |
# Copied from transformers.models.esm.modeling_esm.EsmClassificationHead with Esm->OmniGenome
|
|
|
15 |
""" PyTorch OmniGenome model."""
|
16 |
|
17 |
import math
|
18 |
+
import random
|
19 |
+
import warnings
|
20 |
from typing import List, Optional, Tuple, Union
|
21 |
|
22 |
+
import numpy as np
|
23 |
import torch
|
24 |
import torch.utils.checkpoint
|
25 |
from torch import nn
|
|
|
303 |
)
|
304 |
return position_ids.unsqueeze(0).expand(input_shape)
|
305 |
|
306 |
+
#
|
307 |
+
# # Copied from transformers.models.esm.modeling_esm.EsmSelfAttention with Esm->OmniGenome
|
308 |
+
# class OmniGenomeSelfAttention(nn.Module):
|
309 |
+
# def __init__(self, config, position_embedding_type=None):
|
310 |
+
# super().__init__()
|
311 |
+
# if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
|
312 |
+
# config, "embedding_size"
|
313 |
+
# ):
|
314 |
+
# raise ValueError(
|
315 |
+
# f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
|
316 |
+
# f"heads ({config.num_attention_heads})"
|
317 |
+
# )
|
318 |
+
#
|
319 |
+
# self.num_attention_heads = config.num_attention_heads
|
320 |
+
# self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
321 |
+
# self.all_head_size = self.num_attention_heads * self.attention_head_size
|
322 |
+
#
|
323 |
+
# self.query = nn.Linear(config.hidden_size, self.all_head_size)
|
324 |
+
# self.key = nn.Linear(config.hidden_size, self.all_head_size)
|
325 |
+
# self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
326 |
+
#
|
327 |
+
# self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
328 |
+
# self.position_embedding_type = position_embedding_type or getattr(
|
329 |
+
# config, "position_embedding_type", "absolute"
|
330 |
+
# )
|
331 |
+
# self.rotary_embeddings = None
|
332 |
+
# if (
|
333 |
+
# self.position_embedding_type == "relative_key"
|
334 |
+
# or self.position_embedding_type == "relative_key_query"
|
335 |
+
# ):
|
336 |
+
# self.max_position_embeddings = config.max_position_embeddings
|
337 |
+
# self.distance_embedding = nn.Embedding(
|
338 |
+
# 2 * config.max_position_embeddings - 1, self.attention_head_size
|
339 |
+
# )
|
340 |
+
# elif self.position_embedding_type == "rotary":
|
341 |
+
# self.rotary_embeddings = RotaryEmbedding(dim=self.attention_head_size)
|
342 |
+
#
|
343 |
+
# self.is_decoder = config.is_decoder
|
344 |
+
#
|
345 |
+
# def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
|
346 |
+
# new_x_shape = x.size()[:-1] + (
|
347 |
+
# self.num_attention_heads,
|
348 |
+
# self.attention_head_size,
|
349 |
+
# )
|
350 |
+
# x = x.view(new_x_shape)
|
351 |
+
# return x.permute(0, 2, 1, 3)
|
352 |
+
#
|
353 |
+
# def forward(
|
354 |
+
# self,
|
355 |
+
# hidden_states: torch.Tensor,
|
356 |
+
# attention_mask: Optional[torch.FloatTensor] = None,
|
357 |
+
# head_mask: Optional[torch.FloatTensor] = None,
|
358 |
+
# encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
359 |
+
# encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
360 |
+
# past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
361 |
+
# output_attentions: Optional[bool] = False,
|
362 |
+
# ) -> Tuple[torch.Tensor]:
|
363 |
+
# mixed_query_layer = self.query(hidden_states)
|
364 |
+
#
|
365 |
+
# # If this is instantiated as a cross-attention module, the keys
|
366 |
+
# # and values come from an encoder; the attention mask needs to be
|
367 |
+
# # such that the encoder's padding tokens are not attended to.
|
368 |
+
# is_cross_attention = encoder_hidden_states is not None
|
369 |
+
#
|
370 |
+
# if is_cross_attention and past_key_value is not None:
|
371 |
+
# # reuse k,v, cross_attentions
|
372 |
+
# key_layer = past_key_value[0]
|
373 |
+
# value_layer = past_key_value[1]
|
374 |
+
# attention_mask = encoder_attention_mask
|
375 |
+
# elif is_cross_attention:
|
376 |
+
# key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
|
377 |
+
# value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
|
378 |
+
# attention_mask = encoder_attention_mask
|
379 |
+
# elif past_key_value is not None:
|
380 |
+
# key_layer = self.transpose_for_scores(self.key(hidden_states))
|
381 |
+
# value_layer = self.transpose_for_scores(self.value(hidden_states))
|
382 |
+
# key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
|
383 |
+
# value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
|
384 |
+
# else:
|
385 |
+
# key_layer = self.transpose_for_scores(self.key(hidden_states))
|
386 |
+
# value_layer = self.transpose_for_scores(self.value(hidden_states))
|
387 |
+
#
|
388 |
+
# query_layer = self.transpose_for_scores(mixed_query_layer)
|
389 |
+
#
|
390 |
+
# # Matt: Our BERT model (which this code was derived from) scales attention logits down by sqrt(head_dim).
|
391 |
+
# # OmniGenome scales the query down by the same factor instead. Modulo numerical stability these are equivalent,
|
392 |
+
# # but not when rotary embeddings get involved. Therefore, we scale the query here to match the original
|
393 |
+
# # OmniGenome code and fix rotary embeddings.
|
394 |
+
# query_layer = query_layer * self.attention_head_size ** -0.5
|
395 |
+
#
|
396 |
+
# if self.is_decoder:
|
397 |
+
# # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
398 |
+
# # Further calls to cross_attention layer can then reuse all cross-attention
|
399 |
+
# # key/value_states (first "if" case)
|
400 |
+
# # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
401 |
+
# # all previous decoder key/value_states. Further calls to uni-directional self-attention
|
402 |
+
# # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
403 |
+
# # if encoder bi-directional self-attention `past_key_value` is always `None`
|
404 |
+
# past_key_value = (key_layer, value_layer)
|
405 |
+
#
|
406 |
+
# if self.position_embedding_type == "rotary":
|
407 |
+
# query_layer, key_layer = self.rotary_embeddings(query_layer, key_layer)
|
408 |
+
#
|
409 |
+
# # Take the dot product between "query" and "key" to get the raw attention scores.
|
410 |
+
# attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
411 |
+
#
|
412 |
+
# if (
|
413 |
+
# self.position_embedding_type == "relative_key"
|
414 |
+
# or self.position_embedding_type == "relative_key_query"
|
415 |
+
# ):
|
416 |
+
# seq_length = hidden_states.size()[1]
|
417 |
+
# position_ids_l = torch.arange(
|
418 |
+
# seq_length, dtype=torch.long, device=hidden_states.device
|
419 |
+
# ).view(-1, 1)
|
420 |
+
# position_ids_r = torch.arange(
|
421 |
+
# seq_length, dtype=torch.long, device=hidden_states.device
|
422 |
+
# ).view(1, -1)
|
423 |
+
# distance = position_ids_l - position_ids_r
|
424 |
+
# positional_embedding = self.distance_embedding(
|
425 |
+
# distance + self.max_position_embeddings - 1
|
426 |
+
# )
|
427 |
+
# positional_embedding = positional_embedding.to(
|
428 |
+
# dtype=query_layer.dtype
|
429 |
+
# ) # fp16 compatibility
|
430 |
+
#
|
431 |
+
# if self.position_embedding_type == "relative_key":
|
432 |
+
# relative_position_scores = torch.einsum(
|
433 |
+
# "bhld,lrd->bhlr", query_layer, positional_embedding
|
434 |
+
# )
|
435 |
+
# attention_scores = attention_scores + relative_position_scores
|
436 |
+
# elif self.position_embedding_type == "relative_key_query":
|
437 |
+
# relative_position_scores_query = torch.einsum(
|
438 |
+
# "bhld,lrd->bhlr", query_layer, positional_embedding
|
439 |
+
# )
|
440 |
+
# relative_position_scores_key = torch.einsum(
|
441 |
+
# "bhrd,lrd->bhlr", key_layer, positional_embedding
|
442 |
+
# )
|
443 |
+
# attention_scores = (
|
444 |
+
# attention_scores
|
445 |
+
# + relative_position_scores_query
|
446 |
+
# + relative_position_scores_key
|
447 |
+
# )
|
448 |
+
#
|
449 |
+
# if attention_mask is not None:
|
450 |
+
# # Apply the attention mask is (precomputed for all layers in OmniGenomeModel forward() function)
|
451 |
+
# attention_scores = attention_scores + attention_mask
|
452 |
+
#
|
453 |
+
# # Normalize the attention scores to probabilities.
|
454 |
+
# attention_probs = nn.functional.softmax(attention_scores, dim=-1)
|
455 |
+
#
|
456 |
+
# # This is actually dropping out entire tokens to attend to, which might
|
457 |
+
# # seem a bit unusual, but is taken from the original Transformer paper.
|
458 |
+
# attention_probs = self.dropout(attention_probs)
|
459 |
+
#
|
460 |
+
# # Mask heads if we want to
|
461 |
+
# if head_mask is not None:
|
462 |
+
# attention_probs = attention_probs * head_mask
|
463 |
+
#
|
464 |
+
# context_layer = torch.matmul(attention_probs, value_layer)
|
465 |
+
#
|
466 |
+
# context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
467 |
+
# new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
468 |
+
# context_layer = context_layer.view(new_context_layer_shape)
|
469 |
+
#
|
470 |
+
# outputs = (
|
471 |
+
# (context_layer, attention_probs) if output_attentions else (context_layer,)
|
472 |
+
# )
|
473 |
+
#
|
474 |
+
# if self.is_decoder:
|
475 |
+
# outputs = outputs + (past_key_value,)
|
476 |
+
# return outputs
|
477 |
+
|
478 |
|
479 |
# Copied from transformers.models.esm.modeling_esm.EsmSelfAttention with Esm->OmniGenome
|
480 |
class OmniGenomeSelfAttention(nn.Module):
|
|
|
514 |
|
515 |
self.is_decoder = config.is_decoder
|
516 |
|
517 |
+
# FlashAttention parameters
|
518 |
+
self.enable_flash_attn = getattr(config, "use_flash_attention", True)
|
519 |
+
if self.enable_flash_attn:
|
520 |
+
from flash_attn import flash_attn_func
|
521 |
+
self.flash_attn_func = flash_attn_func
|
522 |
+
else:
|
523 |
+
self.flash_attn_func = None
|
524 |
+
|
525 |
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
|
526 |
new_x_shape = x.size()[:-1] + (
|
527 |
self.num_attention_heads,
|
|
|
542 |
) -> Tuple[torch.Tensor]:
|
543 |
mixed_query_layer = self.query(hidden_states)
|
544 |
|
|
|
|
|
|
|
545 |
is_cross_attention = encoder_hidden_states is not None
|
546 |
|
547 |
if is_cross_attention and past_key_value is not None:
|
|
|
548 |
key_layer = past_key_value[0]
|
549 |
value_layer = past_key_value[1]
|
550 |
attention_mask = encoder_attention_mask
|
|
|
563 |
|
564 |
query_layer = self.transpose_for_scores(mixed_query_layer)
|
565 |
|
|
|
|
|
|
|
|
|
|
|
|
|
566 |
if self.is_decoder:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
567 |
past_key_value = (key_layer, value_layer)
|
568 |
|
569 |
+
# 使用FlashAttention的条件判断
|
570 |
+
use_flash_attn = self.enable_flash_attn and self.position_embedding_type == "rotary"
|
571 |
+
if use_flash_attn:
|
572 |
+
# 应用旋转位置编码
|
573 |
query_layer, key_layer = self.rotary_embeddings(query_layer, key_layer)
|
574 |
|
575 |
+
# 调整维度顺序为 [batch_size, seq_len, num_heads, head_dim]
|
576 |
+
q = query_layer.transpose(1, 2).half()
|
577 |
+
k = key_layer.transpose(1, 2).half()
|
578 |
+
v = value_layer.transpose(1, 2).half()
|
579 |
+
|
580 |
+
# 使用FlashAttention计算
|
581 |
+
context_layer = self.flash_attn_func(
|
582 |
+
q, k, v,
|
583 |
+
dropout_p=self.dropout.p if self.training else 0.0,
|
584 |
+
softmax_scale=self.attention_head_size ** -0.5,
|
585 |
+
causal=self.is_decoder
|
|
|
|
|
|
|
|
|
|
|
|
|
586 |
)
|
|
|
|
|
|
|
587 |
|
588 |
+
# 恢复维度顺序 [batch_size, num_heads, seq_len, head_dim]
|
589 |
+
context_layer = context_layer.transpose(1, 2).to(hidden_states.dtype)
|
590 |
+
else:
|
591 |
+
# 原始实现
|
592 |
+
query_layer = query_layer * self.attention_head_size ** -0.5
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
593 |
|
594 |
+
if self.position_embedding_type == "rotary":
|
595 |
+
query_layer, key_layer = self.rotary_embeddings(query_layer, key_layer)
|
596 |
+
|
597 |
+
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
598 |
+
|
599 |
+
if self.position_embedding_type in ["relative_key", "relative_key_query"]:
|
600 |
+
seq_length = hidden_states.size()[1]
|
601 |
+
position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
|
602 |
+
position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
|
603 |
+
distance = position_ids_l - position_ids_r
|
604 |
+
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
|
605 |
+
positional_embedding = positional_embedding.to(dtype=query_layer.dtype)
|
606 |
|
607 |
+
if self.position_embedding_type == "relative_key":
|
608 |
+
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
609 |
+
attention_scores = attention_scores + relative_position_scores
|
610 |
+
elif self.position_embedding_type == "relative_key_query":
|
611 |
+
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
612 |
+
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
|
613 |
+
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
|
614 |
|
615 |
+
if attention_mask is not None:
|
616 |
+
attention_scores = attention_scores + attention_mask
|
|
|
617 |
|
618 |
+
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
|
619 |
+
attention_probs = self.dropout(attention_probs)
|
|
|
620 |
|
621 |
+
if head_mask is not None:
|
622 |
+
attention_probs = attention_probs * head_mask
|
623 |
+
|
624 |
+
context_layer = torch.matmul(attention_probs, value_layer)
|
625 |
|
626 |
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
627 |
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
628 |
context_layer = context_layer.view(new_context_layer_shape)
|
629 |
|
630 |
+
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
|
|
|
|
|
|
631 |
if self.is_decoder:
|
632 |
outputs = outputs + (past_key_value,)
|
633 |
return outputs
|
634 |
|
|
|
635 |
# Copied from transformers.models.esm.modeling_esm.EsmSelfOutput with Esm->OmniGenome
|
636 |
class OmniGenomeSelfOutput(nn.Module):
|
637 |
def __init__(self, config):
|
|
|
689 |
output_attentions=False,
|
690 |
):
|
691 |
hidden_states_ln = self.LayerNorm(hidden_states)
|
692 |
+
hidden_states_ln = hidden_states_ln.to(hidden_states.dtype)
|
693 |
self_outputs = self.self(
|
694 |
hidden_states_ln,
|
695 |
attention_mask,
|
|
|
1213 |
inputs_embeds=inputs_embeds,
|
1214 |
past_key_values_length=past_key_values_length,
|
1215 |
)
|
1216 |
+
embedding_output = embedding_output.half()
|
1217 |
encoder_outputs = self.encoder(
|
1218 |
embedding_output,
|
1219 |
attention_mask=extended_attention_mask,
|
|
|
1278 |
|
1279 |
self.OmniGenome = OmniGenomeModel(config, add_pooling_layer=False)
|
1280 |
self.lm_head = OmniGenomeLMHead(config)
|
1281 |
+
self.init_weights()
|
1282 |
|
1283 |
def get_output_embeddings(self):
|
1284 |
return self.lm_head.decoder
|
|
|
1398 |
self.config = config
|
1399 |
self.OmniGenome = OmniGenomeModel(config, add_pooling_layer=False)
|
1400 |
self.classifier = OmniGenomeClassificationHead(config)
|
1401 |
+
self.init_weights()
|
1402 |
|
1403 |
@add_start_docstrings_to_model_forward(
|
1404 |
OmniGenome_INPUTS_DOCSTRING.format("batch_size, sequence_length")
|
|
|
1440 |
output_hidden_states=output_hidden_states,
|
1441 |
return_dict=return_dict,
|
1442 |
)
|
1443 |
+
last_hidden_state = outputs[0]
|
1444 |
+
logits = self.classifier(last_hidden_state)
|
1445 |
|
1446 |
loss = None
|
1447 |
if labels is not None:
|
|
|
1497 |
super().__init__(config)
|
1498 |
self.num_labels = config.num_labels
|
1499 |
self.OmniGenome = OmniGenomeModel(config, add_pooling_layer=False)
|
|
|
1500 |
self.dense = torch.nn.Linear(config.hidden_size, config.hidden_size)
|
1501 |
self.classifier = torch.nn.Linear(self.config.hidden_size, self.num_labels)
|
1502 |
+
self.softmax = nn.Softmax(dim=-1)
|
1503 |
+
self.init_weights()
|
|
|
1504 |
|
1505 |
@add_start_docstrings_to_model_forward(
|
1506 |
OmniGenome_INPUTS_DOCSTRING.format("batch_size, sequence_length")
|
|
|
1526 |
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
1527 |
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
|
1528 |
"""
|
1529 |
+
|
1530 |
return_dict = (
|
1531 |
return_dict if return_dict is not None else self.config.use_return_dict
|
1532 |
)
|
1533 |
+
|
1534 |
+
outputs = self.OmniGenome(
|
1535 |
input_ids,
|
1536 |
attention_mask=attention_mask,
|
1537 |
position_ids=position_ids,
|
|
|
1541 |
output_hidden_states=output_hidden_states,
|
1542 |
return_dict=return_dict,
|
1543 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
1544 |
|
1545 |
+
last_hidden_state = outputs[0]
|
1546 |
+
last_hidden_state = self.dense(last_hidden_state)
|
1547 |
logits = self.classifier(last_hidden_state)
|
1548 |
+
logits = self.softmax(logits)
|
|
|
|
|
1549 |
|
1550 |
loss = None
|
1551 |
if labels is not None:
|
|
|
1553 |
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
1554 |
|
1555 |
if not return_dict:
|
1556 |
+
output = (logits,) + outputs[2:]
|
1557 |
return ((loss,) + output) if loss is not None else output
|
1558 |
|
1559 |
return TokenClassifierOutput(
|
1560 |
loss=loss,
|
1561 |
logits=logits,
|
1562 |
+
hidden_states=outputs.hidden_states,
|
1563 |
+
attentions=outputs.attentions,
|
1564 |
)
|
1565 |
|
1566 |
@staticmethod
|
|
|
1586 |
|
1587 |
return structure
|
1588 |
|
1589 |
+
def predict_rna_structure(
|
1590 |
self,
|
1591 |
+
sequence: str,
|
|
|
1592 |
**kwargs
|
1593 |
) -> List[str]:
|
1594 |
+
r"""
|
1595 |
+
Load the pretrained OmniGenome Model to do zero-shot prediction of the secondary structure
|
1596 |
+
of a sequence given the sequence
|
1597 |
"""
|
1598 |
+
if self.tokenizer is None:
|
1599 |
+
tokenizer = kwargs.get("tokenizer", None)
|
1600 |
+
if tokenizer is None:
|
1601 |
+
from transformers import AutoTokenizer
|
1602 |
+
self.tokenizer = AutoTokenizer.from_pretrained(self.config.name_or_path)
|
1603 |
+
else:
|
1604 |
+
self.tokenizer = tokenizer
|
1605 |
+
|
1606 |
+
inputs = self.tokenizer(sequence, return_tensors="pt", padding="max_length", truncation=True)
|
1607 |
+
input_ids = inputs["input_ids"]
|
1608 |
+
attention_mask = inputs["attention_mask"]
|
1609 |
outputs = self.forward(input_ids, attention_mask, **kwargs)
|
1610 |
|
1611 |
logits = torch.argmax(outputs.logits, dim=-1)
|
|
|
1622 |
|
1623 |
@add_start_docstrings(
|
1624 |
"""
|
1625 |
+
This is not a standard Seq2Seq model. Instead, this model is designed for RNA design tasks.
|
1626 |
+
This is the OmniGenome Model with a simple genetic algorithm based RNA design head on top.
|
1627 |
""",
|
1628 |
OmniGenome_START_DOCSTRING,
|
1629 |
)
|
1630 |
+
class OmniGenomeModelForSeq2SeqLM(OmniGenomePreTrainedModel):
|
1631 |
def __init__(self, config):
|
1632 |
super().__init__(config)
|
1633 |
self.num_labels = config.num_labels
|
1634 |
+
self.OmniGenome = OmniGenomeModel(config, add_pooling_layer=False)
|
1635 |
+
self.lm_head = OmniGenomeLMHead(config)
|
1636 |
self.num_generation = config.num_generation
|
1637 |
self.num_population = config.num_population
|
1638 |
+
self.init_weights()
|
1639 |
+
|
1640 |
+
self.tokenizer = None
|
1641 |
+
self.predict_structure = None
|
1642 |
+
|
1643 |
+
warnings.warn(f"This model {self.__class__.__name__} is not a real Seq2Seq model. "
|
1644 |
+
f"Instead, this model is designed for RNA design tasks")
|
1645 |
|
1646 |
@add_start_docstrings_to_model_forward(
|
1647 |
OmniGenome_INPUTS_DOCSTRING.format("batch_size, sequence_length")
|
|
|
1667 |
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
1668 |
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
|
1669 |
"""
|
1670 |
+
raise NotImplementedError("This model is not designed for standard Seq2Seq tasks. "
|
1671 |
+
"Use model.rna_sequence_design() for RNA sequences design instead.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1672 |
|
1673 |
+
def rna_sequence_design(
|
1674 |
+
self,
|
1675 |
+
structure: str,
|
1676 |
+
predict_structure_func=None,
|
1677 |
+
**kwargs
|
1678 |
+
) -> List[str]:
|
1679 |
+
"""
|
1680 |
+
Assemble the RNA sequence given the reference sequence structure
|
1681 |
+
"""
|
1682 |
+
if self.tokenizer is None:
|
1683 |
+
tokenizer = kwargs.get("tokenizer", None)
|
1684 |
+
if tokenizer is None:
|
1685 |
+
from transformers import AutoTokenizer
|
1686 |
+
self.tokenizer = AutoTokenizer.from_pretrained(self.config.name_or_path)
|
1687 |
+
else:
|
1688 |
+
self.tokenizer = tokenizer
|
1689 |
+
|
1690 |
+
candidates = self.genetic_algorithm_for_rna_design(structure, predict_structure_func=None, **kwargs)
|
1691 |
+
|
1692 |
+
return candidates
|
1693 |
+
|
1694 |
+
def genetic_algorithm_for_rna_design(self, structure, predict_structure_func=None, **kwargs):
|
1695 |
+
if predict_structure_func is None:
|
1696 |
+
import ViennaRNA
|
1697 |
+
|
1698 |
+
def predict_structure(sequence):
|
1699 |
+
return ViennaRNA.fold(sequence)[0]
|
1700 |
+
|
1701 |
+
predict_structure_func = predict_structure
|
1702 |
+
|
1703 |
+
self.predict_structure = predict_structure_func
|
1704 |
+
mutation_ratio = kwargs.get("mutation_ratio", 0.5)
|
1705 |
+
num_population = kwargs.get("num_population", self.num_population)
|
1706 |
+
num_generation = kwargs.get("num_generation", self.num_generation)
|
1707 |
+
import tqdm
|
1708 |
+
population = self.init_population(structure, num_population)
|
1709 |
+
population = self.mlm_mutate(population, structure, mutation_ratio=mutation_ratio)
|
1710 |
+
for generation_id in tqdm.tqdm(range(num_generation), desc="Designing RNA Sequence"):
|
1711 |
+
population_fitness = self.sequence_fitness(population, structure)[:num_population]
|
1712 |
+
population = sorted(zip(population, population_fitness), key=lambda x: x[1])[:num_population]
|
1713 |
+
population = [x[0] for x in population]
|
1714 |
+
next_generation = population # Elitism
|
1715 |
+
next_generation += self.crossover(population, structure)
|
1716 |
+
next_generation += self.mlm_mutate(next_generation, structure, mutation_ratio)
|
1717 |
+
fitness_values = self.sequence_fitness(next_generation, structure)
|
1718 |
+
next_generation = sorted(zip(next_generation, fitness_values), key=lambda x: x[1])
|
1719 |
+
|
1720 |
+
candidate_sequences = []
|
1721 |
+
for sequence, fitness in next_generation:
|
1722 |
+
if fitness == 0:
|
1723 |
+
candidate_sequences.append(sequence)
|
1724 |
+
else:
|
1725 |
+
break
|
1726 |
+
if candidate_sequences:
|
1727 |
+
return candidate_sequences
|
1728 |
+
print(f"Generation {generation_id}: {next_generation[0][0]} with fitness {next_generation[0][1]}")
|
1729 |
+
population = [x[0] for x in next_generation[:num_population]]
|
1730 |
+
|
1731 |
+
return []
|
1732 |
+
|
1733 |
+
def init_population(self, structure, num_population):
|
1734 |
+
# Initialize lists to store population data and inputs for masked language model
|
1735 |
+
population = []
|
1736 |
+
mlm_inputs = []
|
1737 |
+
# Iterate over the number of individuals in the population
|
1738 |
+
for _ in range(num_population): # Changed from self.num_population to num_population
|
1739 |
+
# Create a sequence by randomly choosing nucleotides or a mask token for each position in the structure
|
1740 |
+
masked_sequence = [
|
1741 |
+
random.choice(["A", "G", "C", "T", "<mask>"])
|
1742 |
+
for _ in range(len(structure))
|
1743 |
+
]
|
1744 |
+
masked_sequence_str = "".join(masked_sequence)
|
1745 |
+
mlm_inputs.append(f"{masked_sequence_str}<eos>{''.join(structure)}")
|
1746 |
+
|
1747 |
+
# Call a function to predict outputs using the masked language model
|
1748 |
+
outputs = self.mlm_predict(mlm_inputs, structure)
|
1749 |
+
|
1750 |
+
# Decode the mlm outputs and construct the initial population
|
1751 |
+
for i in range(len(outputs)):
|
1752 |
+
sequence = self.tokenizer.convert_ids_to_tokens(outputs[i].tolist())
|
1753 |
+
fixed_sequence = [
|
1754 |
+
x if x in "AGCT" else random.choice(["G", "C"])
|
1755 |
+
for x, y in zip(sequence, list(mlm_inputs[i].replace('<mask>', '$')))
|
1756 |
+
]
|
1757 |
+
population.append("".join(fixed_sequence))
|
1758 |
+
|
1759 |
+
return population
|
1760 |
+
|
1761 |
+
def mlm_mutate(self, population, structure, mutation_ratio):
|
1762 |
+
def mutate(sequence, mutation_rate):
|
1763 |
+
sequence = np.array(list(sequence), dtype=np.str_)
|
1764 |
+
probability_matrix = np.full(sequence.shape, mutation_rate)
|
1765 |
+
masked_indices = np.random.rand(*sequence.shape) < probability_matrix
|
1766 |
+
sequence[masked_indices] = "$"
|
1767 |
+
mut_seq = "".join(sequence.tolist()).replace("$", "<mask>")
|
1768 |
+
return mut_seq
|
1769 |
+
|
1770 |
+
# Initialize lists to store population data and inputs for masked language model
|
1771 |
+
mlm_inputs = []
|
1772 |
+
masked_sequences = []
|
1773 |
+
|
1774 |
+
# Iterate over the number of individuals in the population
|
1775 |
+
for sequence in population:
|
1776 |
+
# Create a sequence by randomly choosing nucleotides or a mask token for each position in the structure
|
1777 |
+
masked_sequence = mutate(sequence, mutation_ratio)
|
1778 |
+
masked_sequences.append(masked_sequence)
|
1779 |
+
mlm_inputs.append(f"{masked_sequence}<eos>{''.join(structure)}")
|
1780 |
+
|
1781 |
+
# Call a function to predict outputs using the masked language model
|
1782 |
+
outputs = self.mlm_predict(mlm_inputs, structure)
|
1783 |
+
|
1784 |
+
mut_population = []
|
1785 |
+
|
1786 |
+
# Decode the mlm outputs and construct the initial population
|
1787 |
+
for i in range(len(outputs)):
|
1788 |
+
sequence = self.tokenizer.convert_ids_to_tokens(outputs[i].tolist())
|
1789 |
+
fixed_sequence = [
|
1790 |
+
x if x in "AGCT" else random.choice(["G", "C"])
|
1791 |
+
for x, y in zip(sequence, list(masked_sequences[i].replace('<mask>', '$')))
|
1792 |
+
]
|
1793 |
+
mut_population.append("".join(fixed_sequence))
|
1794 |
+
|
1795 |
+
return mut_population
|
1796 |
+
|
1797 |
+
def crossover(self, population, structure):
|
1798 |
+
crossover_population = []
|
1799 |
+
batch_crossover_inputs = []
|
1800 |
+
for i in range(len(population)):
|
1801 |
+
parent1, parent2 = random.choices(population, k=2)
|
1802 |
+
pos = random.randint(1, len(parent1) - 1)
|
1803 |
+
child1 = parent1[:pos] + "<mask>" * len(parent2[pos:])
|
1804 |
+
child2 = "<mask>" * len(parent1[:pos]) + parent2[pos:]
|
1805 |
+
batch_crossover_inputs.append(f"{child1}<eos>{structure}")
|
1806 |
+
batch_crossover_inputs.append(f"{child2}<eos>{structure}")
|
1807 |
+
|
1808 |
+
outputs = self.mlm_predict(batch_crossover_inputs, structure)
|
1809 |
+
|
1810 |
+
for i in range(len(outputs)):
|
1811 |
+
sequence = self.tokenizer.convert_ids_to_tokens(outputs[i].tolist())
|
1812 |
+
fixed_sequence = [
|
1813 |
+
x if x in "AGCT" else random.choice(["G", "C"])
|
1814 |
+
for x, y in zip(sequence, list(batch_crossover_inputs[i].replace('<mask>', '$')))
|
1815 |
+
]
|
1816 |
+
crossover_population.append("".join(fixed_sequence))
|
1817 |
+
|
1818 |
+
return crossover_population
|
1819 |
+
|
1820 |
+
def sequence_fitness(self, sequences, structure):
|
1821 |
+
fitness_values = []
|
1822 |
+
structures = [self.predict_structure(sequence) for sequence in sequences]
|
1823 |
+
for predicted_structure in structures:
|
1824 |
+
scores = []
|
1825 |
+
for i in range(len(predicted_structure)):
|
1826 |
+
if predicted_structure[i] == structure[i]:
|
1827 |
+
scores.append(1)
|
1828 |
+
elif (
|
1829 |
+
predicted_structure[i] == ")"
|
1830 |
+
and structure[i] == "("
|
1831 |
+
or predicted_structure[i] == "("
|
1832 |
+
and structure[i] == ")"
|
1833 |
+
):
|
1834 |
+
scores.append(-3)
|
1835 |
+
else:
|
1836 |
+
scores.append(0)
|
1837 |
+
score = 1 - sum(scores) / len(structure)
|
1838 |
+
fitness_values.append(score)
|
1839 |
+
return fitness_values
|
1840 |
+
|
1841 |
+
def mlm_predict(self, mlm_inputs, structure):
|
1842 |
+
batch_size = 4
|
1843 |
+
all_outputs = []
|
1844 |
+
from transformers import set_seed
|
1845 |
+
set_seed(random.randint(0, 99999999), deterministic=False)
|
1846 |
+
|
1847 |
+
with torch.no_grad():
|
1848 |
+
for i in range(0, len(mlm_inputs), batch_size):
|
1849 |
+
batch_mlm_inputs = self.tokenizer(
|
1850 |
+
mlm_inputs[i:i + batch_size],
|
1851 |
+
padding=True,
|
1852 |
+
max_length=len(mlm_inputs[0]) // 2,
|
1853 |
+
truncation=True,
|
1854 |
+
return_tensors="pt",
|
1855 |
+
)
|
1856 |
+
batch_mlm_inputs = batch_mlm_inputs.to(self.device)
|
1857 |
+
outputs = self.OmniGenome(**batch_mlm_inputs)[0]
|
1858 |
+
outputs = self.lm_head(outputs)
|
1859 |
+
outputs = outputs.argmax(dim=-1)
|
1860 |
+
all_outputs.append(outputs)
|
1861 |
+
outputs = torch.cat(all_outputs, dim=0)
|
1862 |
+
return outputs[:, 1:1 + len(structure)]
|
1863 |
|
1864 |
|
1865 |
# Copied from transformers.models.esm.modeling_esm.EsmClassificationHead with Esm->OmniGenome
|