jefson08 commited on
Commit
2c42af0
·
verified ·
1 Parent(s): b81eef1

Upload 13 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ model.TGT filter=lfs diff=lfs merge=lfs -text
config.json ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "checkpoint-1200",
3
+ "activation_dropout": 0.0,
4
+ "activation_function": "gelu",
5
+ "architectures": [
6
+ "IndicTransForConditionalGeneration"
7
+ ],
8
+ "attention_dropout": 0.0,
9
+ "attn_implementation": null,
10
+ "auto_map": {
11
+ "AutoConfig": "configuration_indictrans.IndicTransConfig",
12
+ "AutoModelForSeq2SeqLM": "modeling_indictrans.IndicTransForConditionalGeneration"
13
+ },
14
+ "bos_token_id": 0,
15
+ "decoder_attention_heads": 16,
16
+ "decoder_embed_dim": 1024,
17
+ "decoder_ffn_dim": 8192,
18
+ "decoder_layerdrop": 0,
19
+ "decoder_layers": 18,
20
+ "decoder_normalize_before": true,
21
+ "decoder_start_token_id": 2,
22
+ "decoder_vocab_size": 126655,
23
+ "dropout": 0.2,
24
+ "encoder_attention_heads": 16,
25
+ "encoder_embed_dim": 1024,
26
+ "encoder_ffn_dim": 8192,
27
+ "encoder_layerdrop": 0,
28
+ "encoder_layers": 18,
29
+ "encoder_normalize_before": true,
30
+ "encoder_vocab_size": 32322,
31
+ "eos_token_id": 2,
32
+ "init_std": 0.02,
33
+ "is_encoder_decoder": true,
34
+ "layernorm_embedding": false,
35
+ "max_source_positions": 256,
36
+ "max_target_positions": 256,
37
+ "model_type": "IndicTrans",
38
+ "num_hidden_layers": 18,
39
+ "pad_token_id": 1,
40
+ "scale_embedding": true,
41
+ "share_decoder_input_output_embed": false,
42
+ "tokenizer_class": "IndicTransTokenizer",
43
+ "torch_dtype": "bfloat16",
44
+ "transformers_version": "4.44.2",
45
+ "use_cache": true
46
+ }
configuration_indictrans.py ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The IndicTrans2 Authors and AI4Bharat team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ PyTorch IndicTrans config."""
16
+
17
+
18
+ from collections import OrderedDict
19
+ from typing import Any, Mapping, Optional
20
+
21
+ from transformers import PreTrainedTokenizer
22
+ from transformers.configuration_utils import PretrainedConfig
23
+ from transformers.onnx import OnnxConfig, OnnxSeq2SeqConfigWithPast
24
+ from transformers.onnx.utils import compute_effective_axis_dimension
25
+ from transformers.utils import TensorType, is_torch_available
26
+
27
+
28
+ # Copied from transformers.models.m2m_100.configuration_m2m_100.M2M100Config->IndicTrans
29
+ class IndicTransConfig(PretrainedConfig):
30
+ r"""
31
+ This is the configuration class to store the configuration of a [`IT2Model`]. It is used to instantiate an
32
+ IT2 model according to the specified arguments, defining the model architecture. Instantiating a configuration
33
+ with the defaults will yield a similar configuration to that of the IT2
34
+
35
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
36
+ documentation from [`PretrainedConfig`] for more information.
37
+
38
+
39
+ Args:
40
+ vocab_size (`int`, *optional*, defaults to 50265):
41
+ Vocabulary size of the IT2 model. Defines the number of different tokens that can be represented by the
42
+ `inputs_ids` passed when calling [`IT2Model`] or
43
+ d_model (`int`, *optional*, defaults to 1024):
44
+ Dimensionality of the layers and the pooler layer.
45
+ encoder_layers (`int`, *optional*, defaults to 12):
46
+ Number of encoder layers.
47
+ decoder_layers (`int`, *optional*, defaults to 12):
48
+ Number of decoder layers.
49
+ encoder_attention_heads (`int`, *optional*, defaults to 16):
50
+ Number of attention heads for each attention layer in the Transformer encoder.
51
+ decoder_attention_heads (`int`, *optional*, defaults to 16):
52
+ Number of attention heads for each attention layer in the Transformer decoder.
53
+ decoder_ffn_dim (`int`, *optional*, defaults to 4096):
54
+ Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
55
+ encoder_ffn_dim (`int`, *optional*, defaults to 4096):
56
+ Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
57
+ activation_function (`str` or `function`, *optional*, defaults to `"gelu"`):
58
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
59
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
60
+ dropout (`float`, *optional*, defaults to 0.1):
61
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
62
+ attention_dropout (`float`, *optional*, defaults to 0.0):
63
+ The dropout ratio for the attention probabilities.
64
+ activation_dropout (`float`, *optional*, defaults to 0.0):
65
+ The dropout ratio for activations inside the fully connected layer.
66
+ classifier_dropout (`float`, *optional*, defaults to 0.0):
67
+ The dropout ratio for classifier.
68
+ max_position_embeddings (`int`, *optional*, defaults to 1024):
69
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
70
+ just in case (e.g., 512 or 1024 or 2048).
71
+ init_std (`float`, *optional*, defaults to 0.02):
72
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
73
+ encoder_layerdrop (`float`, *optional*, defaults to 0.0):
74
+ The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
75
+ for more details.
76
+ decoder_layerdrop (`float`, *optional*, defaults to 0.0):
77
+ The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
78
+ for more details.
79
+ use_cache (`bool`, *optional*, defaults to `True`):
80
+ Whether or not the model should return the last key/values attentions (not used by all models).
81
+ ```"""
82
+ model_type = "IndicTrans"
83
+ keys_to_ignore_at_inference = ["past_key_values"]
84
+ attribute_map = {
85
+ "num_attention_heads": "encoder_attention_heads",
86
+ "hidden_size": "d_model",
87
+ }
88
+
89
+ def __init__(
90
+ self,
91
+ encoder_vocab_size=None,
92
+ decoder_vocab_size=None,
93
+ encoder_embed_dim=512,
94
+ decoder_embed_dim=512,
95
+ max_source_positions=210,
96
+ max_target_positions=210,
97
+ encoder_layers=6,
98
+ encoder_ffn_dim=2048,
99
+ encoder_attention_heads=8,
100
+ decoder_layers=6,
101
+ decoder_ffn_dim=2048,
102
+ decoder_attention_heads=8,
103
+ encoder_layerdrop=0.00,
104
+ decoder_layerdrop=0.00,
105
+ use_cache=True,
106
+ is_encoder_decoder=True,
107
+ activation_function="relu",
108
+ encoder_normalize_before=False,
109
+ decoder_normalize_before=False,
110
+ layernorm_embedding=False,
111
+ share_decoder_input_output_embed=False,
112
+ dropout=0.1,
113
+ attention_dropout=0.0,
114
+ activation_dropout=0.0,
115
+ init_std=0.02,
116
+ scale_embedding=True,
117
+ decoder_start_token_id=2,
118
+ pad_token_id=1,
119
+ bos_token_id=0,
120
+ eos_token_id=2,
121
+ attn_implementation="eager",
122
+ **kwargs,
123
+ ):
124
+ self.encoder_vocab_size = encoder_vocab_size
125
+ self.decoder_vocab_size = decoder_vocab_size
126
+ self.encoder_normalize_before = encoder_normalize_before
127
+ self.decoder_normalize_before = decoder_normalize_before
128
+ self.layernorm_embedding = layernorm_embedding
129
+ self.max_source_positions = max_source_positions
130
+ self.max_target_positions = max_target_positions
131
+ self.encoder_embed_dim = encoder_embed_dim
132
+ self.decoder_embed_dim = decoder_embed_dim
133
+ self.encoder_ffn_dim = encoder_ffn_dim
134
+ self.encoder_layers = encoder_layers
135
+ self.encoder_attention_heads = encoder_attention_heads
136
+ self.decoder_ffn_dim = decoder_ffn_dim
137
+ self.decoder_layers = decoder_layers
138
+ self.decoder_attention_heads = decoder_attention_heads
139
+ self.dropout = dropout
140
+ self.attention_dropout = attention_dropout
141
+ self.activation_dropout = activation_dropout
142
+ self.activation_function = activation_function
143
+ self.init_std = init_std
144
+ self.encoder_layerdrop = encoder_layerdrop
145
+ self.decoder_layerdrop = decoder_layerdrop
146
+ self.use_cache = use_cache
147
+ self.num_hidden_layers = encoder_layers
148
+ self.scale_embedding = scale_embedding
149
+ self.share_decoder_input_output_embed = share_decoder_input_output_embed
150
+ self.attn_implementation = attn_implementation
151
+
152
+ super().__init__(
153
+ pad_token_id=pad_token_id,
154
+ bos_token_id=bos_token_id,
155
+ eos_token_id=eos_token_id,
156
+ is_encoder_decoder=is_encoder_decoder,
157
+ decoder_start_token_id=decoder_start_token_id,
158
+ **kwargs,
159
+ )
160
+
161
+
162
+ class IndicTransOnnxConfig(OnnxSeq2SeqConfigWithPast):
163
+ @property
164
+ def inputs(self) -> Mapping[str, Mapping[int, str]]:
165
+ common_inputs = OrderedDict(
166
+ [
167
+ ("input_ids", {0: "batch", 1: "encoder_sequence"}),
168
+ ("attention_mask", {0: "batch", 1: "encoder_sequence"}),
169
+ ]
170
+ )
171
+
172
+ if self.use_past:
173
+ common_inputs["decoder_input_ids"] = {0: "batch"}
174
+ common_inputs["decoder_attention_mask"] = {
175
+ 0: "batch",
176
+ 1: "past_decoder_sequence + sequence",
177
+ }
178
+ else:
179
+ common_inputs["decoder_input_ids"] = {0: "batch", 1: "decoder_sequence"}
180
+ common_inputs["decoder_attention_mask"] = {
181
+ 0: "batch",
182
+ 1: "decoder_sequence",
183
+ }
184
+
185
+ if self.use_past:
186
+ self.fill_with_past_key_values_(common_inputs, direction="inputs")
187
+ return common_inputs
188
+
189
+ # Copied from BartOnnxConfig._generate_dummy_inputs_for_sequence_classification_and_question_answering
190
+ # A better name would be _generate_dummy_inputs_for_encoder_and_decoder because sequence classification and question
191
+ # answering are not supported for IT2, but this name is preserved to be able to check that the copy matches what
192
+ # was done for BART so that it can be updated if need be.
193
+ def _generate_dummy_inputs_for_sequence_classification_and_question_answering(
194
+ self,
195
+ tokenizer: PreTrainedTokenizer,
196
+ batch_size: int = -1,
197
+ seq_length: int = -1,
198
+ is_pair: bool = False,
199
+ framework: Optional[TensorType] = None,
200
+ ) -> Mapping[str, Any]:
201
+ # Copied from OnnxConfig.generate_dummy_inputs
202
+ # Did not use super(OnnxConfigWithPast, self).generate_dummy_inputs for code clarity.
203
+ # If dynamic axis (-1) we forward with a fixed dimension of 2 samples to avoid optimizations made by ONNX
204
+ batch_size = compute_effective_axis_dimension(
205
+ batch_size,
206
+ fixed_dimension=OnnxConfig.default_fixed_batch,
207
+ num_token_to_add=0,
208
+ )
209
+
210
+ # If dynamic axis (-1) we forward with a fixed dimension of 8 tokens to avoid optimizations made by ONNX
211
+ token_to_add = tokenizer.num_special_tokens_to_add(is_pair)
212
+ seq_length = compute_effective_axis_dimension(
213
+ seq_length,
214
+ fixed_dimension=OnnxConfig.default_fixed_sequence,
215
+ num_token_to_add=token_to_add,
216
+ )
217
+
218
+ # Generate dummy inputs according to compute batch and sequence
219
+ dummy_input = [" ".join([tokenizer.unk_token]) * seq_length] * batch_size
220
+ common_inputs = dict(tokenizer(dummy_input, return_tensors=framework))
221
+ return common_inputs
222
+
223
+ # Copied from transformers.models.bart.configuration_bart.BartOnnxConfig._generate_dummy_inputs_for_default_and_seq2seq_lm
224
+ def _generate_dummy_inputs_for_default_and_seq2seq_lm(
225
+ self,
226
+ tokenizer: PreTrainedTokenizer,
227
+ batch_size: int = -1,
228
+ seq_length: int = -1,
229
+ is_pair: bool = False,
230
+ framework: Optional[TensorType] = None,
231
+ ) -> Mapping[str, Any]:
232
+ encoder_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering(
233
+ tokenizer, batch_size, seq_length, is_pair, framework
234
+ )
235
+
236
+ # Generate decoder inputs
237
+ decoder_seq_length = seq_length if not self.use_past else 1
238
+ decoder_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering(
239
+ tokenizer, batch_size, decoder_seq_length, is_pair, framework
240
+ )
241
+ decoder_inputs = {
242
+ f"decoder_{name}": tensor for name, tensor in decoder_inputs.items()
243
+ }
244
+ common_inputs = dict(**encoder_inputs, **decoder_inputs)
245
+
246
+ if self.use_past:
247
+ if not is_torch_available():
248
+ raise ValueError(
249
+ "Cannot generate dummy past_keys inputs without PyTorch installed."
250
+ )
251
+ else:
252
+ import torch
253
+ batch, encoder_seq_length = common_inputs["input_ids"].shape
254
+ decoder_seq_length = common_inputs["decoder_input_ids"].shape[1]
255
+ (
256
+ num_encoder_attention_heads,
257
+ num_decoder_attention_heads,
258
+ ) = self.num_attention_heads
259
+ encoder_shape = (
260
+ batch,
261
+ num_encoder_attention_heads,
262
+ encoder_seq_length,
263
+ self._config.hidden_size // num_encoder_attention_heads,
264
+ )
265
+ decoder_past_length = decoder_seq_length + 3
266
+ decoder_shape = (
267
+ batch,
268
+ num_decoder_attention_heads,
269
+ decoder_past_length,
270
+ self._config.hidden_size // num_decoder_attention_heads,
271
+ )
272
+
273
+ common_inputs["decoder_attention_mask"] = torch.cat(
274
+ [
275
+ common_inputs["decoder_attention_mask"],
276
+ torch.ones(batch, decoder_past_length),
277
+ ],
278
+ dim=1,
279
+ )
280
+
281
+ common_inputs["past_key_values"] = []
282
+ # If the number of encoder and decoder layers are present in the model configuration, both are considered
283
+ num_encoder_layers, num_decoder_layers = self.num_layers
284
+ min_num_layers = min(num_encoder_layers, num_decoder_layers)
285
+ max_num_layers = (
286
+ max(num_encoder_layers, num_decoder_layers) - min_num_layers
287
+ )
288
+ remaining_side_name = (
289
+ "encoder" if num_encoder_layers > num_decoder_layers else "decoder"
290
+ )
291
+
292
+ for _ in range(min_num_layers):
293
+ common_inputs["past_key_values"].append(
294
+ (
295
+ torch.zeros(decoder_shape),
296
+ torch.zeros(decoder_shape),
297
+ torch.zeros(encoder_shape),
298
+ torch.zeros(encoder_shape),
299
+ )
300
+ )
301
+ # TODO: test this.
302
+ shape = encoder_shape if remaining_side_name == "encoder" else decoder_shape
303
+ for _ in range(min_num_layers, max_num_layers):
304
+ common_inputs["past_key_values"].append(
305
+ (torch.zeros(shape), torch.zeros(shape))
306
+ )
307
+ return common_inputs
308
+
309
+ generate_dummy_inputs = _generate_dummy_inputs_for_default_and_seq2seq_lm
dict.SRC.json ADDED
The diff for this file is too large to render. See raw diff
 
dict.TGT.json ADDED
The diff for this file is too large to render. See raw diff
 
generation_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token_id": 0,
3
+ "decoder_start_token_id": 2,
4
+ "eos_token_id": 2,
5
+ "pad_token_id": 1,
6
+ "transformers_version": "4.44.2"
7
+ }
model.SRC ADDED
Binary file (759 kB). View file
 
model.TGT ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ac9257c8e76b8b607705b959cc3d075656ea33032f7a974e467b8941df6e98d4
3
+ size 3256903
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ea67a27d22d8aed4603f79437cd9a36cdf096acf2c62df9dd9d20d7343546e27
3
+ size 2247492800
modeling_indictrans.py ADDED
@@ -0,0 +1,1801 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The IndicTrans2 Authors and AI4Bharat team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ PyTorch IndicTrans model."""
16
+
17
+
18
+ import math
19
+ from typing import List, Optional, Tuple, Union
20
+
21
+ import torch
22
+ import torch.nn as nn
23
+ from torch.nn import functional as F
24
+
25
+ from transformers.activations import ACT2FN
26
+
27
+ from transformers.modeling_attn_mask_utils import (
28
+ _prepare_4d_attention_mask,
29
+ _prepare_4d_attention_mask_for_sdpa,
30
+ _prepare_4d_causal_attention_mask,
31
+ _prepare_4d_causal_attention_mask_for_sdpa,
32
+ )
33
+
34
+ from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
35
+ from transformers.modeling_outputs import (
36
+ BaseModelOutput,
37
+ BaseModelOutputWithPastAndCrossAttentions,
38
+ Seq2SeqLMOutput,
39
+ Seq2SeqModelOutput
40
+ )
41
+
42
+ from transformers.utils import (
43
+ logging,
44
+ is_flash_attn_2_available,
45
+ is_flash_attn_greater_or_equal_2_10,
46
+ )
47
+
48
+ from transformers.modeling_utils import PreTrainedModel
49
+
50
+ from .configuration_indictrans import IndicTransConfig
51
+
52
+
53
+ logger = logging.get_logger(__name__)
54
+
55
+ INDICTRANS_PRETRAINED_MODEL_ARCHIVE_LIST = [""]
56
+
57
+ try:
58
+ if is_flash_attn_2_available():
59
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
60
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
61
+ except:
62
+ pass
63
+
64
+
65
+ # Copied from transformers.models.llama.modeling_llama._get_unpad_data
66
+ def _get_unpad_data(attention_mask):
67
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
68
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
69
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
70
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
71
+ return (
72
+ indices,
73
+ cu_seqlens,
74
+ max_seqlen_in_batch,
75
+ )
76
+
77
+
78
+ # Copied from transformers.models.bart.modeling_bart.shift_tokens_right
79
+ def shift_tokens_right(
80
+ input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int
81
+ ):
82
+ """
83
+ Shift input ids one token to the right.
84
+ """
85
+ shifted_input_ids = input_ids.new_zeros(input_ids.shape)
86
+ shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
87
+ shifted_input_ids[:, 0] = decoder_start_token_id
88
+
89
+ if pad_token_id is None:
90
+ raise ValueError("self.model.config.pad_token_id has to be defined.")
91
+ # replace possible -100 values in labels by `pad_token_id`
92
+ shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
93
+
94
+ return shifted_input_ids
95
+
96
+
97
+ def create_position_ids_from_input_ids(
98
+ input_ids, padding_idx, past_key_values_length=0
99
+ ):
100
+ """
101
+ Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
102
+ are ignored. This is modified from fairseq's `utils.make_positions`.
103
+ """
104
+ # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
105
+ mask = input_ids.ne(padding_idx).int()
106
+ incremental_indices = (
107
+ torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length
108
+ ) * mask
109
+ return incremental_indices.long() + padding_idx
110
+
111
+
112
+ # Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100SinusoidalPositionalEmbedding->IndicTrans
113
+ class IndicTransSinusoidalPositionalEmbedding(nn.Module):
114
+ """This module produces sinusoidal positional embeddings of any length."""
115
+
116
+ def __init__(
117
+ self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None
118
+ ):
119
+ super().__init__()
120
+ self.offset = 2
121
+ self.embedding_dim = embedding_dim
122
+ self.padding_idx = padding_idx
123
+ self.make_weights(num_positions + self.offset, embedding_dim, padding_idx)
124
+
125
+ def make_weights(
126
+ self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None
127
+ ):
128
+ emb_weights = self.get_embedding(num_embeddings, embedding_dim, padding_idx)
129
+ if hasattr(self, "weights"):
130
+ # in forward put the weights on the correct dtype and device of the param
131
+ emb_weights = emb_weights.to(
132
+ dtype=self.weights.dtype, device=self.weights.device
133
+ )
134
+
135
+ self.register_buffer("weights", emb_weights, persistent=False)
136
+
137
+ @staticmethod
138
+ def get_embedding(
139
+ num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None
140
+ ):
141
+ """
142
+ Build sinusoidal embeddings.
143
+
144
+ This matches the implementation in tensor2tensor, but differs slightly from the description in Section 3.5 of
145
+ "Attention Is All You Need".
146
+ """
147
+ half_dim = embedding_dim // 2
148
+ emb = math.log(10000) / (half_dim - 1)
149
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)
150
+ emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(
151
+ 1
152
+ ) * emb.unsqueeze(0)
153
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(
154
+ num_embeddings, -1
155
+ )
156
+ if embedding_dim % 2 == 1:
157
+ # zero pad
158
+ emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
159
+ if padding_idx is not None:
160
+ emb[padding_idx, :] = 0
161
+
162
+ return emb.to(torch.get_default_dtype())
163
+
164
+ @torch.no_grad()
165
+ def forward(
166
+ self,
167
+ input_ids: torch.Tensor = None,
168
+ inputs_embeds: torch.Tensor = None,
169
+ past_key_values_length: int = 0,
170
+ ):
171
+ if input_ids is not None:
172
+ bsz, seq_len = input_ids.size()
173
+ # Create the position ids from the input token ids. Any padded tokens remain padded.
174
+ position_ids = create_position_ids_from_input_ids(
175
+ input_ids, self.padding_idx, past_key_values_length
176
+ ).to(input_ids.device)
177
+ else:
178
+ bsz, seq_len = inputs_embeds.size()[:-1]
179
+ position_ids = self.create_position_ids_from_inputs_embeds(
180
+ inputs_embeds, past_key_values_length
181
+ )
182
+
183
+ # expand embeddings if needed
184
+ max_pos = self.padding_idx + 1 + seq_len + past_key_values_length
185
+ if max_pos > self.weights.size(0):
186
+ self.make_weights(
187
+ max_pos + self.offset, self.embedding_dim, self.padding_idx
188
+ )
189
+
190
+ return (
191
+ self.weights.index_select(0, position_ids.view(-1))
192
+ .view(bsz, seq_len, self.weights.shape[-1])
193
+ .detach()
194
+ )
195
+
196
+ def create_position_ids_from_inputs_embeds(
197
+ self, inputs_embeds, past_key_values_length
198
+ ):
199
+ """
200
+ We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.
201
+
202
+ Args:
203
+ inputs_embeds: torch.Tensor
204
+
205
+ Returns: torch.Tensor
206
+ """
207
+ input_shape = inputs_embeds.size()[:-1]
208
+ sequence_length = input_shape[1]
209
+
210
+ position_ids = torch.arange(
211
+ self.padding_idx + 1,
212
+ sequence_length + self.padding_idx + 1,
213
+ dtype=torch.long,
214
+ device=inputs_embeds.device,
215
+ )
216
+ return (
217
+ position_ids.unsqueeze(0).expand(input_shape).contiguous()
218
+ + past_key_values_length
219
+ )
220
+
221
+
222
+ # Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->IndicTrans
223
+ class IndicTransAttention(nn.Module):
224
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
225
+
226
+ def __init__(
227
+ self,
228
+ embed_dim: int,
229
+ num_heads: int,
230
+ dropout: float = 0.0,
231
+ is_decoder: bool = False,
232
+ bias: bool = True,
233
+ is_causal: bool = False,
234
+ config: Optional[IndicTransConfig] = None,
235
+ ):
236
+ super().__init__()
237
+ self.embed_dim = embed_dim
238
+ self.num_heads = num_heads
239
+ self.dropout = dropout
240
+ self.head_dim = embed_dim // num_heads
241
+ self.config = config
242
+
243
+ if (self.head_dim * num_heads) != self.embed_dim:
244
+ raise ValueError(
245
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
246
+ f" and `num_heads`: {num_heads})."
247
+ )
248
+ self.scaling = self.head_dim**-0.5
249
+ self.is_decoder = is_decoder
250
+ self.is_causal = is_causal
251
+
252
+ self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
253
+ self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
254
+ self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
255
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
256
+
257
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
258
+ return (
259
+ tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
260
+ .transpose(1, 2)
261
+ .contiguous()
262
+ )
263
+
264
+ def forward(
265
+ self,
266
+ hidden_states: torch.Tensor,
267
+ key_value_states: Optional[torch.Tensor] = None,
268
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
269
+ attention_mask: Optional[torch.Tensor] = None,
270
+ layer_head_mask: Optional[torch.Tensor] = None,
271
+ output_attentions: bool = False,
272
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
273
+ """Input shape: Batch x Time x Channel"""
274
+
275
+ # if key_value_states are provided this layer is used as a cross-attention layer
276
+ # for the decoder
277
+ is_cross_attention = key_value_states is not None
278
+
279
+ bsz, tgt_len, _ = hidden_states.size()
280
+
281
+ # get query proj
282
+ query_states = self.q_proj(hidden_states) * self.scaling
283
+ # get key, value proj
284
+ # `past_key_value[0].shape[2] == key_value_states.shape[1]`
285
+ # is checking that the `sequence_length` of the `past_key_value` is the same as
286
+ # the provided `key_value_states` to support prefix tuning
287
+ if (
288
+ is_cross_attention
289
+ and past_key_value is not None
290
+ and past_key_value[0].shape[2] == key_value_states.shape[1]
291
+ ):
292
+ # reuse k,v, cross_attentions
293
+ key_states = past_key_value[0]
294
+ value_states = past_key_value[1]
295
+ elif is_cross_attention:
296
+ # cross_attentions
297
+ key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
298
+ value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
299
+ elif past_key_value is not None:
300
+ # reuse k, v, self_attention
301
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
302
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
303
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
304
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
305
+ else:
306
+ # self_attention
307
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
308
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
309
+
310
+ if self.is_decoder:
311
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
312
+ # Further calls to cross_attention layer can then reuse all cross-attention
313
+ # key/value_states (first "if" case)
314
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
315
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
316
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
317
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
318
+ past_key_value = (key_states, value_states)
319
+
320
+ proj_shape = (bsz * self.num_heads, -1, self.head_dim)
321
+ query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
322
+ key_states = key_states.reshape(*proj_shape)
323
+ value_states = value_states.reshape(*proj_shape)
324
+
325
+ src_len = key_states.size(1)
326
+ attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
327
+
328
+ if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
329
+ raise ValueError(
330
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
331
+ f" {attn_weights.size()}"
332
+ )
333
+
334
+ if attention_mask is not None:
335
+ if attention_mask.size() != (bsz, 1, tgt_len, src_len):
336
+ raise ValueError(
337
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
338
+ )
339
+ attn_weights = (
340
+ attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
341
+ + attention_mask
342
+ )
343
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
344
+
345
+ attn_weights = F.softmax(attn_weights, dim=-1)
346
+
347
+ if layer_head_mask is not None:
348
+ if layer_head_mask.size() != (self.num_heads,):
349
+ raise ValueError(
350
+ f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
351
+ f" {layer_head_mask.size()}"
352
+ )
353
+ attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(
354
+ bsz, self.num_heads, tgt_len, src_len
355
+ )
356
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
357
+
358
+ if output_attentions:
359
+ # this operation is a bit awkward, but it's required to
360
+ # make sure that attn_weights keeps its gradient.
361
+ # In order to do so, attn_weights have to be reshaped
362
+ # twice and have to be reused in the following
363
+ attn_weights_reshaped = attn_weights.view(
364
+ bsz, self.num_heads, tgt_len, src_len
365
+ )
366
+ attn_weights = attn_weights_reshaped.view(
367
+ bsz * self.num_heads, tgt_len, src_len
368
+ )
369
+ else:
370
+ attn_weights_reshaped = None
371
+
372
+ attn_probs = F.dropout(attn_weights, p=self.dropout, training=self.training)
373
+
374
+ attn_output = torch.bmm(attn_probs, value_states)
375
+
376
+ if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
377
+ raise ValueError(
378
+ f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is"
379
+ f" {attn_output.size()}"
380
+ )
381
+
382
+ attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
383
+ attn_output = attn_output.transpose(1, 2)
384
+
385
+ # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
386
+ # partitioned across GPUs when using tensor-parallelism.
387
+ attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
388
+
389
+ attn_output = self.out_proj(attn_output)
390
+
391
+ return attn_output, attn_weights_reshaped, past_key_value
392
+
393
+
394
+ class IndicTransFlashAttention2(IndicTransAttention):
395
+ """
396
+ IndicTrans flash attention module. This module inherits from `IndicTransAttention` as the weights of the module stays
397
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
398
+ flash attention and deal with padding tokens in case the input contains any of them.
399
+ """
400
+
401
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
402
+ def __init__(self, *args, **kwargs):
403
+ super().__init__(*args, **kwargs)
404
+
405
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
406
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
407
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
408
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
409
+
410
+ def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
411
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
412
+
413
+ def forward(
414
+ self,
415
+ hidden_states: torch.Tensor,
416
+ key_value_states: Optional[torch.Tensor] = None,
417
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
418
+ attention_mask: Optional[torch.Tensor] = None,
419
+ layer_head_mask: Optional[torch.Tensor] = None,
420
+ output_attentions: bool = False,
421
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
422
+ # IndicTransFlashAttention2 attention does not support output_attentions
423
+ if output_attentions:
424
+ raise ValueError("IndicTransFlashAttention2 attention does not support output_attentions")
425
+
426
+ # if key_value_states are provided this layer is used as a cross-attention layer
427
+ # for the decoder
428
+ is_cross_attention = key_value_states is not None
429
+
430
+ bsz, q_len, _ = hidden_states.size()
431
+
432
+ # get query proj
433
+ query_states = self._reshape(self.q_proj(hidden_states), -1, bsz)
434
+ # get key, value proj
435
+ # `past_key_value[0].shape[2] == key_value_states.shape[1]`
436
+ # is checking that the `sequence_length` of the `past_key_value` is the same as
437
+ # the provided `key_value_states` to support prefix tuning
438
+ if (
439
+ is_cross_attention
440
+ and past_key_value is not None
441
+ and past_key_value[0].shape[2] == key_value_states.shape[1]
442
+ ):
443
+ # reuse k,v, cross_attentions
444
+ key_states = past_key_value[0].transpose(1, 2)
445
+ value_states = past_key_value[1].transpose(1, 2)
446
+ elif is_cross_attention:
447
+ # cross_attentions
448
+ key_states = self._reshape(self.k_proj(key_value_states), -1, bsz)
449
+ value_states = self._reshape(self.v_proj(key_value_states), -1, bsz)
450
+ elif past_key_value is not None:
451
+ # reuse k, v, self_attention
452
+ key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
453
+ value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
454
+ key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1)
455
+ value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1)
456
+ else:
457
+ # self_attention
458
+ key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
459
+ value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
460
+
461
+ if self.is_decoder:
462
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
463
+ # Further calls to cross_attention layer can then reuse all cross-attention
464
+ # key/value_states (first "if" case)
465
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
466
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
467
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
468
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
469
+ past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2))
470
+
471
+ kv_seq_len = key_states.shape[-2]
472
+ if past_key_value is not None:
473
+ kv_seq_len += past_key_value[0].shape[-2]
474
+
475
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
476
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
477
+ # cast them back in the correct dtype just to be sure everything works as expected.
478
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
479
+ # in fp32. (LlamaRMSNorm handles it correctly)
480
+
481
+ input_dtype = query_states.dtype
482
+ if input_dtype == torch.float32:
483
+ if torch.is_autocast_enabled():
484
+ target_dtype = torch.get_autocast_gpu_dtype()
485
+ # Handle the case where the model is quantized
486
+ elif hasattr(self.config, "_pre_quantization_dtype"):
487
+ target_dtype = self.config._pre_quantization_dtype
488
+ else:
489
+ target_dtype = self.q_proj.weight.dtype
490
+
491
+ logger.warning_once(
492
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
493
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
494
+ f" {target_dtype}."
495
+ )
496
+
497
+ query_states = query_states.to(target_dtype)
498
+ key_states = key_states.to(target_dtype)
499
+ value_states = value_states.to(target_dtype)
500
+
501
+ attn_output = self._flash_attention_forward(
502
+ query_states, key_states, value_states, attention_mask, q_len, dropout=self.dropout
503
+ )
504
+
505
+ attn_output = attn_output.reshape(bsz, q_len, -1)
506
+ attn_output = self.out_proj(attn_output)
507
+
508
+ if not output_attentions:
509
+ attn_weights = None
510
+
511
+ return attn_output, attn_weights, past_key_value
512
+
513
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward
514
+ def _flash_attention_forward(
515
+ self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
516
+ ):
517
+ """
518
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
519
+ first unpad the input, then computes the attention scores and pad the final attention scores.
520
+
521
+ Args:
522
+ query_states (`torch.Tensor`):
523
+ Input query states to be passed to Flash Attention API
524
+ key_states (`torch.Tensor`):
525
+ Input key states to be passed to Flash Attention API
526
+ value_states (`torch.Tensor`):
527
+ Input value states to be passed to Flash Attention API
528
+ attention_mask (`torch.Tensor`):
529
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
530
+ position of padding tokens and 1 for the position of non-padding tokens.
531
+ dropout (`float`):
532
+ Attention dropout
533
+ softmax_scale (`float`, *optional*):
534
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
535
+ """
536
+ if not self._flash_attn_uses_top_left_mask:
537
+ causal = self.is_causal
538
+ else:
539
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
540
+ causal = self.is_causal and query_length != 1
541
+
542
+ # Contains at least one padding token in the sequence
543
+ if attention_mask is not None:
544
+ batch_size = query_states.shape[0]
545
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
546
+ query_states, key_states, value_states, attention_mask, query_length
547
+ )
548
+
549
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
550
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
551
+
552
+ attn_output_unpad = flash_attn_varlen_func(
553
+ query_states,
554
+ key_states,
555
+ value_states,
556
+ cu_seqlens_q=cu_seqlens_q,
557
+ cu_seqlens_k=cu_seqlens_k,
558
+ max_seqlen_q=max_seqlen_in_batch_q,
559
+ max_seqlen_k=max_seqlen_in_batch_k,
560
+ dropout_p=dropout,
561
+ softmax_scale=softmax_scale,
562
+ causal=causal,
563
+ )
564
+
565
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
566
+ else:
567
+ attn_output = flash_attn_func(
568
+ query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
569
+ )
570
+
571
+ return attn_output
572
+
573
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input
574
+ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
575
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
576
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
577
+
578
+ key_layer = index_first_axis(
579
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
580
+ )
581
+ value_layer = index_first_axis(
582
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
583
+ )
584
+ if query_length == kv_seq_len:
585
+ query_layer = index_first_axis(
586
+ query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
587
+ )
588
+ cu_seqlens_q = cu_seqlens_k
589
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
590
+ indices_q = indices_k
591
+ elif query_length == 1:
592
+ max_seqlen_in_batch_q = 1
593
+ cu_seqlens_q = torch.arange(
594
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
595
+ ) # There is a memcpy here, that is very bad.
596
+ indices_q = cu_seqlens_q[:-1]
597
+ query_layer = query_layer.squeeze(1)
598
+ else:
599
+ # The -q_len: slice assumes left padding.
600
+ attention_mask = attention_mask[:, -query_length:]
601
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
602
+
603
+ return (
604
+ query_layer,
605
+ key_layer,
606
+ value_layer,
607
+ indices_q,
608
+ (cu_seqlens_q, cu_seqlens_k),
609
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
610
+ )
611
+
612
+
613
+ class IndicTransSdpaAttention(IndicTransAttention):
614
+ def forward(
615
+ self,
616
+ hidden_states: torch.Tensor,
617
+ key_value_states: Optional[torch.Tensor] = None,
618
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
619
+ attention_mask: Optional[torch.Tensor] = None,
620
+ layer_head_mask: Optional[torch.Tensor] = None,
621
+ output_attentions: bool = False,
622
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
623
+ """Input shape: Batch x Time x Channel"""
624
+ if output_attentions or layer_head_mask is not None:
625
+ # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented.
626
+ logger.warning_once(
627
+ "IndicTransModel is using IndicTransSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` or `layer_head_mask` not None. Falling back to the manual attention"
628
+ ' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
629
+ )
630
+ return super().forward(
631
+ hidden_states,
632
+ key_value_states=key_value_states,
633
+ past_key_value=past_key_value,
634
+ attention_mask=attention_mask,
635
+ layer_head_mask=layer_head_mask,
636
+ output_attentions=output_attentions,
637
+ )
638
+
639
+ # if key_value_states are provided this layer is used as a cross-attention layer
640
+ # for the decoder
641
+ is_cross_attention = key_value_states is not None
642
+
643
+ bsz, tgt_len, _ = hidden_states.size()
644
+
645
+ # get query proj
646
+ query_states = self.q_proj(hidden_states)
647
+ # get key, value proj
648
+ # `past_key_value[0].shape[2] == key_value_states.shape[1]`
649
+ # is checking that the `sequence_length` of the `past_key_value` is the same as
650
+ # the provided `key_value_states` to support prefix tuning
651
+ if (
652
+ is_cross_attention
653
+ and past_key_value is not None
654
+ and past_key_value[0].shape[2] == key_value_states.shape[1]
655
+ ):
656
+ # reuse k,v, cross_attentions
657
+ key_states = past_key_value[0]
658
+ value_states = past_key_value[1]
659
+ elif is_cross_attention:
660
+ # cross_attentions
661
+ key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
662
+ value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
663
+ elif past_key_value is not None:
664
+ # reuse k, v, self_attention
665
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
666
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
667
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
668
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
669
+ else:
670
+ # self_attention
671
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
672
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
673
+
674
+ if self.is_decoder:
675
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
676
+ # Further calls to cross_attention layer can then reuse all cross-attention
677
+ # key/value_states (first "if" case)
678
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
679
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
680
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
681
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
682
+ past_key_value = (key_states, value_states)
683
+
684
+ query_states = self._shape(query_states, tgt_len, bsz)
685
+
686
+ # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask,
687
+ # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577
688
+ attn_output = F.scaled_dot_product_attention(
689
+ query_states,
690
+ key_states,
691
+ value_states,
692
+ attn_mask=attention_mask,
693
+ dropout_p=self.dropout if self.training else 0.0,
694
+ # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1.
695
+ is_causal=self.is_causal and attention_mask is None and tgt_len > 1,
696
+ )
697
+
698
+ if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim):
699
+ raise ValueError(
700
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
701
+ f" {attn_output.size()}"
702
+ )
703
+
704
+ attn_output = attn_output.transpose(1, 2)
705
+
706
+ # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
707
+ # partitioned across GPUs when using tensor-parallelism.
708
+ attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
709
+
710
+ attn_output = self.out_proj(attn_output)
711
+
712
+ return attn_output, None, past_key_value
713
+
714
+
715
+ INDICTRANS_ATTENTION_CLASSES = {
716
+ "eager": IndicTransAttention,
717
+ "sdpa": IndicTransSdpaAttention,
718
+ "flash_attention_2": IndicTransFlashAttention2,
719
+ }
720
+
721
+ # Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->IndicTrans
722
+ class IndicTransEncoderLayer(nn.Module):
723
+ def __init__(self, config: IndicTransConfig):
724
+ super().__init__()
725
+ self.embed_dim = config.encoder_embed_dim
726
+ self.self_attn = INDICTRANS_ATTENTION_CLASSES[config._attn_implementation](
727
+ embed_dim=self.embed_dim,
728
+ num_heads=config.encoder_attention_heads,
729
+ dropout=config.attention_dropout,
730
+ config=config,
731
+ )
732
+ self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
733
+ self.dropout = config.dropout
734
+ self.activation_fn = ACT2FN[config.activation_function]
735
+ self.activation_dropout = config.activation_dropout
736
+ self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
737
+ self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
738
+ self.final_layer_norm = nn.LayerNorm(self.embed_dim)
739
+ self.normalize_before = config.encoder_normalize_before
740
+
741
+ def forward(
742
+ self,
743
+ hidden_states: torch.Tensor,
744
+ attention_mask: torch.Tensor,
745
+ layer_head_mask: torch.Tensor,
746
+ output_attentions: bool = False,
747
+ ) -> torch.Tensor:
748
+ """
749
+ Args:
750
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
751
+ attention_mask (`torch.FloatTensor`): attention mask of size
752
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
753
+ layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
754
+ `(encoder_attention_heads,)`.
755
+ output_attentions (`bool`, *optional*):
756
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
757
+ returned tensors for more detail.
758
+ """
759
+ residual = hidden_states
760
+ if self.normalize_before:
761
+ hidden_states = self.self_attn_layer_norm(hidden_states)
762
+ hidden_states, attn_weights, _ = self.self_attn(
763
+ hidden_states=hidden_states,
764
+ attention_mask=attention_mask,
765
+ layer_head_mask=layer_head_mask,
766
+ output_attentions=output_attentions,
767
+ )
768
+ hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
769
+ hidden_states = residual + hidden_states
770
+ if not self.normalize_before:
771
+ hidden_states = self.self_attn_layer_norm(hidden_states)
772
+
773
+ residual = hidden_states
774
+ if self.normalize_before:
775
+ hidden_states = self.final_layer_norm(hidden_states)
776
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
777
+ hidden_states = F.dropout(
778
+ hidden_states, p=self.activation_dropout, training=self.training
779
+ )
780
+ hidden_states = self.fc2(hidden_states)
781
+ hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
782
+ hidden_states = residual + hidden_states
783
+ if not self.normalize_before:
784
+ hidden_states = self.final_layer_norm(hidden_states)
785
+
786
+ if hidden_states.dtype == torch.float16 and (
787
+ torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
788
+ ):
789
+ clamp_value = torch.finfo(hidden_states.dtype).max - 1000
790
+ hidden_states = torch.clamp(
791
+ hidden_states, min=-clamp_value, max=clamp_value
792
+ )
793
+
794
+ outputs = (hidden_states,)
795
+
796
+ if output_attentions:
797
+ outputs += (attn_weights,)
798
+
799
+ return outputs
800
+
801
+
802
+ # Copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer with MBart->IndicTrans
803
+ class IndicTransDecoderLayer(nn.Module):
804
+ def __init__(self, config: IndicTransConfig):
805
+ super().__init__()
806
+ self.embed_dim = config.decoder_embed_dim
807
+
808
+ self.self_attn = INDICTRANS_ATTENTION_CLASSES[config._attn_implementation](
809
+ embed_dim=self.embed_dim,
810
+ num_heads=config.decoder_attention_heads,
811
+ dropout=config.attention_dropout,
812
+ is_decoder=True,
813
+ is_causal=True,
814
+ config=config,
815
+ )
816
+ self.dropout = config.dropout
817
+ self.activation_fn = ACT2FN[config.activation_function]
818
+ self.activation_dropout = config.activation_dropout
819
+
820
+ self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
821
+ self.encoder_attn = INDICTRANS_ATTENTION_CLASSES[config._attn_implementation](
822
+ self.embed_dim,
823
+ config.decoder_attention_heads,
824
+ dropout=config.attention_dropout,
825
+ is_decoder=True,
826
+ config=config,
827
+ )
828
+ self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
829
+ self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
830
+ self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
831
+ self.final_layer_norm = nn.LayerNorm(self.embed_dim)
832
+ self.normalize_before = config.decoder_normalize_before
833
+
834
+ def forward(
835
+ self,
836
+ hidden_states: torch.Tensor,
837
+ attention_mask: Optional[torch.Tensor] = None,
838
+ encoder_hidden_states: Optional[torch.Tensor] = None,
839
+ encoder_attention_mask: Optional[torch.Tensor] = None,
840
+ layer_head_mask: Optional[torch.Tensor] = None,
841
+ cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
842
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
843
+ output_attentions: Optional[bool] = False,
844
+ use_cache: Optional[bool] = True,
845
+ ) -> torch.Tensor:
846
+ """
847
+ Args:
848
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
849
+ attention_mask (`torch.FloatTensor`): attention mask of size
850
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
851
+ encoder_hidden_states (`torch.FloatTensor`):
852
+ cross attention input to the layer of shape `(batch, seq_len, embed_dim)`
853
+ encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
854
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
855
+ layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
856
+ `(encoder_attention_heads,)`.
857
+ cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of
858
+ size `(decoder_attention_heads,)`.
859
+ past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states
860
+ output_attentions (`bool`, *optional*):
861
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
862
+ returned tensors for more detail.
863
+ """
864
+ residual = hidden_states
865
+ if self.normalize_before:
866
+ hidden_states = self.self_attn_layer_norm(hidden_states)
867
+
868
+ # Self Attention
869
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
870
+ self_attn_past_key_value = (
871
+ past_key_value[:2] if past_key_value is not None else None
872
+ )
873
+ # add present self-attn cache to positions 1,2 of present_key_value tuple
874
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
875
+ hidden_states=hidden_states,
876
+ past_key_value=self_attn_past_key_value,
877
+ attention_mask=attention_mask,
878
+ layer_head_mask=layer_head_mask,
879
+ output_attentions=output_attentions,
880
+ )
881
+ hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
882
+ hidden_states = residual + hidden_states
883
+ if not self.normalize_before:
884
+ hidden_states = self.self_attn_layer_norm(hidden_states)
885
+
886
+ # Cross-Attention Block
887
+ cross_attn_present_key_value = None
888
+ cross_attn_weights = None
889
+ if encoder_hidden_states is not None:
890
+ residual = hidden_states
891
+ if self.normalize_before:
892
+ hidden_states = self.encoder_attn_layer_norm(hidden_states)
893
+
894
+ # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
895
+ cross_attn_past_key_value = (
896
+ past_key_value[-2:] if past_key_value is not None else None
897
+ )
898
+ (
899
+ hidden_states,
900
+ cross_attn_weights,
901
+ cross_attn_present_key_value,
902
+ ) = self.encoder_attn(
903
+ hidden_states=hidden_states,
904
+ key_value_states=encoder_hidden_states,
905
+ attention_mask=encoder_attention_mask,
906
+ layer_head_mask=cross_attn_layer_head_mask,
907
+ past_key_value=cross_attn_past_key_value,
908
+ output_attentions=output_attentions,
909
+ )
910
+ hidden_states = F.dropout(
911
+ hidden_states, p=self.dropout, training=self.training
912
+ )
913
+ hidden_states = residual + hidden_states
914
+ if not self.normalize_before:
915
+ hidden_states = self.encoder_attn_layer_norm(hidden_states)
916
+
917
+ # add cross-attn to positions 3,4 of present_key_value tuple
918
+ present_key_value = present_key_value + cross_attn_present_key_value
919
+
920
+ # Fully Connected
921
+ residual = hidden_states
922
+ if self.normalize_before:
923
+ hidden_states = self.final_layer_norm(hidden_states)
924
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
925
+ hidden_states = F.dropout(
926
+ hidden_states, p=self.activation_dropout, training=self.training
927
+ )
928
+ hidden_states = self.fc2(hidden_states)
929
+ hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
930
+ hidden_states = residual + hidden_states
931
+ if not self.normalize_before:
932
+ hidden_states = self.final_layer_norm(hidden_states)
933
+
934
+ outputs = (hidden_states,)
935
+
936
+ if output_attentions:
937
+ outputs += (self_attn_weights, cross_attn_weights)
938
+
939
+ if use_cache:
940
+ outputs += (present_key_value,)
941
+
942
+ return outputs
943
+
944
+
945
+ # Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100PretrainedModel->IndicTrans
946
+ class IndicTransPreTrainedModel(PreTrainedModel):
947
+ config_class = IndicTransConfig
948
+ base_model_prefix = "model"
949
+ supports_gradient_checkpointing = True
950
+ _no_split_modules = ["IndicTransAttention"]
951
+
952
+ def _init_weights(self, module):
953
+ std = self.config.init_std
954
+ if isinstance(module, nn.Linear):
955
+ module.weight.data.normal_(mean=0.0, std=std)
956
+ if module.bias is not None:
957
+ module.bias.data.zero_()
958
+ elif isinstance(module, nn.Embedding):
959
+ module.weight.data.normal_(mean=0.0, std=std)
960
+ if module.padding_idx is not None:
961
+ module.weight.data[module.padding_idx].zero_()
962
+
963
+ def _set_gradient_checkpointing(self, module, value=False):
964
+ if isinstance(module, (IndicTransDecoder, IndicTransEncoder)):
965
+ module.gradient_checkpointing = value
966
+
967
+
968
+ # Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100EncoderLayer->IndicTrans
969
+ class IndicTransEncoder(IndicTransPreTrainedModel):
970
+ """
971
+ Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
972
+ [`IndicTransEncoderLayer`].
973
+
974
+ Args:
975
+ config: IndicTransConfig
976
+ embed_tokens (nn.Embedding): output embedding
977
+ """
978
+
979
+ def __init__(
980
+ self, config: IndicTransConfig, embed_tokens: Optional[nn.Embedding] = None
981
+ ):
982
+ super().__init__(config)
983
+
984
+ self.dropout = config.dropout
985
+ self.layerdrop = config.encoder_layerdrop
986
+
987
+ embed_dim = config.encoder_embed_dim
988
+ self.padding_idx = config.pad_token_id
989
+ self.max_source_positions = config.max_source_positions
990
+ self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
991
+
992
+ self.embed_tokens = nn.Embedding(
993
+ config.encoder_vocab_size, embed_dim, self.padding_idx
994
+ )
995
+
996
+ if embed_tokens is not None:
997
+ self.embed_tokens.weight = embed_tokens.weight
998
+
999
+ self.embed_positions = IndicTransSinusoidalPositionalEmbedding(
1000
+ config.max_source_positions,
1001
+ embed_dim,
1002
+ self.padding_idx,
1003
+ )
1004
+ self.layers = nn.ModuleList(
1005
+ [IndicTransEncoderLayer(config) for _ in range(config.encoder_layers)]
1006
+ )
1007
+ self.layer_norm = (
1008
+ nn.LayerNorm(embed_dim) if config.encoder_normalize_before else None
1009
+ )
1010
+ self.layernorm_embedding = (
1011
+ nn.LayerNorm(embed_dim) if config.layernorm_embedding else None
1012
+ )
1013
+
1014
+ self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
1015
+ self._use_sdpa = config._attn_implementation == "sdpa"
1016
+
1017
+ self.gradient_checkpointing = False
1018
+ # Initialize weights and apply final processing
1019
+ self.post_init()
1020
+
1021
+ def forward(
1022
+ self,
1023
+ input_ids: Optional[torch.Tensor] = None,
1024
+ attention_mask: Optional[torch.Tensor] = None,
1025
+ head_mask: Optional[torch.Tensor] = None,
1026
+ inputs_embeds: Optional[torch.Tensor] = None,
1027
+ output_attentions: Optional[bool] = None,
1028
+ output_hidden_states: Optional[bool] = None,
1029
+ return_dict: Optional[bool] = None,
1030
+ ):
1031
+ r"""
1032
+ Args:
1033
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1034
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
1035
+ provide it.
1036
+
1037
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1038
+ [`PreTrainedTokenizer.__call__`] for details.
1039
+
1040
+ [What are input IDs?](../glossary#input-ids)
1041
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1042
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
1043
+
1044
+ - 1 for tokens that are **not masked**,
1045
+ - 0 for tokens that are **masked**.
1046
+
1047
+ [What are attention masks?](../glossary#attention-mask)
1048
+ head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
1049
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
1050
+
1051
+ - 1 indicates the head is **not masked**,
1052
+ - 0 indicates the head is **masked**.
1053
+
1054
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1055
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
1056
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
1057
+ than the model's internal embedding lookup matrix.
1058
+ output_attentions (`bool`, *optional*):
1059
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
1060
+ returned tensors for more detail.
1061
+ output_hidden_states (`bool`, *optional*):
1062
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
1063
+ for more detail.
1064
+ return_dict (`bool`, *optional*):
1065
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1066
+ """
1067
+ output_attentions = (
1068
+ output_attentions
1069
+ if output_attentions is not None
1070
+ else self.config.output_attentions
1071
+ )
1072
+ output_hidden_states = (
1073
+ output_hidden_states
1074
+ if output_hidden_states is not None
1075
+ else self.config.output_hidden_states
1076
+ )
1077
+ return_dict = (
1078
+ return_dict if return_dict is not None else self.config.use_return_dict
1079
+ )
1080
+
1081
+ # retrieve input_ids and inputs_embeds
1082
+ if input_ids is not None and inputs_embeds is not None:
1083
+ raise ValueError(
1084
+ "You cannot specify both input_ids and inputs_embeds at the same time"
1085
+ )
1086
+ elif input_ids is not None:
1087
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
1088
+ input_shape = input_ids.size()
1089
+ input_ids = input_ids.view(-1, input_shape[-1])
1090
+ elif inputs_embeds is not None:
1091
+ input_shape = inputs_embeds.size()[:-1]
1092
+ else:
1093
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
1094
+
1095
+ if inputs_embeds is None:
1096
+ inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
1097
+
1098
+ embed_pos = self.embed_positions(input_ids, inputs_embeds)
1099
+ embed_pos = embed_pos.to(inputs_embeds.device)
1100
+
1101
+ hidden_states = inputs_embeds + embed_pos
1102
+ if self.layernorm_embedding is not None:
1103
+ hidden_states = self.layernorm_embedding(hidden_states)
1104
+ hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
1105
+
1106
+ if attention_mask is not None:
1107
+ if self._use_flash_attention_2:
1108
+ attention_mask = attention_mask if 0 in attention_mask else None
1109
+ elif self._use_sdpa and head_mask is None and not output_attentions:
1110
+ # output_attentions=True & head_mask can not be supported when using SDPA, fall back to
1111
+ # the manual implementation that requires a 4D causal mask in all cases.
1112
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
1113
+ attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype)
1114
+ else:
1115
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
1116
+ attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
1117
+
1118
+
1119
+ encoder_states = () if output_hidden_states else None
1120
+ all_attentions = () if output_attentions else None
1121
+
1122
+ # check if head_mask has a correct number of layers specified if desired
1123
+ if head_mask is not None:
1124
+ if head_mask.size()[0] != len(self.layers):
1125
+ raise ValueError(
1126
+ f"The head_mask should be specified for {len(self.layers)} layers, but it is for"
1127
+ f" {head_mask.size()[0]}."
1128
+ )
1129
+ deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled()
1130
+
1131
+ for idx, encoder_layer in enumerate(self.layers):
1132
+ if output_hidden_states:
1133
+ encoder_states = encoder_states + (hidden_states,)
1134
+
1135
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
1136
+ dropout_probability = torch.rand([])
1137
+
1138
+ skip_the_layer = (
1139
+ True
1140
+ if self.training and (dropout_probability < self.layerdrop)
1141
+ else False
1142
+ )
1143
+ if not skip_the_layer or deepspeed_zero3_is_enabled:
1144
+ # under deepspeed zero3 all gpus must run in sync
1145
+
1146
+ if self.gradient_checkpointing and self.training:
1147
+ # create gradient checkpointing function
1148
+ def create_custom_forward(module):
1149
+ def custom_forward(*inputs):
1150
+ return module(*inputs, output_attentions)
1151
+
1152
+ return custom_forward
1153
+
1154
+ layer_outputs = torch.utils.checkpoint.checkpoint(
1155
+ create_custom_forward(encoder_layer),
1156
+ hidden_states,
1157
+ attention_mask,
1158
+ (head_mask[idx] if head_mask is not None else None),
1159
+ )
1160
+ else:
1161
+ layer_outputs = encoder_layer(
1162
+ hidden_states,
1163
+ attention_mask,
1164
+ layer_head_mask=(
1165
+ head_mask[idx] if head_mask is not None else None
1166
+ ),
1167
+ output_attentions=output_attentions,
1168
+ )
1169
+
1170
+ hidden_states = layer_outputs[0]
1171
+
1172
+ if skip_the_layer:
1173
+ layer_outputs = (None, None)
1174
+
1175
+ if output_attentions:
1176
+ all_attentions = all_attentions + (layer_outputs[1],)
1177
+
1178
+ if self.layer_norm is not None:
1179
+ hidden_states = self.layer_norm(hidden_states)
1180
+
1181
+ if output_hidden_states:
1182
+ encoder_states = encoder_states + (hidden_states,)
1183
+
1184
+ if not return_dict:
1185
+ return tuple(
1186
+ v
1187
+ for v in [hidden_states, encoder_states, all_attentions]
1188
+ if v is not None
1189
+ )
1190
+ return BaseModelOutput(
1191
+ last_hidden_state=hidden_states,
1192
+ hidden_states=encoder_states,
1193
+ attentions=all_attentions,
1194
+ )
1195
+
1196
+
1197
+ # Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100DecoderLayer->IndicTrans
1198
+ class IndicTransDecoder(IndicTransPreTrainedModel):
1199
+ """
1200
+ Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`IndicTransDecoderLayer`]
1201
+
1202
+ Args:
1203
+ config: IndicTransConfig
1204
+ embed_tokens (nn.Embedding): output embedding
1205
+ """
1206
+
1207
+ def __init__(
1208
+ self, config: IndicTransConfig, embed_tokens: Optional[nn.Embedding] = None
1209
+ ):
1210
+ super().__init__(config)
1211
+ self.dropout = config.dropout
1212
+ self.layerdrop = config.decoder_layerdrop
1213
+
1214
+ embed_dim = config.encoder_embed_dim
1215
+ self.padding_idx = config.pad_token_id
1216
+ self.max_target_positions = config.max_target_positions
1217
+ self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
1218
+
1219
+ self.embed_tokens = nn.Embedding(
1220
+ config.decoder_vocab_size, embed_dim, self.padding_idx
1221
+ )
1222
+
1223
+ if embed_tokens is not None:
1224
+ self.embed_tokens.weight = embed_tokens.weight
1225
+
1226
+ self.embed_positions = IndicTransSinusoidalPositionalEmbedding(
1227
+ config.max_target_positions,
1228
+ embed_dim,
1229
+ self.padding_idx,
1230
+ )
1231
+ self.layers = nn.ModuleList(
1232
+ [IndicTransDecoderLayer(config) for _ in range(config.decoder_layers)]
1233
+ )
1234
+ self.layer_norm = (
1235
+ nn.LayerNorm(embed_dim) if config.decoder_normalize_before else None
1236
+ )
1237
+ self.layernorm_embedding = (
1238
+ nn.LayerNorm(embed_dim) if config.layernorm_embedding else None
1239
+ )
1240
+
1241
+ self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
1242
+ self._use_sdpa = config._attn_implementation == "sdpa"
1243
+
1244
+ self.gradient_checkpointing = False
1245
+ # Initialize weights and apply final processing
1246
+ self.post_init()
1247
+
1248
+ def forward(
1249
+ self,
1250
+ input_ids: Optional[torch.Tensor] = None,
1251
+ attention_mask: Optional[torch.Tensor] = None,
1252
+ encoder_hidden_states: Optional[torch.Tensor] = None,
1253
+ encoder_attention_mask: Optional[torch.Tensor] = None,
1254
+ head_mask: Optional[torch.Tensor] = None,
1255
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
1256
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1257
+ inputs_embeds: Optional[torch.Tensor] = None,
1258
+ use_cache: Optional[bool] = None,
1259
+ output_attentions: Optional[bool] = None,
1260
+ output_hidden_states: Optional[bool] = None,
1261
+ return_dict: Optional[bool] = None,
1262
+ ):
1263
+ r"""
1264
+ Args:
1265
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1266
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
1267
+ provide it.
1268
+
1269
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1270
+ [`PreTrainedTokenizer.__call__`] for details.
1271
+
1272
+ [What are input IDs?](../glossary#input-ids)
1273
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1274
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
1275
+
1276
+ - 1 for tokens that are **not masked**,
1277
+ - 0 for tokens that are **masked**.
1278
+
1279
+ [What are attention masks?](../glossary#attention-mask)
1280
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
1281
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
1282
+ of the decoder.
1283
+ encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*):
1284
+ Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values
1285
+ selected in `[0, 1]`:
1286
+
1287
+ - 1 for tokens that are **not masked**,
1288
+ - 0 for tokens that are **masked**.
1289
+
1290
+ [What are attention masks?](../glossary#attention-mask)
1291
+ head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
1292
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
1293
+
1294
+ - 1 indicates the head is **not masked**,
1295
+ - 0 indicates the head is **masked**.
1296
+
1297
+ cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
1298
+ Mask to nullify selected heads of the cross-attention modules in the decoder to avoid performing
1299
+ cross-attention on hidden heads. Mask values selected in `[0, 1]`:
1300
+
1301
+ - 1 indicates the head is **not masked**,
1302
+ - 0 indicates the head is **masked**.
1303
+
1304
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
1305
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
1306
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
1307
+ shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
1308
+
1309
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
1310
+ cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
1311
+
1312
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
1313
+ that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
1314
+ all `decoder_input_ids` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of
1315
+ shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing
1316
+ `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more
1317
+ control over how to convert `input_ids` indices into associated vectors than the model's internal
1318
+ embedding lookup matrix.
1319
+ output_attentions (`bool`, *optional*):
1320
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
1321
+ returned tensors for more detail.
1322
+ output_hidden_states (`bool`, *optional*):
1323
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
1324
+ for more detail.
1325
+ return_dict (`bool`, *optional*):
1326
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1327
+ """
1328
+ output_attentions = (
1329
+ output_attentions
1330
+ if output_attentions is not None
1331
+ else self.config.output_attentions
1332
+ )
1333
+ output_hidden_states = (
1334
+ output_hidden_states
1335
+ if output_hidden_states is not None
1336
+ else self.config.output_hidden_states
1337
+ )
1338
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1339
+ return_dict = (
1340
+ return_dict if return_dict is not None else self.config.use_return_dict
1341
+ )
1342
+
1343
+ # retrieve input_ids and inputs_embeds
1344
+ if input_ids is not None and inputs_embeds is not None:
1345
+ raise ValueError(
1346
+ "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
1347
+ )
1348
+ elif input_ids is not None:
1349
+ input_shape = input_ids.size()
1350
+ input_ids = input_ids.view(-1, input_shape[-1])
1351
+ elif inputs_embeds is not None:
1352
+ input_shape = inputs_embeds.size()[:-1]
1353
+ else:
1354
+ raise ValueError(
1355
+ "You have to specify either decoder_input_ids or decoder_inputs_embeds"
1356
+ )
1357
+
1358
+ # past_key_values_length
1359
+ past_key_values_length = (
1360
+ past_key_values[0][0].shape[2] if past_key_values is not None else 0
1361
+ )
1362
+
1363
+ if inputs_embeds is None:
1364
+ inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
1365
+
1366
+
1367
+ if self._use_flash_attention_2:
1368
+ # 2d mask is passed through the layers
1369
+ attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
1370
+ elif self._use_sdpa and not output_attentions and cross_attn_head_mask is None:
1371
+ # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
1372
+ # the manual implementation that requires a 4D causal mask in all cases.
1373
+ attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
1374
+ attention_mask,
1375
+ input_shape,
1376
+ inputs_embeds,
1377
+ past_key_values_length,
1378
+ )
1379
+ else:
1380
+ # 4d mask is passed through the layers
1381
+ attention_mask = _prepare_4d_causal_attention_mask(
1382
+ attention_mask, input_shape, inputs_embeds, past_key_values_length
1383
+ )
1384
+
1385
+ # expand encoder attention mask
1386
+ if encoder_hidden_states is not None and encoder_attention_mask is not None:
1387
+ if self._use_flash_attention_2:
1388
+ encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
1389
+ elif self._use_sdpa and cross_attn_head_mask is None and not output_attentions:
1390
+ # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
1391
+ # the manual implementation that requires a 4D causal mask in all cases.
1392
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
1393
+ encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
1394
+ encoder_attention_mask,
1395
+ inputs_embeds.dtype,
1396
+ tgt_len=input_shape[-1],
1397
+ )
1398
+ else:
1399
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
1400
+ encoder_attention_mask = _prepare_4d_attention_mask(
1401
+ encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
1402
+ )
1403
+
1404
+ # embed positions
1405
+ positions = self.embed_positions(
1406
+ input_ids, inputs_embeds, past_key_values_length
1407
+ )
1408
+ positions = positions.to(inputs_embeds.device)
1409
+
1410
+ hidden_states = inputs_embeds + positions
1411
+ if self.layernorm_embedding is not None:
1412
+ hidden_states = self.layernorm_embedding(hidden_states)
1413
+
1414
+ hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
1415
+
1416
+ if self.gradient_checkpointing and self.training:
1417
+ if use_cache:
1418
+ logger.warning_once(
1419
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting"
1420
+ " `use_cache=False`..."
1421
+ )
1422
+ use_cache = False
1423
+
1424
+ # decoder layers
1425
+ all_hidden_states = () if output_hidden_states else None
1426
+ all_self_attns = () if output_attentions else None
1427
+ all_cross_attentions = () if output_attentions else None
1428
+ next_decoder_cache = () if use_cache else None
1429
+
1430
+ # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
1431
+ for attn_mask, mask_name in zip(
1432
+ [head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]
1433
+ ):
1434
+ if attn_mask is not None:
1435
+ if attn_mask.size()[0] != len(self.layers):
1436
+ raise ValueError(
1437
+ f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
1438
+ f" {head_mask.size()[0]}."
1439
+ )
1440
+ deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled()
1441
+
1442
+ for idx, decoder_layer in enumerate(self.layers):
1443
+ if output_hidden_states:
1444
+ all_hidden_states += (hidden_states,)
1445
+
1446
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
1447
+ dropout_probability = torch.rand([])
1448
+
1449
+ skip_the_layer = (
1450
+ True
1451
+ if self.training and (dropout_probability < self.layerdrop)
1452
+ else False
1453
+ )
1454
+ if not skip_the_layer or deepspeed_zero3_is_enabled:
1455
+ # under deepspeed zero3 all gpus must run in sync
1456
+
1457
+ past_key_value = (
1458
+ past_key_values[idx] if past_key_values is not None else None
1459
+ )
1460
+
1461
+ if self.gradient_checkpointing and self.training:
1462
+
1463
+ def create_custom_forward(module):
1464
+ def custom_forward(*inputs):
1465
+ # None for past_key_value
1466
+ return module(*inputs, output_attentions, use_cache)
1467
+
1468
+ return custom_forward
1469
+
1470
+ layer_outputs = torch.utils.checkpoint.checkpoint(
1471
+ create_custom_forward(decoder_layer),
1472
+ hidden_states,
1473
+ attention_mask,
1474
+ encoder_hidden_states,
1475
+ encoder_attention_mask,
1476
+ head_mask[idx] if head_mask is not None else None,
1477
+ cross_attn_head_mask[idx]
1478
+ if cross_attn_head_mask is not None
1479
+ else None,
1480
+ None,
1481
+ )
1482
+ else:
1483
+ layer_outputs = decoder_layer(
1484
+ hidden_states,
1485
+ attention_mask=attention_mask,
1486
+ encoder_hidden_states=encoder_hidden_states,
1487
+ encoder_attention_mask=encoder_attention_mask,
1488
+ layer_head_mask=(
1489
+ head_mask[idx] if head_mask is not None else None
1490
+ ),
1491
+ cross_attn_layer_head_mask=(
1492
+ cross_attn_head_mask[idx]
1493
+ if cross_attn_head_mask is not None
1494
+ else None
1495
+ ),
1496
+ past_key_value=past_key_value,
1497
+ output_attentions=output_attentions,
1498
+ use_cache=use_cache,
1499
+ )
1500
+
1501
+ hidden_states = layer_outputs[0]
1502
+
1503
+ if skip_the_layer:
1504
+ continue
1505
+
1506
+ if use_cache:
1507
+ next_decoder_cache += (layer_outputs[3 if output_attentions else 1],)
1508
+
1509
+ if output_attentions:
1510
+ all_self_attns += (layer_outputs[1],)
1511
+ all_cross_attentions += (layer_outputs[2],)
1512
+
1513
+ if self.layer_norm is not None:
1514
+ hidden_states = self.layer_norm(hidden_states)
1515
+
1516
+ # add hidden states from the last decoder layer
1517
+ if output_hidden_states:
1518
+ all_hidden_states += (hidden_states,)
1519
+
1520
+ next_cache = next_decoder_cache if use_cache else None
1521
+ if not return_dict:
1522
+ return tuple(
1523
+ v
1524
+ for v in [
1525
+ hidden_states,
1526
+ next_cache,
1527
+ all_hidden_states,
1528
+ all_self_attns,
1529
+ all_cross_attentions,
1530
+ ]
1531
+ if v is not None
1532
+ )
1533
+ return BaseModelOutputWithPastAndCrossAttentions(
1534
+ last_hidden_state=hidden_states,
1535
+ past_key_values=next_cache,
1536
+ hidden_states=all_hidden_states,
1537
+ attentions=all_self_attns,
1538
+ cross_attentions=all_cross_attentions,
1539
+ )
1540
+
1541
+
1542
+ # Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100Model->IndicTrans
1543
+ class IndicTransModel(IndicTransPreTrainedModel):
1544
+ _tied_weights_keys = None
1545
+
1546
+ def __init__(self, config: IndicTransConfig):
1547
+ super().__init__(config)
1548
+
1549
+ self.encoder = IndicTransEncoder(config)
1550
+ self.decoder = IndicTransDecoder(config)
1551
+
1552
+ # Initialize weights and apply final processing
1553
+ self.post_init()
1554
+
1555
+ def get_encoder(self):
1556
+ return self.encoder
1557
+
1558
+ def get_decoder(self):
1559
+ return self.decoder
1560
+
1561
+ def forward(
1562
+ self,
1563
+ input_ids: Optional[torch.LongTensor] = None,
1564
+ attention_mask: Optional[torch.Tensor] = None,
1565
+ decoder_input_ids: Optional[torch.LongTensor] = None,
1566
+ decoder_attention_mask: Optional[torch.LongTensor] = None,
1567
+ head_mask: Optional[torch.Tensor] = None,
1568
+ decoder_head_mask: Optional[torch.Tensor] = None,
1569
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
1570
+ encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
1571
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
1572
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1573
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
1574
+ use_cache: Optional[bool] = None,
1575
+ output_attentions: Optional[bool] = None,
1576
+ output_hidden_states: Optional[bool] = None,
1577
+ return_dict: Optional[bool] = None,
1578
+ ) -> Union[Tuple[torch.Tensor], Seq2SeqModelOutput]:
1579
+ output_attentions = (
1580
+ output_attentions
1581
+ if output_attentions is not None
1582
+ else self.config.output_attentions
1583
+ )
1584
+ output_hidden_states = (
1585
+ output_hidden_states
1586
+ if output_hidden_states is not None
1587
+ else self.config.output_hidden_states
1588
+ )
1589
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1590
+ return_dict = (
1591
+ return_dict if return_dict is not None else self.config.use_return_dict
1592
+ )
1593
+
1594
+ if encoder_outputs is None:
1595
+ encoder_outputs = self.encoder(
1596
+ input_ids=input_ids,
1597
+ attention_mask=attention_mask,
1598
+ head_mask=head_mask,
1599
+ inputs_embeds=inputs_embeds,
1600
+ output_attentions=output_attentions,
1601
+ output_hidden_states=output_hidden_states,
1602
+ return_dict=return_dict,
1603
+ )
1604
+ # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
1605
+ elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
1606
+ encoder_outputs = BaseModelOutput(
1607
+ last_hidden_state=encoder_outputs[0],
1608
+ hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
1609
+ attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
1610
+ )
1611
+
1612
+ # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
1613
+ decoder_outputs = self.decoder(
1614
+ input_ids=decoder_input_ids,
1615
+ attention_mask=decoder_attention_mask,
1616
+ encoder_hidden_states=encoder_outputs[0],
1617
+ encoder_attention_mask=attention_mask,
1618
+ head_mask=decoder_head_mask,
1619
+ cross_attn_head_mask=cross_attn_head_mask,
1620
+ past_key_values=past_key_values,
1621
+ inputs_embeds=decoder_inputs_embeds,
1622
+ use_cache=use_cache,
1623
+ output_attentions=output_attentions,
1624
+ output_hidden_states=output_hidden_states,
1625
+ return_dict=return_dict,
1626
+ )
1627
+
1628
+ if not return_dict:
1629
+ return decoder_outputs + encoder_outputs
1630
+
1631
+ return Seq2SeqModelOutput(
1632
+ last_hidden_state=decoder_outputs.last_hidden_state,
1633
+ past_key_values=decoder_outputs.past_key_values,
1634
+ decoder_hidden_states=decoder_outputs.hidden_states,
1635
+ decoder_attentions=decoder_outputs.attentions,
1636
+ cross_attentions=decoder_outputs.cross_attentions,
1637
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
1638
+ encoder_hidden_states=encoder_outputs.hidden_states,
1639
+ encoder_attentions=encoder_outputs.attentions,
1640
+ )
1641
+
1642
+
1643
+ # Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100ForConditionalGeneration->IndicTrans
1644
+ class IndicTransForConditionalGeneration(IndicTransPreTrainedModel):
1645
+ base_model_prefix = "model"
1646
+ _tied_weights_keys = None
1647
+ _label_smoothing = 0.0
1648
+
1649
+ def __init__(self, config: IndicTransConfig):
1650
+ super().__init__(config)
1651
+ self.model = IndicTransModel(config)
1652
+ self.lm_head = nn.Linear(
1653
+ config.decoder_embed_dim, config.decoder_vocab_size, bias=False
1654
+ )
1655
+
1656
+ if config.share_decoder_input_output_embed:
1657
+ self.lm_head.weight = self.model.decoder.embed_tokens.weight
1658
+
1659
+ self.post_init()
1660
+
1661
+ def tie_weights(self):
1662
+ pass
1663
+
1664
+ def get_encoder(self):
1665
+ return self.model.get_encoder()
1666
+
1667
+ def get_decoder(self):
1668
+ return self.model.get_decoder()
1669
+
1670
+ def get_output_embeddings(self):
1671
+ return self.lm_head
1672
+
1673
+ def set_output_embeddings(self, new_embeddings):
1674
+ self.lm_head = new_embeddings
1675
+
1676
+ def set_label_smoothing(self, label_smoothing):
1677
+ self._label_smoothing = label_smoothing
1678
+
1679
+ def forward(
1680
+ self,
1681
+ input_ids: Optional[torch.LongTensor] = None,
1682
+ attention_mask: Optional[torch.Tensor] = None,
1683
+ decoder_input_ids: Optional[torch.LongTensor] = None,
1684
+ decoder_attention_mask: Optional[torch.LongTensor] = None,
1685
+ head_mask: Optional[torch.Tensor] = None,
1686
+ decoder_head_mask: Optional[torch.Tensor] = None,
1687
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
1688
+ encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
1689
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
1690
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1691
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
1692
+ labels: Optional[torch.LongTensor] = None,
1693
+ use_cache: Optional[bool] = None,
1694
+ output_attentions: Optional[bool] = None,
1695
+ output_hidden_states: Optional[bool] = None,
1696
+ return_dict: Optional[bool] = None,
1697
+ ) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]:
1698
+ r"""
1699
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1700
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1701
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1702
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1703
+
1704
+ Returns:
1705
+ """
1706
+ return_dict = (
1707
+ return_dict if return_dict is not None else self.config.use_return_dict
1708
+ )
1709
+
1710
+ if labels is not None:
1711
+ if decoder_input_ids is None:
1712
+ decoder_input_ids = shift_tokens_right(
1713
+ labels, self.config.pad_token_id, self.config.decoder_start_token_id
1714
+ )
1715
+
1716
+ outputs = self.model(
1717
+ input_ids,
1718
+ attention_mask=attention_mask,
1719
+ decoder_input_ids=decoder_input_ids,
1720
+ encoder_outputs=encoder_outputs,
1721
+ decoder_attention_mask=decoder_attention_mask,
1722
+ head_mask=head_mask,
1723
+ decoder_head_mask=decoder_head_mask,
1724
+ cross_attn_head_mask=cross_attn_head_mask,
1725
+ past_key_values=past_key_values,
1726
+ inputs_embeds=inputs_embeds,
1727
+ decoder_inputs_embeds=decoder_inputs_embeds,
1728
+ use_cache=use_cache,
1729
+ output_attentions=output_attentions,
1730
+ output_hidden_states=output_hidden_states,
1731
+ return_dict=return_dict,
1732
+ )
1733
+ lm_logits = self.lm_head(outputs[0])
1734
+
1735
+ masked_lm_loss = None
1736
+ if labels is not None:
1737
+ # move labels to the correct device to enable PP
1738
+ labels = labels.to(lm_logits.device)
1739
+ masked_lm_loss = F.cross_entropy(
1740
+ input=lm_logits.view(-1, self.config.decoder_vocab_size),
1741
+ target=labels.view(-1),
1742
+ ignore_index=-100,
1743
+ label_smoothing=self._label_smoothing,
1744
+ )
1745
+
1746
+ if not return_dict:
1747
+ output = (lm_logits,) + outputs[1:]
1748
+ return (
1749
+ ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
1750
+ )
1751
+
1752
+ return Seq2SeqLMOutput(
1753
+ loss=masked_lm_loss,
1754
+ logits=lm_logits,
1755
+ past_key_values=outputs.past_key_values,
1756
+ decoder_hidden_states=outputs.decoder_hidden_states,
1757
+ decoder_attentions=outputs.decoder_attentions,
1758
+ cross_attentions=outputs.cross_attentions,
1759
+ encoder_last_hidden_state=outputs.encoder_last_hidden_state,
1760
+ encoder_hidden_states=outputs.encoder_hidden_states,
1761
+ encoder_attentions=outputs.encoder_attentions,
1762
+ )
1763
+
1764
+ def prepare_inputs_for_generation(
1765
+ self,
1766
+ decoder_input_ids,
1767
+ past_key_values=None,
1768
+ attention_mask=None,
1769
+ head_mask=None,
1770
+ decoder_head_mask=None,
1771
+ cross_attn_head_mask=None,
1772
+ use_cache=None,
1773
+ encoder_outputs=None,
1774
+ **kwargs,
1775
+ ):
1776
+ # cut decoder_input_ids if past is used
1777
+ if past_key_values is not None:
1778
+ decoder_input_ids = decoder_input_ids[:, -1:]
1779
+
1780
+ return {
1781
+ "input_ids": None, # encoder_outputs is defined. input_ids not needed
1782
+ "encoder_outputs": encoder_outputs,
1783
+ "past_key_values": past_key_values,
1784
+ "decoder_input_ids": decoder_input_ids,
1785
+ "attention_mask": attention_mask,
1786
+ "head_mask": head_mask,
1787
+ "decoder_head_mask": decoder_head_mask,
1788
+ "cross_attn_head_mask": cross_attn_head_mask,
1789
+ "use_cache": use_cache, # change this to avoid caching (presumably for debugging)
1790
+ }
1791
+
1792
+ @staticmethod
1793
+ def _reorder_cache(past_key_values, beam_idx):
1794
+ reordered_past = ()
1795
+ for layer_past in past_key_values:
1796
+ reordered_past += (
1797
+ tuple(
1798
+ past_state.index_select(0, beam_idx) for past_state in layer_past
1799
+ ),
1800
+ )
1801
+ return reordered_past
special_tokens_map.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<s>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "</s>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": {
17
+ "content": "<pad>",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ },
23
+ "unk_token": {
24
+ "content": "<unk>",
25
+ "lstrip": false,
26
+ "normalized": false,
27
+ "rstrip": false,
28
+ "single_word": false
29
+ }
30
+ }
tokenization_indictrans.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+
4
+ from typing import Dict, List, Optional, Union, Tuple
5
+
6
+ from transformers.utils import logging
7
+ from sentencepiece import SentencePieceProcessor
8
+ from transformers.tokenization_utils import PreTrainedTokenizer
9
+
10
+
11
+ logger = logging.get_logger(__name__)
12
+
13
+ SPIECE_UNDERLINE = "▁"
14
+
15
+ SPECIAL_TAGS = {
16
+ "_bt_",
17
+ "_ft_",
18
+ "asm_Beng",
19
+ "awa_Deva",
20
+ "ben_Beng",
21
+ "bho_Deva",
22
+ "brx_Deva",
23
+ "doi_Deva",
24
+ "eng_Latn",
25
+ "gom_Deva",
26
+ "gon_Deva",
27
+ "guj_Gujr",
28
+ "hin_Deva",
29
+ "hne_Deva",
30
+ "kan_Knda",
31
+ "kas_Arab",
32
+ "kas_Deva",
33
+ "kha_Latn",
34
+ "lus_Latn",
35
+ "mag_Deva",
36
+ "mai_Deva",
37
+ "mal_Mlym",
38
+ "mar_Deva",
39
+ "mni_Beng",
40
+ "mni_Mtei",
41
+ "npi_Deva",
42
+ "ory_Orya",
43
+ "pan_Guru",
44
+ "san_Deva",
45
+ "sat_Olck",
46
+ "snd_Arab",
47
+ "snd_Deva",
48
+ "tam_Taml",
49
+ "tel_Telu",
50
+ "urd_Arab",
51
+ "unr_Deva",
52
+ }
53
+
54
+ VOCAB_FILES_NAMES = {
55
+ "src_vocab_fp": "dict.SRC.json",
56
+ "tgt_vocab_fp": "dict.TGT.json",
57
+ "src_spm_fp": "model.SRC",
58
+ "tgt_spm_fp": "model.TGT",
59
+ }
60
+
61
+
62
+ class IndicTransTokenizer(PreTrainedTokenizer):
63
+ _added_tokens_encoder = {}
64
+ _added_tokens_decoder = {}
65
+
66
+ vocab_files_names = VOCAB_FILES_NAMES
67
+ model_input_names = ["input_ids", "attention_mask"]
68
+
69
+ def __init__(
70
+ self,
71
+ src_vocab_fp=None,
72
+ tgt_vocab_fp=None,
73
+ src_spm_fp=None,
74
+ tgt_spm_fp=None,
75
+ unk_token="<unk>",
76
+ bos_token="<s>",
77
+ eos_token="</s>",
78
+ pad_token="<pad>",
79
+ do_lower_case=False,
80
+ **kwargs,
81
+ ):
82
+
83
+ self.src = True
84
+
85
+ self.src_vocab_fp = src_vocab_fp
86
+ self.tgt_vocab_fp = tgt_vocab_fp
87
+ self.src_spm_fp = src_spm_fp
88
+ self.tgt_spm_fp = tgt_spm_fp
89
+
90
+ self.unk_token = unk_token.content
91
+ self.pad_token = pad_token.content
92
+ self.eos_token = eos_token.content
93
+ self.bos_token = bos_token.content
94
+
95
+ self.encoder = self._load_json(self.src_vocab_fp)
96
+ if self.unk_token not in self.encoder:
97
+ raise KeyError("<unk> token must be in vocab")
98
+ assert self.pad_token in self.encoder
99
+ self.encoder_rev = {v: k for k, v in self.encoder.items()}
100
+
101
+ self.decoder = self._load_json(self.tgt_vocab_fp)
102
+ if self.unk_token not in self.encoder:
103
+ raise KeyError("<unk> token must be in vocab")
104
+ assert self.pad_token in self.encoder
105
+ self.decoder_rev = {v: k for k, v in self.decoder.items()}
106
+
107
+ # load SentencePiece model for pre-processing
108
+ self.src_spm = self._load_spm(self.src_spm_fp)
109
+ self.tgt_spm = self._load_spm(self.tgt_spm_fp)
110
+
111
+ self.current_spm = self.src_spm
112
+ self.current_encoder = self.encoder
113
+ self.current_encoder_rev = self.encoder_rev
114
+
115
+ self.unk_token_id = self.encoder[self.unk_token]
116
+ self.pad_token_id = self.encoder[self.pad_token]
117
+ self.eos_token_id = self.encoder[self.eos_token]
118
+ self.bos_token_id = self.encoder[self.bos_token]
119
+
120
+ super().__init__(
121
+ src_vocab_file=self.src_vocab_fp,
122
+ tgt_vocab_file=self.src_vocab_fp,
123
+ do_lower_case=do_lower_case,
124
+ unk_token=unk_token,
125
+ bos_token=bos_token,
126
+ eos_token=eos_token,
127
+ pad_token=pad_token,
128
+ **kwargs,
129
+ )
130
+
131
+ def add_new_special_tags(self, new_tags: List[str]):
132
+ SPECIAL_TAGS.update(new_tags)
133
+
134
+ def _switch_to_input_mode(self):
135
+ self.src = True
136
+ self.padding_side = "left"
137
+ self.current_spm = self.src_spm
138
+ self.current_encoder = self.encoder
139
+ self.current_encoder_rev = self.encoder_rev
140
+
141
+ def _switch_to_target_mode(self):
142
+ self.src = False
143
+ self.padding_side = "right"
144
+ self.current_spm = self.tgt_spm
145
+ self.current_encoder = self.decoder
146
+ self.current_encoder_rev = self.decoder_rev
147
+
148
+ def _load_spm(self, path: str) -> SentencePieceProcessor:
149
+ return SentencePieceProcessor(model_file=path)
150
+
151
+ def _save_json(self, data, path: str) -> None:
152
+ with open(path, "w", encoding="utf-8") as f:
153
+ json.dump(data, f, indent=2)
154
+
155
+ def _load_json(self, path: str) -> Union[Dict, List]:
156
+ with open(path, "r", encoding="utf-8") as f:
157
+ return json.load(f)
158
+
159
+ def _split_tags(self, tokens: List[str]) -> Tuple[List[str], List[str]]:
160
+ tags = [token for token in tokens if token in SPECIAL_TAGS]
161
+ tokens = [token for token in tokens if token not in SPECIAL_TAGS]
162
+ return tags, tokens
163
+
164
+ def _split_pads(self, tokens: List[str]) -> Tuple[List[str], List[str]]:
165
+ pads = [token for token in tokens if token == self.pad_token]
166
+ tokens = [token for token in tokens if token != self.pad_token]
167
+ return pads, tokens
168
+
169
+ @property
170
+ def src_vocab_size(self) -> int:
171
+ return len(self.encoder)
172
+
173
+ @property
174
+ def tgt_vocab_size(self) -> int:
175
+ return len(self.decoder)
176
+
177
+ def get_src_vocab(self) -> Dict[str, int]:
178
+ return dict(self.encoder, **self.added_tokens_encoder)
179
+
180
+ def get_tgt_vocab(self) -> Dict[str, int]:
181
+ return dict(self.decoder, **self.added_tokens_decoder)
182
+
183
+ # hack override
184
+ def get_vocab(self) -> Dict[str, int]:
185
+ return self.get_src_vocab()
186
+
187
+ # hack override
188
+ @property
189
+ def vocab_size(self) -> int:
190
+ return self.src_vocab_size
191
+
192
+ def _convert_token_to_id(self, token: str) -> int:
193
+ """Converts an token (str) into an index (integer) using the source/target vocabulary map."""
194
+ return self.current_encoder.get(token, self.current_encoder[self.unk_token])
195
+
196
+ def _convert_id_to_token(self, index: int) -> str:
197
+ """Converts an index (integer) into a token (str) using the source/target vocabulary map."""
198
+ return self.current_encoder_rev.get(index, self.unk_token)
199
+
200
+ def convert_tokens_to_string(self, tokens: List[str]) -> str:
201
+ """Uses sentencepiece model for detokenization"""
202
+ pads, tokens = self._split_pads(tokens)
203
+
204
+ if self.src:
205
+
206
+ tags, non_tags = self._split_tags(tokens)
207
+
208
+ return (
209
+ " ".join(pads)
210
+ + " "
211
+ + " ".join(tags)
212
+ + " "
213
+ + "".join(non_tags).replace(SPIECE_UNDERLINE, " ").strip()
214
+ )
215
+
216
+ return (
217
+ "".join(tokens).replace(SPIECE_UNDERLINE, " ").strip()
218
+ + " "
219
+ + " ".join(pads)
220
+ )
221
+
222
+ def _tokenize(self, text) -> List[str]:
223
+ if self.src:
224
+ tokens = text.split(" ")
225
+ tags, non_tags = self._split_tags(tokens)
226
+ text = " ".join(non_tags)
227
+ tokens = self.current_spm.EncodeAsPieces(text)
228
+ return tags + tokens
229
+ else:
230
+ return self.current_spm.EncodeAsPieces(text)
231
+
232
+ def build_inputs_with_special_tokens(
233
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
234
+ ) -> List[int]:
235
+ if token_ids_1 is None:
236
+ return token_ids_0 + [self.eos_token_id]
237
+ # We don't expect to process pairs, but leave the pair logic for API consistency
238
+ return token_ids_0 + [self.eos_token_id] + token_ids_1 + [self.eos_token_id]
239
+
240
+ def save_vocabulary(
241
+ self, save_directory: str, filename_prefix: Optional[str] = None
242
+ ) -> Tuple[str]:
243
+ if not os.path.isdir(save_directory):
244
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
245
+ return
246
+
247
+ src_spm_fp = os.path.join(save_directory, "model.SRC")
248
+ tgt_spm_fp = os.path.join(save_directory, "model.TGT")
249
+ src_vocab_fp = os.path.join(save_directory, "dict.SRC.json")
250
+ tgt_vocab_fp = os.path.join(save_directory, "dict.TGT.json")
251
+
252
+ self._save_json(self.encoder, src_vocab_fp)
253
+ self._save_json(self.decoder, tgt_vocab_fp)
254
+
255
+ with open(src_spm_fp, "wb") as f:
256
+ f.write(self.src_spm.serialized_model_proto())
257
+
258
+ with open(tgt_spm_fp, "wb") as f:
259
+ f.write(self.tgt_spm.serialized_model_proto())
260
+
261
+ return src_vocab_fp, tgt_vocab_fp, src_spm_fp, tgt_spm_fp
tokenizer_config.json ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "<s>",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "1": {
12
+ "content": "<pad>",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "2": {
20
+ "content": "</s>",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ },
27
+ "3": {
28
+ "content": "<unk>",
29
+ "lstrip": false,
30
+ "normalized": false,
31
+ "rstrip": false,
32
+ "single_word": false,
33
+ "special": true
34
+ }
35
+ },
36
+ "auto_map": {
37
+ "AutoTokenizer": [
38
+ "tokenization_indictrans.IndicTransTokenizer",
39
+ null
40
+ ]
41
+ },
42
+ "bos_token": "<s>",
43
+ "clean_up_tokenization_spaces": true,
44
+ "do_lower_case": false,
45
+ "eos_token": "</s>",
46
+ "model_max_length": 256,
47
+ "pad_token": "<pad>",
48
+ "src_vocab_file": "EnIndicNETokeniser/dict.SRC.json",
49
+ "tgt_vocab_file": "EnIndicNETokeniser/dict.SRC.json",
50
+ "tokenizer_class": "IndicTransTokenizer",
51
+ "unk_token": "<unk>"
52
+ }
training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0896fc3439f499c79631080eecce2776e8c9e1cbf20378003c3135c92b893fcb
3
+ size 5368