Spaces:
Running
Running
File size: 6,398 Bytes
dec16d6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 |
#!/usr/bin/env python
# coding=utf-8
import argparse
import json
import os
from pathlib import Path
import re
import torch
from typing import Dict, List, Tuple
from vibevoice.modular.configuration_vibevoice import (
VibeVoiceConfig
)
from vibevoice.modular.modeling_vibevoice import VibeVoiceForConditionalGeneration
from transformers.utils import logging
logger = logging.get_logger(__name__)
def convert_vibevoice_nnscaler_checkpoint_to_hf(
checkpoint_path: str,
pytorch_dump_folder_path: str,
config_path: str = None,
):
"""
Convert a nnscaler VibeVoice checkpoint to HuggingFace format.
Supports both regular checkpoints and tensor parallel checkpoints.
"""
# Load regular checkpoint
logger.info(f"Loading regular checkpoint from {checkpoint_path}")
checkpoint = torch.load(checkpoint_path, map_location="cpu") # ['model', 'optimizer', 'lr_scheduler', 'train_status', 'train_args', 'rng_states', 'nnscaler', 'dataloader']
# config = checkpoint['train_args']
init_config_name = checkpoint['train_args']['vars']['model_args']['config_path']['relative_path']
pretrained_name = checkpoint['train_args']['vars']['data_args']['tokenizer_path']
init_config_path = Path(__file__).parent.parent / 'configs' / init_config_name.split('/')[-1]
if init_config_path.exists():
logger.info(f"Loading initial config from {init_config_path}")
with open(init_config_path, 'r') as f:
init_config = json.load(f)
else:
raise FileNotFoundError(f"Initial config file {init_config_path} not found. Please provide a valid path.")
tie_word_embeddings = init_config['decoder_config'].get('tie_word_embeddings', True)
logger.info(f"Tie word embeddings: {tie_word_embeddings}")
init_config['decoder_config']['use_cache'] = True
config = VibeVoiceConfig(**init_config, tie_word_embeddings=tie_word_embeddings)
# # Extract the model state dict
model_state_dict = {k.replace('model.model.', 'model.'): v for k, v in checkpoint["model"].items() if k.startswith('model.model.')}
if not tie_word_embeddings and 'model.lm_head.weight' in checkpoint["model"].keys():
# If not tying weights, we need to add the lm_head weight separately
model_state_dict['lm_head.weight'] = checkpoint["model"]['model.lm_head.weight']
# Override with provided config if available
if config_path:
logger.info(f"Loading config from {config_path}")
with open(config_path, 'r') as f:
config_dict = json.load(f)
config = VibeVoiceConfig.from_dict(config_dict)
# Set the default dtype to bfloat16 before creating the model
original_dtype = torch.get_default_dtype()
torch.set_default_dtype(torch.bfloat16)
# Create the HuggingFace model
logger.info("Creating HuggingFace VibeVoiceForConditionalGeneration model")
model = VibeVoiceForConditionalGeneration(config)
# Restore original dtype
torch.set_default_dtype(original_dtype)
# Load the state dict
logger.info("Loading weights into model")
missing_keys, unexpected_keys = model.load_state_dict(model_state_dict, strict=False)
if missing_keys:
logger.warning(f"Missing keys: {missing_keys}")
if unexpected_keys:
logger.warning(f"Unexpected keys: {unexpected_keys}")
# Create output directory
os.makedirs(pytorch_dump_folder_path, exist_ok=True)
# Save the model and config
logger.info(f"Saving model to {pytorch_dump_folder_path}")
# Save config
config.save_pretrained(pytorch_dump_folder_path)
# Save VibeVoiceProcessor configuration
logger.info("Saving VibeVoiceProcessor configuration")
processor_config = {
"processor_class": "VibeVoiceProcessor",
"speech_tok_compress_ratio": 3200,
"db_normalize": True,
# Audio processor configuration
"audio_processor": {
"feature_extractor_type": "VibeVoiceTokenizerProcessor",
"sampling_rate": 24000,
"normalize_audio": True,
"target_dB_FS": -25,
"eps": 1e-6,
},
"language_model_pretrained_name": pretrained_name,
}
processor_config_path = os.path.join(pytorch_dump_folder_path, "preprocessor_config.json")
with open(processor_config_path, 'w') as f:
json.dump(processor_config, f, indent=2)
logger.info(f"Saved processor config to {processor_config_path}")
# Save model with sharding
# save_pretrained handles tied weights automatically
logger.info("Saving model weights with sharding...")
model.save_pretrained(
pytorch_dump_folder_path,
max_shard_size="2GB", # Set maximum size for each shard
safe_serialization=True # Ensure saving in .safetensors format
)
logger.info(f"Model weights saved to {pytorch_dump_folder_path}")
logger.info("Conversion complete!")
# Verify the saved model can be loaded
logger.info("Verifying saved model...")
loaded_model = VibeVoiceForConditionalGeneration.from_pretrained(pytorch_dump_folder_path)
logger.info("Model successfully loaded from saved checkpoint!")
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--nnscaler_checkpoint_path",
type=str,
required=True,
help="Path to the fairseq checkpoint (.pt file). For tensor parallel checkpoints, "
"provide any one of the part files (e.g., checkpoint_1_5000-model_part-0.pt), "
"and the script will automatically detect and merge all parts.",
)
parser.add_argument(
"--pytorch_dump_folder_path",
type=str,
required=True,
help="Path to the output PyTorch model directory",
)
parser.add_argument(
"--config_path",
type=str,
default=None,
help="Optional path to a config JSON file to override extracted config",
)
args = parser.parse_args()
convert_vibevoice_nnscaler_checkpoint_to_hf(
args.nnscaler_checkpoint_path,
args.pytorch_dump_folder_path,
args.config_path,
)
if __name__ == "__main__":
main() |