victan commited on
Commit
719e3d4
·
1 Parent(s): d6ab6ec

Upload seamless_communication/models/monotonic_decoder/loader.py with huggingface_hub

Browse files
seamless_communication/models/monotonic_decoder/loader.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # MIT_LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import Any, Mapping
8
+
9
+ import torch
10
+ from fairseq2.assets import asset_store, download_manager
11
+ from fairseq2.models.utils import ConfigLoader, ModelLoader
12
+ from fairseq2.models.utils.checkpoint import convert_fairseq_checkpoint
13
+
14
+ from seamless_communication.models.monotonic_decoder.builder import (
15
+ MonotonicDecoderConfig,
16
+ create_monotonic_decoder_model,
17
+ monotonic_decoder_archs,
18
+ )
19
+ from seamless_communication.models.monotonic_decoder.model import MonotonicDecoderModel
20
+
21
+
22
+ def convert_monotonic_checkpoint(
23
+ checkpoint: Mapping[str, Any], config: MonotonicDecoderConfig
24
+ ) -> Mapping[str, Any]:
25
+ state_dict = checkpoint["model"]
26
+
27
+ # Check if we have a fairseq2 checkpoint.
28
+ if "text_decoder.layers.0.self_attn.k_proj.weight" in state_dict:
29
+ return checkpoint
30
+
31
+ key_map = {
32
+ # fmt: off
33
+ r"^decoder\.embed_tokens\.": r"text_decoder_frontend.embed.",
34
+ r"^decoder\.layers\.([0-9]+)\.self_attn\.out_proj\.": r"text_decoder.layers.\1.self_attn.output_proj.",
35
+ r"^decoder\.layers\.([0-9]+)\.self_attn\.": r"text_decoder.layers.\1.self_attn.",
36
+ r"^decoder\.layers\.([0-9]+)\.self_attn_layer_norm\.": r"text_decoder.layers.\1.self_attn_layer_norm.",
37
+ r"^decoder\.layers\.([0-9]+)\.encoder_attn\.out_proj\.": r"text_decoder.layers.\1.encoder_decoder_attn.output_proj.",
38
+ r"^decoder\.layers\.([0-9]+)\.encoder_attn\.energy_bias": r"text_decoder.layers.\1.p_choose_layer.energy_bias",
39
+ r"^decoder\.layers\.([0-9]+)\.encoder_attn\.source_energy_layer\.": r"text_decoder.layers.\1.p_choose_layer.k_energy_proj.",
40
+ r"^decoder\.layers\.([0-9]+)\.encoder_attn\.target_energy_layer\.": r"text_decoder.layers.\1.p_choose_layer.q_energy_proj.",
41
+ r"^decoder\.layers\.([0-9]+)\.encoder_attn\.": r"text_decoder.layers.\1.encoder_decoder_attn.",
42
+ r"^decoder\.layers\.([0-9]+)\.encoder_attn_layer_norm\.": r"text_decoder.layers.\1.encoder_decoder_attn_layer_norm.",
43
+ r"^decoder\.layers\.([0-9]+)\.fc1\.": r"text_decoder.layers.\1.ffn.inner_proj.",
44
+ r"^decoder\.layers\.([0-9]+)\.fc2\.": r"text_decoder.layers.\1.ffn.output_proj.",
45
+ r"^decoder\.layers\.([0-9]+)\.final_layer_norm\.": r"text_decoder.layers.\1.ffn_layer_norm.",
46
+ r"^decoder\.layer_norm\.": r"text_decoder.layer_norm.",
47
+ r"^decoder\.output_projection\.": r"final_proj.",
48
+ # fmt: on
49
+ }
50
+
51
+ # Convert to fairseq2.
52
+ checkpoint = convert_fairseq_checkpoint(checkpoint, key_map)
53
+
54
+ state_dict = checkpoint["model"]
55
+
56
+ embeds = state_dict["final_proj.weight"]
57
+
58
+ # fairseq had a bug that accidentally introduced a dummy token in the
59
+ # embedding table of NLLB-100. We just discard it.
60
+ if embeds.size(0) == 256103: # means NLLB-100
61
+ embeds = embeds[:-1]
62
+
63
+ state_dict["final_proj.weight"] = embeds
64
+
65
+ # fairseq checkpoints have duplicate embedding weights. Ensure that we
66
+ # use a single embedding table in fairseq2.
67
+ state_dict["text_decoder_frontend.embed.weight"] = embeds
68
+
69
+ # The embedding positions of the control symbols in fairseq's dict do
70
+ # not match the SentencePiece model of the tokenizer.
71
+ with torch.inference_mode():
72
+ # (BOS, PAD, EOS, UNK) -> (PAD, UNK, BOS, EOS)
73
+ embeds[[0, 1, 2, 3]] = embeds[[1, 3, 0, 2]]
74
+
75
+ return checkpoint
76
+
77
+
78
+ load_monotonic_decoder_config = ConfigLoader[MonotonicDecoderConfig](
79
+ asset_store, monotonic_decoder_archs
80
+ )
81
+
82
+
83
+ load_monotonic_decoder_model = ModelLoader[
84
+ MonotonicDecoderModel, MonotonicDecoderConfig
85
+ ](
86
+ asset_store,
87
+ download_manager,
88
+ load_monotonic_decoder_config,
89
+ create_monotonic_decoder_model,
90
+ convert_monotonic_checkpoint,
91
+ restrict_checkpoints=False,
92
+ )