|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
import gc |
|
import json |
|
import os |
|
import re |
|
|
|
from safetensors.torch import save_file |
|
from safetensors.torch import safe_open |
|
from huggingface_hub import snapshot_download |
|
|
|
from transformers import VoxtralConfig |
|
|
|
|
|
STATE_DICT_MAPPING = { |
|
r"^language_model\.lm_head": r"output", |
|
r"^language_model\.model\.norm": r"norm", |
|
r"^language_model\.model\.embed_tokens": r"tok_embeddings", |
|
r"^language_model\.model\.layers\.(\d+)\.input_layernorm": r"layers.\1.attention_norm", |
|
r"^language_model\.model\.layers\.(\d+)\.post_attention_layernorm": r"layers.\1.ffn_norm", |
|
r"^language_model\.model\.layers\.(\d+)\.self_attn\.(q|k|v|o)_proj": r"layers.\1.attention.w\2", |
|
r"^language_model\.model\.layers\.(\d+)\.mlp\.gate_proj": r"layers.\1.feed_forward.w1", |
|
r"^language_model\.model\.layers\.(\d+)\.mlp\.down_proj": r"layers.\1.feed_forward.w2", |
|
r"^language_model\.model\.layers\.(\d+)\.mlp\.up_proj": r"layers.\1.feed_forward.w3", |
|
r"language_model.model.embed_tokens": r"tok_embeddings", |
|
r"audio_tower.conv1": r"mm_whisper_embeddings.whisper_encoder.conv_layers.0" , |
|
r"audio_tower.conv2": r"mm_whisper_embeddings.whisper_encoder.conv_layers.1" , |
|
r"audio_tower.layer_norm": r"mm_whisper_embeddings.whisper_encoder.transformer.norm" , |
|
r"audio_tower.layers.(\d+).self_attn.(q|k|v)_proj": r"mm_whisper_embeddings.whisper_encoder.transformer.layers.\1.attention.w\2" , |
|
r"audio_tower.layers.(\d+).self_attn.out_proj": r"mm_whisper_embeddings.whisper_encoder.transformer.layers.\1.attention.wo" , |
|
r"audio_tower.layers.(\d+).self_attn_layer_norm": r"mm_whisper_embeddings.whisper_encoder.transformer.layers.\1.attention_norm" , |
|
r"audio_tower.layers.(\d+).fc(\d+)": r"mm_whisper_embeddings.whisper_encoder.transformer.layers.\1.feed_forward.w\2" , |
|
r"audio_tower.layers.(\d+).final_layer_norm": r"mm_whisper_embeddings.whisper_encoder.transformer.layers.\1.ffn_norm" , |
|
r"multi_modal_projector.linear_1": r"mm_whisper_embeddings.audio_language_projection.0" , |
|
r"multi_modal_projector.linear_2": r"mm_whisper_embeddings.audio_language_projection.2" , |
|
} |
|
|
|
|
|
SKIP_KEYS = ["audio_tower.embed_positions.weight"] |
|
|
|
def add_quantization_config(config, hf_config: VoxtralConfig): |
|
quantization_config = hf_config.quantization_config |
|
mistral_ignore = [] |
|
for hf_key in quantization_config["ignore"]: |
|
mistral_key = map_hf_key_to_mistral(hf_key) |
|
mistral_ignore.append(mistral_key) |
|
quantization_config["ignore"] = mistral_ignore |
|
config["quantization"] = quantization_config |
|
|
|
return config |
|
|
|
def map_hf_key_to_mistral(hf_key): |
|
"""Map a key from HF format to Mistral format""" |
|
for pattern, replacement in STATE_DICT_MAPPING.items(): |
|
new_key, n_replace = re.subn(pattern, replacement, hf_key) |
|
if n_replace > 0: |
|
return new_key.replace("weight_scale", "qscale_weight") |
|
|
|
|
|
return hf_key.replace("weight_scale", "qscale_weight") |
|
|
|
|
|
def permute_for_mistral_rope(tensor, n_heads, dim1, dim2): |
|
"""Reverse the ROPE permutation to get back to Mistral format.""" |
|
tensor = tensor.view(n_heads, 2, dim1 // n_heads // 2, dim2) |
|
tensor = tensor.transpose(1, 2) |
|
tensor = tensor.reshape(dim1, dim2) |
|
return tensor |
|
|
|
|
|
def convert_state_dict(hf_state_dict, config): |
|
"""Convert HF Voxtral state dict to Mistral format""" |
|
mistral_dict = {} |
|
|
|
num_attention_heads = config["n_heads"] |
|
hidden_size = config["dim"] |
|
head_dim = config["head_dim"] |
|
num_key_value_heads = config["n_kv_heads"] |
|
key_value_dim = head_dim * num_key_value_heads |
|
query_dim = head_dim * num_attention_heads |
|
|
|
for hf_key, tensor in hf_state_dict.items(): |
|
if hf_key in SKIP_KEYS: |
|
continue |
|
|
|
mistral_key = map_hf_key_to_mistral(hf_key) |
|
|
|
if "language_model" in hf_key: |
|
if hf_key.endswith("q_proj.weight"): |
|
tensor = permute_for_mistral_rope(tensor, num_attention_heads, query_dim, hidden_size) |
|
elif hf_key.endswith("q_proj.weight_scale") and tensor.size(0) == num_attention_heads: |
|
tensor = permute_for_mistral_rope(tensor, num_attention_heads, query_dim, 1) |
|
elif hf_key.endswith("k_proj.weight"): |
|
tensor = permute_for_mistral_rope(tensor, num_key_value_heads, key_value_dim, hidden_size) |
|
elif hf_key.endswith("k_proj.weight_scale") and tensor.size(0) == num_key_value_heads: |
|
tensor = permute_for_mistral_rope(tensor, num_key_value_heads, key_value_dim, 1) |
|
|
|
mistral_dict[mistral_key] = tensor |
|
|
|
return mistral_dict |
|
|
|
|
|
def write_model( |
|
input_path_or_repo, |
|
output_dir, |
|
unquantized_model_path=None, |
|
): |
|
print("Converting HF Voxtral model to Mistral format.") |
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
|
|
print(f"Loading HF Voxtral model from {input_path_or_repo}...") |
|
hf_config = VoxtralConfig.from_pretrained(input_path_or_repo) |
|
|
|
local_path = snapshot_download(input_path_or_repo) |
|
|
|
|
|
if unquantized_model_path is not None: |
|
if os.path.exists(unquantized_model_path): |
|
unquantized_model_path = unquantized_model_path |
|
else: |
|
unquantized_model_path = snapshot_download(unquantized_model_path) |
|
config_path = os.path.join(unquantized_model_path, "params.json") |
|
with open(config_path, "r") as f: |
|
config = json.load(f) |
|
config = add_quantization_config(config, hf_config) |
|
|
|
with open(os.path.join(output_dir, "params.json"), "w") as f: |
|
json.dump(config, f, indent=2) |
|
else: |
|
raise ValueError(f"Unquantized model config not found for {unquantized_model_path}") |
|
|
|
|
|
print("Converting state dict...") |
|
tensor_files = sorted([f for f in os.listdir(os.path.join(local_path)) if f.endswith(".safetensors")]) |
|
|
|
hf_state_dict = {} |
|
|
|
for file in tensor_files: |
|
file_path = os.path.join(local_path, file) |
|
with safe_open(file_path, framework="pt", device="cuda") as f: |
|
for key in f.keys(): |
|
hf_state_dict[key] = f.get_tensor(key) |
|
|
|
mistral_state_dict = convert_state_dict(hf_state_dict, config) |
|
|
|
|
|
save_file(mistral_state_dict, os.path.join(output_dir, "consolidated.safetensors")) |
|
|
|
del hf_state_dict, mistral_state_dict |
|
gc.collect() |
|
print("Model converted successfully.") |
|
|
|
def write_tokenizer(input_path_or_repo: str, output_dir: str): |
|
"""Extract and save the tokenizer from Voxtral model""" |
|
from transformers import MistralCommonTokenizer |
|
|
|
print("Extracting tokenizer...") |
|
tokenizer = MistralCommonTokenizer.from_pretrained(input_path_or_repo) |
|
tokenizer.save_pretrained(output_dir) |
|
print("Tokenizer saved successfully.") |
|
|
|
|
|
def main(): |
|
parser = argparse.ArgumentParser(description="Convert HF Voxtral weights to Mistral format") |
|
parser.add_argument( |
|
"--input_path_or_repo", |
|
type=str, |
|
default="RedHatAI/Voxtral-Mini-3B-2507-FP8-dynamic", |
|
help="Path or repo containing HF Voxtral model", |
|
) |
|
parser.add_argument( |
|
"--output_dir", |
|
type=str, |
|
default="Voxtral-Mini-3B-2507-FP8-dynamic-converted", |
|
help="Location to write Mistral model and tokenizer", |
|
) |
|
parser.add_argument( |
|
"--skip_tokenizer", |
|
action="store_true", |
|
help="Skip tokenizer conversion" |
|
) |
|
parser.add_argument( |
|
"--unquantized_model_path", |
|
type=str, |
|
default="mistralai/Voxtral-Mini-3B-2507", |
|
help="Path to the unquantized model", |
|
) |
|
args = parser.parse_args() |
|
|
|
write_model( |
|
args.input_path_or_repo, |
|
args.output_dir, |
|
unquantized_model_path=args.unquantized_model_path, |
|
) |
|
|
|
if not args.skip_tokenizer: |
|
write_tokenizer( |
|
args.input_path_or_repo, |
|
args.output_dir, |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|