victan's picture
Upload seamless_communication/models/monotonic_decoder/model.py with huggingface_hub
5ec225a
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# MIT_LICENSE file in the root directory of this source tree.
from typing import Optional, Tuple, final
from fairseq2.models.transformer.frontend import TransformerFrontend
from fairseq2.nn.incremental_state import IncrementalStateBag
from fairseq2.nn.padding import PaddingMask
from fairseq2.nn.projection import Projection
from overrides import final as finaloverride
from torch import Tensor
from torch.nn import Module
from seamless_communication.models.monotonic_decoder.monotonic_decoder import (
MonotonicTransformerDecoder,
)
@final
class MonotonicDecoderModel(Module):
text_decoder_frontend: TransformerFrontend
text_decoder: MonotonicTransformerDecoder
final_proj: Projection
def __init__(
self,
text_decoder_frontend: TransformerFrontend,
text_decoder: MonotonicTransformerDecoder,
final_proj: Projection,
) -> None:
super().__init__()
self.text_decoder_frontend = text_decoder_frontend
self.text_decoder = text_decoder
self.final_proj = final_proj
@finaloverride
def decode(
self,
seqs: Tensor,
padding_mask: Optional[PaddingMask],
encoder_output: Tensor,
encoder_padding_mask: Optional[PaddingMask],
*,
state_bag: Optional[IncrementalStateBag] = None,
) -> Tuple[Tensor, Optional[PaddingMask], Tensor]:
seqs, padding_mask = self.text_decoder_frontend(
seqs, padding_mask, state_bag=state_bag
)
return self.text_decoder( # type: ignore[no-any-return]
seqs,
padding_mask,
encoder_output,
encoder_padding_mask,
state_bag=state_bag,
)
@finaloverride
def project(self, decoder_output: Tensor) -> Tensor:
logits = self.final_proj(decoder_output)
return logits # type: ignore[no-any-return]