File size: 9,256 Bytes
0bdcaf2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a8da9ff
 
 
 
 
 
 
 
 
0bdcaf2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
# coding=utf-8
# Copyright 2025 HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import gc
import json
import os
import re

from safetensors.torch import save_file
from safetensors.torch import safe_open
from huggingface_hub import snapshot_download

from transformers import VoxtralConfig

# fmt: off
STATE_DICT_MAPPING = {
    r"^language_model\.lm_head":                                         r"output",
    r"^language_model\.model\.norm":                                     r"norm",
    r"^language_model\.model\.embed_tokens":                             r"tok_embeddings",
    r"^language_model\.model\.layers\.(\d+)\.input_layernorm":           r"layers.\1.attention_norm",
    r"^language_model\.model\.layers\.(\d+)\.post_attention_layernorm":  r"layers.\1.ffn_norm",
    r"^language_model\.model\.layers\.(\d+)\.self_attn\.(q|k|v|o)_proj": r"layers.\1.attention.w\2",
    r"^language_model\.model\.layers\.(\d+)\.mlp\.gate_proj":            r"layers.\1.feed_forward.w1",
    r"^language_model\.model\.layers\.(\d+)\.mlp\.down_proj":            r"layers.\1.feed_forward.w2",
    r"^language_model\.model\.layers\.(\d+)\.mlp\.up_proj":              r"layers.\1.feed_forward.w3",
    r"language_model.model.embed_tokens":                                r"tok_embeddings",
    r"audio_tower.conv1":                                                r"mm_whisper_embeddings.whisper_encoder.conv_layers.0" ,
    r"audio_tower.conv2":                                                r"mm_whisper_embeddings.whisper_encoder.conv_layers.1" ,
    r"audio_tower.layer_norm":                                           r"mm_whisper_embeddings.whisper_encoder.transformer.norm" ,
    r"audio_tower.layers.(\d+).self_attn.(q|k|v)_proj":                  r"mm_whisper_embeddings.whisper_encoder.transformer.layers.\1.attention.w\2" ,
    r"audio_tower.layers.(\d+).self_attn.out_proj":                      r"mm_whisper_embeddings.whisper_encoder.transformer.layers.\1.attention.wo" ,
    r"audio_tower.layers.(\d+).self_attn_layer_norm":                    r"mm_whisper_embeddings.whisper_encoder.transformer.layers.\1.attention_norm" ,
    r"audio_tower.layers.(\d+).fc(\d+)":                                 r"mm_whisper_embeddings.whisper_encoder.transformer.layers.\1.feed_forward.w\2" ,
    r"audio_tower.layers.(\d+).final_layer_norm":                        r"mm_whisper_embeddings.whisper_encoder.transformer.layers.\1.ffn_norm" ,
    r"multi_modal_projector.linear_1":                                   r"mm_whisper_embeddings.audio_language_projection.0" ,
    r"multi_modal_projector.linear_2":                                   r"mm_whisper_embeddings.audio_language_projection.2" ,
}
# fmt: on

SKIP_KEYS = ["audio_tower.embed_positions.weight"]

def add_quantization_config(config, hf_config: VoxtralConfig):
    quantization_config = hf_config.quantization_config
    mistral_ignore = [] # keys to ignore in the quantization config
    for hf_key in quantization_config["ignore"]:
        mistral_key = map_hf_key_to_mistral(hf_key)
        mistral_ignore.append(mistral_key)
    quantization_config["ignore"] = mistral_ignore
    config["quantization"] = quantization_config

    return config

def map_hf_key_to_mistral(hf_key):
    """Map a key from HF format to Mistral format"""
    for pattern, replacement in STATE_DICT_MAPPING.items():
        new_key, n_replace = re.subn(pattern, replacement, hf_key)
        if n_replace > 0:
            return new_key.replace("weight_scale", "qscale_weight")
    
    # If no mapping found, return the original key
    return hf_key.replace("weight_scale", "qscale_weight")


