victan commited on
Commit
5ec225a
·
1 Parent(s): 719e3d4

Upload seamless_communication/models/monotonic_decoder/model.py with huggingface_hub

Browse files
seamless_communication/models/monotonic_decoder/model.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # MIT_LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import Optional, Tuple, final
8
+
9
+ from fairseq2.models.transformer.frontend import TransformerFrontend
10
+ from fairseq2.nn.incremental_state import IncrementalStateBag
11
+ from fairseq2.nn.padding import PaddingMask
12
+ from fairseq2.nn.projection import Projection
13
+ from overrides import final as finaloverride
14
+ from torch import Tensor
15
+ from torch.nn import Module
16
+
17
+ from seamless_communication.models.monotonic_decoder.monotonic_decoder import (
18
+ MonotonicTransformerDecoder,
19
+ )
20
+
21
+
22
+ @final
23
+ class MonotonicDecoderModel(Module):
24
+ text_decoder_frontend: TransformerFrontend
25
+ text_decoder: MonotonicTransformerDecoder
26
+ final_proj: Projection
27
+
28
+ def __init__(
29
+ self,
30
+ text_decoder_frontend: TransformerFrontend,
31
+ text_decoder: MonotonicTransformerDecoder,
32
+ final_proj: Projection,
33
+ ) -> None:
34
+ super().__init__()
35
+
36
+ self.text_decoder_frontend = text_decoder_frontend
37
+ self.text_decoder = text_decoder
38
+ self.final_proj = final_proj
39
+
40
+ @finaloverride
41
+ def decode(
42
+ self,
43
+ seqs: Tensor,
44
+ padding_mask: Optional[PaddingMask],
45
+ encoder_output: Tensor,
46
+ encoder_padding_mask: Optional[PaddingMask],
47
+ *,
48
+ state_bag: Optional[IncrementalStateBag] = None,
49
+ ) -> Tuple[Tensor, Optional[PaddingMask], Tensor]:
50
+ seqs, padding_mask = self.text_decoder_frontend(
51
+ seqs, padding_mask, state_bag=state_bag
52
+ )
53
+
54
+ return self.text_decoder( # type: ignore[no-any-return]
55
+ seqs,
56
+ padding_mask,
57
+ encoder_output,
58
+ encoder_padding_mask,
59
+ state_bag=state_bag,
60
+ )
61
+
62
+ @finaloverride
63
+ def project(self, decoder_output: Tensor) -> Tensor:
64
+ logits = self.final_proj(decoder_output)
65
+
66
+ return logits # type: ignore[no-any-return]