|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Optional, Tuple, final |
|
|
|
from fairseq2.models.transformer.frontend import TransformerFrontend |
|
from fairseq2.nn.incremental_state import IncrementalStateBag |
|
from fairseq2.nn.padding import PaddingMask |
|
from fairseq2.nn.projection import Projection |
|
from overrides import final as finaloverride |
|
from torch import Tensor |
|
from torch.nn import Module |
|
|
|
from seamless_communication.models.monotonic_decoder.monotonic_decoder import ( |
|
MonotonicTransformerDecoder, |
|
) |
|
|
|
|
|
@final |
|
class MonotonicDecoderModel(Module): |
|
text_decoder_frontend: TransformerFrontend |
|
text_decoder: MonotonicTransformerDecoder |
|
final_proj: Projection |
|
|
|
def __init__( |
|
self, |
|
text_decoder_frontend: TransformerFrontend, |
|
text_decoder: MonotonicTransformerDecoder, |
|
final_proj: Projection, |
|
) -> None: |
|
super().__init__() |
|
|
|
self.text_decoder_frontend = text_decoder_frontend |
|
self.text_decoder = text_decoder |
|
self.final_proj = final_proj |
|
|
|
@finaloverride |
|
def decode( |
|
self, |
|
seqs: Tensor, |
|
padding_mask: Optional[PaddingMask], |
|
encoder_output: Tensor, |
|
encoder_padding_mask: Optional[PaddingMask], |
|
*, |
|
state_bag: Optional[IncrementalStateBag] = None, |
|
) -> Tuple[Tensor, Optional[PaddingMask], Tensor]: |
|
seqs, padding_mask = self.text_decoder_frontend( |
|
seqs, padding_mask, state_bag=state_bag |
|
) |
|
|
|
return self.text_decoder( |
|
seqs, |
|
padding_mask, |
|
encoder_output, |
|
encoder_padding_mask, |
|
state_bag=state_bag, |
|
) |
|
|
|
@finaloverride |
|
def project(self, decoder_output: Tensor) -> Tensor: |
|
logits = self.final_proj(decoder_output) |
|
|
|
return logits |
|
|