Voxtral-Mini-3B-2507-FP8-dynamic / convert_voxtral_hf_to_mistral.py
alexmarques's picture
Update convert_voxtral_hf_to_mistral.py
a8da9ff verified
# coding=utf-8
# Copyright 2025 HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import gc
import json
import os
import re
from safetensors.torch import save_file
from safetensors.torch import safe_open
from huggingface_hub import snapshot_download
from transformers import VoxtralConfig
# fmt: off
STATE_DICT_MAPPING = {
r"^language_model\.lm_head": r"output",
r"^language_model\.model\.norm": r"norm",
r"^language_model\.model\.embed_tokens": r"tok_embeddings",
r"^language_model\.model\.layers\.(\d+)\.input_layernorm": r"layers.\1.attention_norm",
r"^language_model\.model\.layers\.(\d+)\.post_attention_layernorm": r"layers.\1.ffn_norm",
r"^language_model\.model\.layers\.(\d+)\.self_attn\.(q|k|v|o)_proj": r"layers.\1.attention.w\2",
r"^language_model\.model\.layers\.(\d+)\.mlp\.gate_proj": r"layers.\1.feed_forward.w1",
r"^language_model\.model\.layers\.(\d+)\.mlp\.down_proj": r"layers.\1.feed_forward.w2",
r"^language_model\.model\.layers\.(\d+)\.mlp\.up_proj": r"layers.\1.feed_forward.w3",
r"language_model.model.embed_tokens": r"tok_embeddings",
r"audio_tower.conv1": r"mm_whisper_embeddings.whisper_encoder.conv_layers.0" ,
r"audio_tower.conv2": r"mm_whisper_embeddings.whisper_encoder.conv_layers.1" ,
r"audio_tower.layer_norm": r"mm_whisper_embeddings.whisper_encoder.transformer.norm" ,
r"audio_tower.layers.(\d+).self_attn.(q|k|v)_proj": r"mm_whisper_embeddings.whisper_encoder.transformer.layers.\1.attention.w\2" ,
r"audio_tower.layers.(\d+).self_attn.out_proj": r"mm_whisper_embeddings.whisper_encoder.transformer.layers.\1.attention.wo" ,
r"audio_tower.layers.(\d+).self_attn_layer_norm": r"mm_whisper_embeddings.whisper_encoder.transformer.layers.\1.attention_norm" ,
r"audio_tower.layers.(\d+).fc(\d+)": r"mm_whisper_embeddings.whisper_encoder.transformer.layers.\1.feed_forward.w\2" ,
r"audio_tower.layers.(\d+).final_layer_norm": r"mm_whisper_embeddings.whisper_encoder.transformer.layers.\1.ffn_norm" ,
r"multi_modal_projector.linear_1": r"mm_whisper_embeddings.audio_language_projection.0" ,
r"multi_modal_projector.linear_2": r"mm_whisper_embeddings.audio_language_projection.2" ,
}
# fmt: on
SKIP_KEYS = ["audio_tower.embed_positions.weight"]
def add_quantization_config(config, hf_config: VoxtralConfig):
quantization_config = hf_config.quantization_config
mistral_ignore = [] # keys to ignore in the quantization config
for hf_key in quantization_config["ignore"]:
mistral_key = map_hf_key_to_mistral(hf_key)
mistral_ignore.append(mistral_key)
quantization_config["ignore"] = mistral_ignore
config["quantization"] = quantization_config
return config
def map_hf_key_to_mistral(hf_key):
"""Map a key from HF format to Mistral format"""
for pattern, replacement in STATE_DICT_MAPPING.items():
new_key, n_replace = re.subn(pattern, replacement, hf_key)
if n_replace > 0:
return new_key.replace("weight_scale", "qscale_weight")
# If no mapping found, return the original key
return hf_key.replace("weight_scale", "qscale_weight")
def permute_for_mistral_rope(tensor, n_heads, dim1, dim2):
"""Reverse the ROPE permutation to get back to Mistral format."""
tensor = tensor.view(n_heads, 2, dim1 // n_heads // 2, dim2)
tensor = tensor.transpose(1, 2)
tensor = tensor.reshape(dim1, dim2)
return tensor
def convert_state_dict(hf_state_dict, config):
"""Convert HF Voxtral state dict to Mistral format"""
mistral_dict = {}
num_attention_heads = config["n_heads"]
hidden_size = config["dim"]
head_dim = config["head_dim"]
num_key_value_heads = config["n_kv_heads"]
key_value_dim = head_dim * num_key_value_heads
query_dim = head_dim * num_attention_heads
for hf_key, tensor in hf_state_dict.items():
if hf_key in SKIP_KEYS:
continue
mistral_key = map_hf_key_to_mistral(hf_key)
if "language_model" in hf_key:
if hf_key.endswith("q_proj.weight"):
tensor = permute_for_mistral_rope(tensor, num_attention_heads, query_dim, hidden_size)
elif hf_key.endswith("q_proj.weight_scale") and tensor.size(0) == num_attention_heads:
tensor = permute_for_mistral_rope(tensor, num_attention_heads, query_dim, 1)
elif hf_key.endswith("k_proj.weight"):
tensor = permute_for_mistral_rope(tensor, num_key_value_heads, key_value_dim, hidden_size)
elif hf_key.endswith("k_proj.weight_scale") and tensor.size(0) == num_key_value_heads:
tensor = permute_for_mistral_rope(tensor, num_key_value_heads, key_value_dim, 1)
mistral_dict[mistral_key] = tensor
return mistral_dict
def write_model(
input_path_or_repo,
output_dir,
unquantized_model_path=None,
):
print("Converting HF Voxtral model to Mistral format.")
os.makedirs(output_dir, exist_ok=True)
# Load the HF Voxtral model
print(f"Loading HF Voxtral model from {input_path_or_repo}...")
hf_config = VoxtralConfig.from_pretrained(input_path_or_repo)
local_path = snapshot_download(input_path_or_repo)
# Convert config
if unquantized_model_path is not None:
if os.path.exists(unquantized_model_path):
unquantized_model_path = unquantized_model_path
else:
unquantized_model_path = snapshot_download(unquantized_model_path)
config_path = os.path.join(unquantized_model_path, "params.json")
with open(config_path, "r") as f:
config = json.load(f)
config = add_quantization_config(config, hf_config)
with open(os.path.join(output_dir, "params.json"), "w") as f:
json.dump(config, f, indent=2)
else:
raise ValueError(f"Unquantized model config not found for {unquantized_model_path}")
# Convert state dict
print("Converting state dict...")
tensor_files = sorted([f for f in os.listdir(os.path.join(local_path)) if f.endswith(".safetensors")])
hf_state_dict = {}
for file in tensor_files:
file_path = os.path.join(local_path, file)
with safe_open(file_path, framework="pt", device="cuda") as f:
for key in f.keys():
hf_state_dict[key] = f.get_tensor(key)
mistral_state_dict = convert_state_dict(hf_state_dict, config)
# save the state dict
save_file(mistral_state_dict, os.path.join(output_dir, "consolidated.safetensors"))
del hf_state_dict, mistral_state_dict
gc.collect()
print("Model converted successfully.")
def write_tokenizer(input_path_or_repo: str, output_dir: str):
"""Extract and save the tokenizer from Voxtral model"""
from transformers import MistralCommonTokenizer
print("Extracting tokenizer...")
tokenizer = MistralCommonTokenizer.from_pretrained(input_path_or_repo)
tokenizer.save_pretrained(output_dir)
print("Tokenizer saved successfully.")
def main():
parser = argparse.ArgumentParser(description="Convert HF Voxtral weights to Mistral format")
parser.add_argument(
"--input_path_or_repo",
type=str,
default="RedHatAI/Voxtral-Mini-3B-2507-FP8-dynamic",
help="Path or repo containing HF Voxtral model",
)
parser.add_argument(
"--output_dir",
type=str,
default="Voxtral-Mini-3B-2507-FP8-dynamic-converted",
help="Location to write Mistral model and tokenizer",
)
parser.add_argument(
"--skip_tokenizer",
action="store_true",
help="Skip tokenizer conversion"
)
parser.add_argument(
"--unquantized_model_path",
type=str,
default="mistralai/Voxtral-Mini-3B-2507",
help="Path to the unquantized model",
)
args = parser.parse_args()
write_model(
args.input_path_or_repo,
args.output_dir,
unquantized_model_path=args.unquantized_model_path,
)
if not args.skip_tokenizer:
write_tokenizer(
args.input_path_or_repo,
args.output_dir,
)
if __name__ == "__main__":
main()