File size: 5,862 Bytes
c9852e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# MIT_LICENSE file in the root directory of this source tree.

from typing import Optional, Tuple, final

from fairseq2.nn.incremental_state import IncrementalStateBag
from fairseq2.nn.normalization import LayerNorm
from fairseq2.nn.padding import PaddingMask
from fairseq2.nn.transformer import (
    AttentionMask,
    FeedForwardNetwork,
    MultiheadAttention,
    create_standard_layer_norm,
)
from fairseq2.typing import DataType, Device, finaloverride
from torch import Tensor
from torch.nn import Dropout, Module

from seamless_communication.models.monotonic_decoder.p_choose import PChooseLayer


@final
class MonotonicTransformerDecoderLayer(Module):
    """Represents a Monotonic Transformer decoder layer."""

    self_attn: MultiheadAttention
    self_attn_dropout: Optional[Dropout]
    self_attn_layer_norm: LayerNorm
    encoder_decoder_attn: MultiheadAttention
    encoder_decoder_attn_dropout: Optional[Dropout]
    encoder_decoder_attn_layer_norm: LayerNorm
    p_choose_layer: PChooseLayer
    ffn: FeedForwardNetwork
    ffn_dropout: Optional[Dropout]
    ffn_layer_norm: LayerNorm

    def __init__(
        self,
        self_attn: MultiheadAttention,
        encoder_decoder_attn: MultiheadAttention,
        p_choose_layer: PChooseLayer,
        ffn: FeedForwardNetwork,
        *,
        dropout_p: float = 0.1,
        device: Optional[Device] = None,
        dtype: Optional[DataType] = None,
    ) -> None:
        """
        :param self_attn:
            The self attention layer.
        :param encoder_decoder_attn:
            The encoder-decoder attention layer.
        :param ffn:
            The feed-forward network.
        :param dropout_p:
            The dropout probability on outputs of the attention layers and the
            feed-forward network.
        """
        super().__init__()

        self.model_dim = self_attn.model_dim

        self_attn_layer_norm = create_standard_layer_norm(
            self.model_dim, device=device, dtype=dtype
        )

        self.self_attn_layer_norm = self_attn_layer_norm

        self.self_attn = self_attn

        if dropout_p > 0.0:
            self.self_attn_dropout = Dropout(dropout_p)
        else:
            self.register_module("self_attn_dropout", None)

        encoder_decoder_attn_layer_norm = create_standard_layer_norm(
            self.model_dim, device=device, dtype=dtype
        )

        self.encoder_decoder_attn_layer_norm = encoder_decoder_attn_layer_norm

        self.encoder_decoder_attn = encoder_decoder_attn

        if dropout_p > 0.0:
            self.encoder_decoder_attn_dropout = Dropout(dropout_p)
        else:
            self.register_module("encoder_decoder_attn_dropout", None)

        self.p_choose_layer = p_choose_layer

        ffn_layer_norm = create_standard_layer_norm(
            self.model_dim, device=device, dtype=dtype
        )

        self.ffn_layer_norm = ffn_layer_norm

        self.ffn = ffn

        if dropout_p > 0.0:
            self.ffn_dropout = Dropout(dropout_p)
        else:
            self.register_module("ffn_dropout", None)

    @finaloverride
    def forward(
        self,
        seqs: Tensor,
        padding_mask: Optional[PaddingMask],
        self_attn_mask: Optional[AttentionMask] = None,
        encoder_output: Optional[Tensor] = None,
        encoder_padding_mask: Optional[PaddingMask] = None,
        *,
        state_bag: Optional[IncrementalStateBag] = None,
    ) -> Tuple[Tensor, Optional[PaddingMask], Tensor]:
        seqs = self._forward_self_attn(seqs, padding_mask, self_attn_mask, state_bag)

        seqs, p_choose = self._forward_encoder_decoder_attn(
            seqs, padding_mask, encoder_output, encoder_padding_mask
        )

        seqs = self._forward_ffn(seqs)

        return seqs, padding_mask, p_choose

    def _forward_self_attn(
        self,
        seqs: Tensor,
        padding_mask: Optional[PaddingMask],
        self_attn_mask: Optional[AttentionMask],
        state_bag: Optional[IncrementalStateBag],
    ) -> Tensor:
        residual = seqs

        seqs = self.self_attn_layer_norm(seqs)

        seqs = self.self_attn(
            seqs,
            padding_mask,
            keys=seqs,
            key_padding_mask=padding_mask,
            values=seqs,
            attn_mask=self_attn_mask,
            state_bag=state_bag,
        )

        if self.self_attn_dropout is not None:
            seqs = self.self_attn_dropout(seqs)

        seqs = seqs + residual

        return seqs

    def _forward_encoder_decoder_attn(
        self,
        seqs: Tensor,
        padding_mask: Optional[PaddingMask],
        encoder_output: Optional[Tensor],
        encoder_padding_mask: Optional[PaddingMask],
    ) -> Tuple[Tensor, Tensor]:
        if encoder_output is None:
            raise ValueError(
                "`encoder_output` must not be `None` for encoder-decoder attention."
            )

        residual = seqs

        seqs = self.encoder_decoder_attn_layer_norm(seqs)

        p_choose = self.p_choose_layer(seqs, encoder_output)

        seqs = self.encoder_decoder_attn(
            seqs,
            padding_mask,
            encoder_output,
            encoder_padding_mask,
            encoder_output,
        )

        if self.encoder_decoder_attn_dropout is not None:
            seqs = self.encoder_decoder_attn_dropout(seqs)

        seqs = seqs + residual

        return seqs, p_choose

    def _forward_ffn(self, seqs: Tensor) -> Tensor:
        residual = seqs

        seqs = self.ffn_layer_norm(seqs)

        seqs = self.ffn(seqs)

        if self.ffn_dropout is not None:
            seqs = self.ffn_dropout(seqs)

        seqs = seqs + residual

        return seqs