File size: 4,253 Bytes
719e3d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# MIT_LICENSE file in the root directory of this source tree.

from typing import Any, Mapping

import torch
from fairseq2.assets import asset_store, download_manager
from fairseq2.models.utils import ConfigLoader, ModelLoader
from fairseq2.models.utils.checkpoint import convert_fairseq_checkpoint

from seamless_communication.models.monotonic_decoder.builder import (
    MonotonicDecoderConfig,
    create_monotonic_decoder_model,
    monotonic_decoder_archs,
)
from seamless_communication.models.monotonic_decoder.model import MonotonicDecoderModel


def convert_monotonic_checkpoint(
    checkpoint: Mapping[str, Any], config: MonotonicDecoderConfig
) -> Mapping[str, Any]:
    state_dict = checkpoint["model"]

    # Check if we have a fairseq2 checkpoint.
    if "text_decoder.layers.0.self_attn.k_proj.weight" in state_dict:
        return checkpoint

    key_map = {
        # fmt: off
        r"^decoder\.embed_tokens\.":                                            r"text_decoder_frontend.embed.",
        r"^decoder\.layers\.([0-9]+)\.self_attn\.out_proj\.":                   r"text_decoder.layers.\1.self_attn.output_proj.",
        r"^decoder\.layers\.([0-9]+)\.self_attn\.":                             r"text_decoder.layers.\1.self_attn.",
        r"^decoder\.layers\.([0-9]+)\.self_attn_layer_norm\.":                  r"text_decoder.layers.\1.self_attn_layer_norm.",
        r"^decoder\.layers\.([0-9]+)\.encoder_attn\.out_proj\.":                r"text_decoder.layers.\1.encoder_decoder_attn.output_proj.",
        r"^decoder\.layers\.([0-9]+)\.encoder_attn\.energy_bias":               r"text_decoder.layers.\1.p_choose_layer.energy_bias",
        r"^decoder\.layers\.([0-9]+)\.encoder_attn\.source_energy_layer\.":     r"text_decoder.layers.\1.p_choose_layer.k_energy_proj.",
        r"^decoder\.layers\.([0-9]+)\.encoder_attn\.target_energy_layer\.":     r"text_decoder.layers.\1.p_choose_layer.q_energy_proj.",
        r"^decoder\.layers\.([0-9]+)\.encoder_attn\.":                          r"text_decoder.layers.\1.encoder_decoder_attn.",
        r"^decoder\.layers\.([0-9]+)\.encoder_attn_layer_norm\.":               r"text_decoder.layers.\1.encoder_decoder_attn_layer_norm.",
        r"^decoder\.layers\.([0-9]+)\.fc1\.":                                   r"text_decoder.layers.\1.ffn.inner_proj.",
        r"^decoder\.layers\.([0-9]+)\.fc2\.":                                   r"text_decoder.layers.\1.ffn.output_proj.",
        r"^decoder\.layers\.([0-9]+)\.final_layer_norm\.":                      r"text_decoder.layers.\1.ffn_layer_norm.",
        r"^decoder\.layer_norm\.":                                              r"text_decoder.layer_norm.",
        r"^decoder\.output_projection\.":                                       r"final_proj.",
        # fmt: on
    }

    # Convert to fairseq2.
    checkpoint = convert_fairseq_checkpoint(checkpoint, key_map)

    state_dict = checkpoint["model"]

    embeds = state_dict["final_proj.weight"]

    # fairseq had a bug that accidentally introduced a dummy token in the
    # embedding table of NLLB-100. We just discard it.
    if embeds.size(0) == 256103:  # means NLLB-100
        embeds = embeds[:-1]

        state_dict["final_proj.weight"] = embeds

    # fairseq checkpoints have duplicate embedding weights. Ensure that we
    # use a single embedding table in fairseq2.
    state_dict["text_decoder_frontend.embed.weight"] = embeds

    # The embedding positions of the control symbols in fairseq's dict do
    # not match the SentencePiece model of the tokenizer.
    with torch.inference_mode():
        # (BOS, PAD, EOS, UNK) -> (PAD, UNK, BOS, EOS)
        embeds[[0, 1, 2, 3]] = embeds[[1, 3, 0, 2]]

    return checkpoint


load_monotonic_decoder_config = ConfigLoader[MonotonicDecoderConfig](
    asset_store, monotonic_decoder_archs
)


load_monotonic_decoder_model = ModelLoader[
    MonotonicDecoderModel, MonotonicDecoderConfig
](
    asset_store,
    download_manager,
    load_monotonic_decoder_config,
    create_monotonic_decoder_model,
    convert_monotonic_checkpoint,
    restrict_checkpoints=False,
)