Upload seamless_communication/models/monotonic_decoder/loader.py with huggingface_hub
Browse files
seamless_communication/models/monotonic_decoder/loader.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# MIT_LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from typing import Any, Mapping
|
8 |
+
|
9 |
+
import torch
|
10 |
+
from fairseq2.assets import asset_store, download_manager
|
11 |
+
from fairseq2.models.utils import ConfigLoader, ModelLoader
|
12 |
+
from fairseq2.models.utils.checkpoint import convert_fairseq_checkpoint
|
13 |
+
|
14 |
+
from seamless_communication.models.monotonic_decoder.builder import (
|
15 |
+
MonotonicDecoderConfig,
|
16 |
+
create_monotonic_decoder_model,
|
17 |
+
monotonic_decoder_archs,
|
18 |
+
)
|
19 |
+
from seamless_communication.models.monotonic_decoder.model import MonotonicDecoderModel
|
20 |
+
|
21 |
+
|
22 |
+
def convert_monotonic_checkpoint(
|
23 |
+
checkpoint: Mapping[str, Any], config: MonotonicDecoderConfig
|
24 |
+
) -> Mapping[str, Any]:
|
25 |
+
state_dict = checkpoint["model"]
|
26 |
+
|
27 |
+
# Check if we have a fairseq2 checkpoint.
|
28 |
+
if "text_decoder.layers.0.self_attn.k_proj.weight" in state_dict:
|
29 |
+
return checkpoint
|
30 |
+
|
31 |
+
key_map = {
|
32 |
+
# fmt: off
|
33 |
+
r"^decoder\.embed_tokens\.": r"text_decoder_frontend.embed.",
|
34 |
+
r"^decoder\.layers\.([0-9]+)\.self_attn\.out_proj\.": r"text_decoder.layers.\1.self_attn.output_proj.",
|
35 |
+
r"^decoder\.layers\.([0-9]+)\.self_attn\.": r"text_decoder.layers.\1.self_attn.",
|
36 |
+
r"^decoder\.layers\.([0-9]+)\.self_attn_layer_norm\.": r"text_decoder.layers.\1.self_attn_layer_norm.",
|
37 |
+
r"^decoder\.layers\.([0-9]+)\.encoder_attn\.out_proj\.": r"text_decoder.layers.\1.encoder_decoder_attn.output_proj.",
|
38 |
+
r"^decoder\.layers\.([0-9]+)\.encoder_attn\.energy_bias": r"text_decoder.layers.\1.p_choose_layer.energy_bias",
|
39 |
+
r"^decoder\.layers\.([0-9]+)\.encoder_attn\.source_energy_layer\.": r"text_decoder.layers.\1.p_choose_layer.k_energy_proj.",
|
40 |
+
r"^decoder\.layers\.([0-9]+)\.encoder_attn\.target_energy_layer\.": r"text_decoder.layers.\1.p_choose_layer.q_energy_proj.",
|
41 |
+
r"^decoder\.layers\.([0-9]+)\.encoder_attn\.": r"text_decoder.layers.\1.encoder_decoder_attn.",
|
42 |
+
r"^decoder\.layers\.([0-9]+)\.encoder_attn_layer_norm\.": r"text_decoder.layers.\1.encoder_decoder_attn_layer_norm.",
|
43 |
+
r"^decoder\.layers\.([0-9]+)\.fc1\.": r"text_decoder.layers.\1.ffn.inner_proj.",
|
44 |
+
r"^decoder\.layers\.([0-9]+)\.fc2\.": r"text_decoder.layers.\1.ffn.output_proj.",
|
45 |
+
r"^decoder\.layers\.([0-9]+)\.final_layer_norm\.": r"text_decoder.layers.\1.ffn_layer_norm.",
|
46 |
+
r"^decoder\.layer_norm\.": r"text_decoder.layer_norm.",
|
47 |
+
r"^decoder\.output_projection\.": r"final_proj.",
|
48 |
+
# fmt: on
|
49 |
+
}
|
50 |
+
|
51 |
+
# Convert to fairseq2.
|
52 |
+
checkpoint = convert_fairseq_checkpoint(checkpoint, key_map)
|
53 |
+
|
54 |
+
state_dict = checkpoint["model"]
|
55 |
+
|
56 |
+
embeds = state_dict["final_proj.weight"]
|
57 |
+
|
58 |
+
# fairseq had a bug that accidentally introduced a dummy token in the
|
59 |
+
# embedding table of NLLB-100. We just discard it.
|
60 |
+
if embeds.size(0) == 256103: # means NLLB-100
|
61 |
+
embeds = embeds[:-1]
|
62 |
+
|
63 |
+
state_dict["final_proj.weight"] = embeds
|
64 |
+
|
65 |
+
# fairseq checkpoints have duplicate embedding weights. Ensure that we
|
66 |
+
# use a single embedding table in fairseq2.
|
67 |
+
state_dict["text_decoder_frontend.embed.weight"] = embeds
|
68 |
+
|
69 |
+
# The embedding positions of the control symbols in fairseq's dict do
|
70 |
+
# not match the SentencePiece model of the tokenizer.
|
71 |
+
with torch.inference_mode():
|
72 |
+
# (BOS, PAD, EOS, UNK) -> (PAD, UNK, BOS, EOS)
|
73 |
+
embeds[[0, 1, 2, 3]] = embeds[[1, 3, 0, 2]]
|
74 |
+
|
75 |
+
return checkpoint
|
76 |
+
|
77 |
+
|
78 |
+
load_monotonic_decoder_config = ConfigLoader[MonotonicDecoderConfig](
|
79 |
+
asset_store, monotonic_decoder_archs
|
80 |
+
)
|
81 |
+
|
82 |
+
|
83 |
+
load_monotonic_decoder_model = ModelLoader[
|
84 |
+
MonotonicDecoderModel, MonotonicDecoderConfig
|
85 |
+
](
|
86 |
+
asset_store,
|
87 |
+
download_manager,
|
88 |
+
load_monotonic_decoder_config,
|
89 |
+
create_monotonic_decoder_model,
|
90 |
+
convert_monotonic_checkpoint,
|
91 |
+
restrict_checkpoints=False,
|
92 |
+
)
|