victan commited on
Commit
8266600
·
1 Parent(s): 411bb63

Upload seamless_communication/models/aligner/builder.py with huggingface_hub

Browse files
seamless_communication/models/aligner/builder.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # MIT_LICENSE file in the root directory of this source tree.
6
+
7
+ from dataclasses import dataclass
8
+ from typing import Optional, Union
9
+
10
+ import torch
11
+ from fairseq2.assets.card import AssetCard
12
+ from fairseq2.data.vocabulary_info import VocabularyInfo
13
+ from fairseq2.models.utils.arch_registry import ArchitectureRegistry
14
+ from fairseq2.nn.embedding import StandardEmbedding, init_scaled_embedding
15
+ from fairseq2.typing import DataType, Device
16
+
17
+ from seamless_communication.models.aligner.model import (
18
+ UnitY2AlignmentEncoder,
19
+ UnitY2AlignmentFrontend,
20
+ UnitY2AlignmentModel,
21
+ )
22
+ from seamless_communication.models.unity.char_tokenizer import load_unity_char_tokenizer
23
+ from seamless_communication.models.unity.loader import load_unity_unit_tokenizer
24
+
25
+
26
+ @dataclass
27
+ class AlignmentEncoderConfig:
28
+ model_dim: int
29
+
30
+ feat_dim: int
31
+
32
+ num_text_layers: int
33
+
34
+ num_feat_layers: int
35
+
36
+ dropout: float
37
+
38
+ temperature: float
39
+
40
+ reduction_factor: int
41
+
42
+
43
+ @dataclass
44
+ class UnitY2AlignmentFrontendConfig:
45
+ unit_vocab_info: VocabularyInfo
46
+
47
+ text_vocab_size: int
48
+
49
+
50
+ @dataclass
51
+ class UnitY2AlignmentConfig:
52
+ model_name_or_card: Union[str, AssetCard]
53
+
54
+ alignment_encoder_config: AlignmentEncoderConfig
55
+
56
+ alignment_frontend_config: UnitY2AlignmentFrontendConfig
57
+
58
+
59
+ aligner_archs = ArchitectureRegistry[UnitY2AlignmentConfig]("unity2_aligner")
60
+
61
+ aligner_arch = aligner_archs.decorator
62
+
63
+
64
+ @aligner_arch("nar_t2u_aligner")
65
+ def _aligner_nar_t2u() -> UnitY2AlignmentConfig:
66
+ encoder_config = AlignmentEncoderConfig(
67
+ model_dim=1024,
68
+ feat_dim=1024,
69
+ num_text_layers=2,
70
+ num_feat_layers=3,
71
+ dropout=0.1,
72
+ temperature=1.0,
73
+ reduction_factor=1,
74
+ )
75
+
76
+ frontend_config = UnitY2AlignmentFrontendConfig(
77
+ unit_vocab_info=VocabularyInfo(
78
+ size=10082, unk_idx=3, bos_idx=0, eos_idx=2, pad_idx=1
79
+ ),
80
+ text_vocab_size=10943,
81
+ )
82
+
83
+ return UnitY2AlignmentConfig(
84
+ model_name_or_card="nar_t2u_aligner",
85
+ alignment_encoder_config=encoder_config,
86
+ alignment_frontend_config=frontend_config,
87
+ )
88
+
89
+
90
+ class UnitY2AlignmentBuilder:
91
+ config: UnitY2AlignmentConfig
92
+ device: Optional[Device]
93
+ dtype: DataType
94
+
95
+ def __init__(
96
+ self,
97
+ config: UnitY2AlignmentConfig,
98
+ *,
99
+ device: Optional[Device] = None,
100
+ dtype: DataType = torch.float32,
101
+ ) -> None:
102
+ """
103
+ :param config:
104
+ The configuration to use.
105
+ :param device:
106
+ The device on which to initialize modules.
107
+ :param dtype:
108
+ The data type of module parameters and buffers.
109
+ """
110
+ self.config = config
111
+
112
+ self.device, self.dtype = device, dtype
113
+
114
+ def build_model(self) -> UnitY2AlignmentModel:
115
+ alignment_frontend = self.build_alignment_frontend()
116
+
117
+ alignment_encoder = self.build_alignment_encoder()
118
+
119
+ return UnitY2AlignmentModel(alignment_frontend, alignment_encoder)
120
+
121
+ def build_alignment_frontend(self) -> UnitY2AlignmentFrontend:
122
+ text_tokenizer = load_unity_char_tokenizer(self.config.model_name_or_card)
123
+
124
+ unit_tokenizer = load_unity_unit_tokenizer(self.config.model_name_or_card)
125
+
126
+ embed_text = StandardEmbedding(
127
+ num_embeddings=self.config.alignment_frontend_config.text_vocab_size,
128
+ embedding_dim=self.config.alignment_encoder_config.model_dim,
129
+ pad_idx=self.config.alignment_frontend_config.unit_vocab_info.pad_idx,
130
+ init_fn=init_scaled_embedding,
131
+ device=self.device,
132
+ dtype=self.dtype,
133
+ )
134
+
135
+ embed_unit = StandardEmbedding(
136
+ num_embeddings=self.config.alignment_frontend_config.unit_vocab_info.size,
137
+ embedding_dim=self.config.alignment_encoder_config.model_dim,
138
+ pad_idx=self.config.alignment_frontend_config.unit_vocab_info.pad_idx,
139
+ init_fn=init_scaled_embedding,
140
+ device=self.device,
141
+ dtype=self.dtype,
142
+ )
143
+
144
+ return UnitY2AlignmentFrontend(
145
+ embed_text, embed_unit, text_tokenizer, unit_tokenizer
146
+ )
147
+
148
+ def build_alignment_encoder(self, training: bool = False) -> UnitY2AlignmentEncoder:
149
+ cfg = self.config.alignment_encoder_config
150
+ alignment_encoder = UnitY2AlignmentEncoder(
151
+ embed_dim=cfg.model_dim,
152
+ feat_dim=cfg.feat_dim,
153
+ text_layers=cfg.num_text_layers,
154
+ feat_layers=cfg.num_feat_layers,
155
+ dropout=cfg.dropout,
156
+ temperature=cfg.temperature,
157
+ reduction_factor=cfg.reduction_factor,
158
+ dtype=self.dtype,
159
+ )
160
+ alignment_encoder.training = training
161
+
162
+ return alignment_encoder
163
+
164
+
165
+ def create_unity2_alignment_model(
166
+ config: UnitY2AlignmentConfig,
167
+ device: Optional[Device] = None,
168
+ dtype: DataType = torch.float32,
169
+ ) -> UnitY2AlignmentModel:
170
+ """Create a UnitY model.
171
+
172
+ :param config:
173
+ The configuration to use.
174
+ :param device:
175
+ The device on which to initialize modules.
176
+ :param dtype:
177
+ The data type of module parameters and buffers.
178
+ """
179
+
180
+ unity2_aligner_builder = UnitY2AlignmentBuilder(
181
+ config,
182
+ device=device,
183
+ dtype=dtype,
184
+ )
185
+
186
+ return unity2_aligner_builder.build_model()