File size: 3,872 Bytes
1182d14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
from .configuration_dna2vec import DNAEncoderConfig
from transformers import PreTrainedModel
import math
from typing import Literal, Optional
import torch
import torch.nn as nn

class Encoder(nn.Module):
    def __init__(
        self,
        vocab_size: int = 4,
        embedding_dim: int = 384,
        dim_feedforward: int = 1536,
        num_heads: int = 12,
        num_layers: int = 6,
        dropout: float = 0.1,
        activation: Literal["relu", "gelu"] = "gelu",
        pos_embedding: Optional[str] = "SinusoidalPositionalEncoding",
        max_position_embeddings: int = 1024,
    ):
        """
        Default values taken from miniLM v6
        https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2/blob/main/config.json
        """
        super().__init__()
        self.vocab_size = vocab_size
        self.embedding_dim = embedding_dim
        self.dropout = dropout
        self.num_heads = num_heads
        self.num_layers = num_layers
        self.emb_dropout = nn.Dropout(p=dropout)
        
        if pos_embedding == "SinusoidalPositionalEncoding":
            position = torch.arange(max_position_embeddings).unsqueeze(1)
            div_term = torch.exp(
                torch.arange(0, embedding_dim, 2) * (-math.log(10000.0) / embedding_dim)
            )
            pe = torch.zeros(max_position_embeddings, 1, embedding_dim)
            pe[:, 0, 0::2] = torch.sin(position * div_term)
            pe[:, 0, 1::2] = torch.cos(position * div_term)
            pe = pe.squeeze(1).unsqueeze(0)
            self.register_buffer("positional_embedding", pe)
        else:
            raise ValueError(f"Positional embedding {pos_embedding} not found")

        self.embedding = nn.Embedding(
            num_embeddings=vocab_size,
            embedding_dim=embedding_dim,
        )

        # create encode layers
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embedding_dim,
            nhead=num_heads,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            activation=activation,
            batch_first=True,
            norm_first=True,  # following: https://arxiv.org/pdf/2002.04745.pdf
        )
        self.trf_encoder = nn.TransformerEncoder(
            encoder_layer=encoder_layer, num_layers=num_layers
        )

    def forward(
        self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        # input_ids.names = ["batch", "sequence"]
        # embedding does not support named tensors

        # Embed
        emb = self.emb_dropout(
            self.embedding(input_ids) + self.positional_embedding[:, :input_ids.size(1), :]
        )
        # emb.names = ["batch", "sequence", "embedding"]

        # Contextualize embeddings
        attn = None
        if attention_mask is not None:
            attn = attention_mask == 0  # to boolean
        out = self.trf_encoder(emb, src_key_padding_mask=attn)
        # out.names = ["batch", "sequence", "embedding"]
        return out

class DNAEncoder(PreTrainedModel):
    config_class = DNAEncoderConfig
    
    def __init__(self, config: DNAEncoderConfig):
        super().__init__(config)
        self.config = config
        self.encoder =  Encoder(
            vocab_size=config.vocab_size,
            embedding_dim=config.embedding_dim,
            dim_feedforward=config.dim_feedforward,
            num_heads=config.num_heads,
            num_layers=config.num_layers,
            dropout=config.dropout,
            activation=config.activation,
            max_position_embeddings=config.max_position_embeddings,
        )
            
    def forward(
        self,
        input_ids: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        **kwargs,
    ) -> torch.Tensor:
        return self.encoder(input_ids, attention_mask)