File size: 5,862 Bytes
c9852e4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# MIT_LICENSE file in the root directory of this source tree.
from typing import Optional, Tuple, final
from fairseq2.nn.incremental_state import IncrementalStateBag
from fairseq2.nn.normalization import LayerNorm
from fairseq2.nn.padding import PaddingMask
from fairseq2.nn.transformer import (
AttentionMask,
FeedForwardNetwork,
MultiheadAttention,
create_standard_layer_norm,
)
from fairseq2.typing import DataType, Device, finaloverride
from torch import Tensor
from torch.nn import Dropout, Module
from seamless_communication.models.monotonic_decoder.p_choose import PChooseLayer
@final
class MonotonicTransformerDecoderLayer(Module):
"""Represents a Monotonic Transformer decoder layer."""
self_attn: MultiheadAttention
self_attn_dropout: Optional[Dropout]
self_attn_layer_norm: LayerNorm
encoder_decoder_attn: MultiheadAttention
encoder_decoder_attn_dropout: Optional[Dropout]
encoder_decoder_attn_layer_norm: LayerNorm
p_choose_layer: PChooseLayer
ffn: FeedForwardNetwork
ffn_dropout: Optional[Dropout]
ffn_layer_norm: LayerNorm
def __init__(
self,
self_attn: MultiheadAttention,
encoder_decoder_attn: MultiheadAttention,
p_choose_layer: PChooseLayer,
ffn: FeedForwardNetwork,
*,
dropout_p: float = 0.1,
device: Optional[Device] = None,
dtype: Optional[DataType] = None,
) -> None:
"""
:param self_attn:
The self attention layer.
:param encoder_decoder_attn:
The encoder-decoder attention layer.
:param ffn:
The feed-forward network.
:param dropout_p:
The dropout probability on outputs of the attention layers and the
feed-forward network.
"""
super().__init__()
self.model_dim = self_attn.model_dim
self_attn_layer_norm = create_standard_layer_norm(
self.model_dim, device=device, dtype=dtype
)
self.self_attn_layer_norm = self_attn_layer_norm
self.self_attn = self_attn
if dropout_p > 0.0:
self.self_attn_dropout = Dropout(dropout_p)
else:
self.register_module("self_attn_dropout", None)
encoder_decoder_attn_layer_norm = create_standard_layer_norm(
self.model_dim, device=device, dtype=dtype
)
self.encoder_decoder_attn_layer_norm = encoder_decoder_attn_layer_norm
self.encoder_decoder_attn = encoder_decoder_attn
if dropout_p > 0.0:
self.encoder_decoder_attn_dropout = Dropout(dropout_p)
else:
self.register_module("encoder_decoder_attn_dropout", None)
self.p_choose_layer = p_choose_layer
ffn_layer_norm = create_standard_layer_norm(
self.model_dim, device=device, dtype=dtype
)
self.ffn_layer_norm = ffn_layer_norm
self.ffn = ffn
if dropout_p > 0.0:
self.ffn_dropout = Dropout(dropout_p)
else:
self.register_module("ffn_dropout", None)
@finaloverride
def forward(
self,
seqs: Tensor,
padding_mask: Optional[PaddingMask],
self_attn_mask: Optional[AttentionMask] = None,
encoder_output: Optional[Tensor] = None,
encoder_padding_mask: Optional[PaddingMask] = None,
*,
state_bag: Optional[IncrementalStateBag] = None,
) -> Tuple[Tensor, Optional[PaddingMask], Tensor]:
seqs = self._forward_self_attn(seqs, padding_mask, self_attn_mask, state_bag)
seqs, p_choose = self._forward_encoder_decoder_attn(
seqs, padding_mask, encoder_output, encoder_padding_mask
)
seqs = self._forward_ffn(seqs)
return seqs, padding_mask, p_choose
def _forward_self_attn(
self,
seqs: Tensor,
padding_mask: Optional[PaddingMask],
self_attn_mask: Optional[AttentionMask],
state_bag: Optional[IncrementalStateBag],
) -> Tensor:
residual = seqs
seqs = self.self_attn_layer_norm(seqs)
seqs = self.self_attn(
seqs,
padding_mask,
keys=seqs,
key_padding_mask=padding_mask,
values=seqs,
attn_mask=self_attn_mask,
state_bag=state_bag,
)
if self.self_attn_dropout is not None:
seqs = self.self_attn_dropout(seqs)
seqs = seqs + residual
return seqs
def _forward_encoder_decoder_attn(
self,
seqs: Tensor,
padding_mask: Optional[PaddingMask],
encoder_output: Optional[Tensor],
encoder_padding_mask: Optional[PaddingMask],
) -> Tuple[Tensor, Tensor]:
if encoder_output is None:
raise ValueError(
"`encoder_output` must not be `None` for encoder-decoder attention."
)
residual = seqs
seqs = self.encoder_decoder_attn_layer_norm(seqs)
p_choose = self.p_choose_layer(seqs, encoder_output)
seqs = self.encoder_decoder_attn(
seqs,
padding_mask,
encoder_output,
encoder_padding_mask,
encoder_output,
)
if self.encoder_decoder_attn_dropout is not None:
seqs = self.encoder_decoder_attn_dropout(seqs)
seqs = seqs + residual
return seqs, p_choose
def _forward_ffn(self, seqs: Tensor) -> Tensor:
residual = seqs
seqs = self.ffn_layer_norm(seqs)
seqs = self.ffn(seqs)
if self.ffn_dropout is not None:
seqs = self.ffn_dropout(seqs)
seqs = seqs + residual
return seqs
|