def permute_for_mistral_rope(tensor, n_heads, dim1, dim2):
    """Reverse the ROPE permutation to get back to Mistral format."""
    tensor = tensor.view(n_heads, 2, dim1 // n_heads // 2, dim2)
    tensor = tensor.transpose(1, 2)
    tensor = tensor.reshape(dim1, dim2)
    return tensor


def convert_state_dict(hf_state_dict, config):
    """Convert HF Voxtral state dict to Mistral format"""
    mistral_dict = {}
    
    num_attention_heads = config["n_heads"]
    hidden_size = config["dim"]
    head_dim = config["head_dim"]
    num_key_value_heads = config["n_kv_heads"]
    key_value_dim = head_dim * num_key_value_heads
    query_dim = head_dim * num_attention_heads
    
    for hf_key, tensor in hf_state_dict.items():
        if hf_key in SKIP_KEYS:
            continue
        
        mistral_key = map_hf_key_to_mistral(hf_key)
            
        if "language_model" in hf_key:
            if hf_key.endswith("q_proj.weight"):
                tensor = permute_for_mistral_rope(tensor, num_attention_heads, query_dim, hidden_size)
            elif hf_key.endswith("q_proj.weight_scale") and tensor.size(0) == num_attention_heads:
                tensor = permute_for_mistral_rope(tensor, num_attention_heads, query_dim, 1)
            elif hf_key.endswith("k_proj.weight"):
                tensor = permute_for_mistral_rope(tensor, num_key_value_heads, key_value_dim, hidden_size)
            elif hf_key.endswith("k_proj.weight_scale") and tensor.size(0) == num_key_value_heads:
                tensor = permute_for_mistral_rope(tensor, num_key_value_heads, key_value_dim, 1)

        mistral_dict[mistral_key] = tensor
       
    return mistral_dict


def write_model(
    input_path_or_repo,
    output_dir,
    unquantized_model_path=None,
):
    print("Converting HF Voxtral model to Mistral format.")
    os.makedirs(output_dir, exist_ok=True)

    # Load the HF Voxtral model
    print(f"Loading HF Voxtral model from {input_path_or_repo}...")
    hf_config = VoxtralConfig.from_pretrained(input_path_or_repo)

    local_path = snapshot_download(input_path_or_repo)
    
    # Convert config
    if unquantized_model_path is not None:
        if os.path.exists(unquantized_model_path):
            unquantized_model_path = unquantized_model_path
        else:
            unquantized_model_path = snapshot_download(unquantized_model_path)
        config_path = os.path.join(unquantized_model_path, "params.json")
        with open(config_path, "r") as f:
            config = json.load(f)
        config = add_quantization_config(config, hf_config)

        with open(os.path.join(output_dir, "params.json"), "w") as f:
            json.dump(config, f, indent=2)
    else:
        raise ValueError(f"Unquantized model config not found for {unquantized_model_path}")
    
    # Convert state dict
    print("Converting state dict...")
    tensor_files = sorted([f for f in os.listdir(os.path.join(local_path)) if f.endswith(".safetensors")])

    hf_state_dict = {}

    for file in tensor_files:
        file_path = os.path.join(local_path, file)
        with safe_open(file_path, framework="pt", device="cuda") as f:
            for key in f.keys():
                hf_state_dict[key] = f.get_tensor(key)    

    mistral_state_dict = convert_state_dict(hf_state_dict, config)
   
    # save the state dict
    save_file(mistral_state_dict, os.path.join(output_dir, "consolidated.safetensors"))

    del hf_state_dict, mistral_state_dict
    gc.collect()
    print("Model converted successfully.")

def write_tokenizer(input_path_or_repo: str, output_dir: str):
    """Extract and save the tokenizer from Voxtral model"""
    from transformers import MistralCommonTokenizer
    
    print("Extracting tokenizer...")
    tokenizer = MistralCommonTokenizer.from_pretrained(input_path_or_repo)
    tokenizer.save_pretrained(output_dir)
    print("Tokenizer saved successfully.")


def main():
    parser = argparse.ArgumentParser(description="Convert HF Voxtral weights to Mistral format")
    parser.add_argument(
        "--input_path_or_repo",
        type=str,
        default="RedHatAI/Voxtral-Mini-3B-2507-FP8-dynamic",
        help="Path or repo containing HF Voxtral model",
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        default="Voxtral-Mini-3B-2507-FP8-dynamic-converted",
        help="Location to write Mistral model and tokenizer",
    )
    parser.add_argument(
        "--skip_tokenizer",
        action="store_true",
        help="Skip tokenizer conversion"
    )
    parser.add_argument(
        "--unquantized_model_path",
        type=str,
        default="mistralai/Voxtral-Mini-3B-2507",
        help="Path to the unquantized model",
    )
    args = parser.parse_args()

    write_model(
        args.input_path_or_repo,
        args.output_dir,
        unquantized_model_path=args.unquantized_model_path,
    )

    if not args.skip_tokenizer:
        write_tokenizer(
            args.input_path_or_repo,
            args.output_dir,
        )


if __name__ == "__main__":
    main()