Spaces:
Running
Running
Upload 2 files
Browse files
scripts/__init__.py
ADDED
File without changes
|
scripts/convert_nnscaler_checkpoint_to_transformers.py
ADDED
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# coding=utf-8
|
3 |
+
|
4 |
+
import argparse
|
5 |
+
import json
|
6 |
+
import os
|
7 |
+
from pathlib import Path
|
8 |
+
import re
|
9 |
+
import torch
|
10 |
+
from typing import Dict, List, Tuple
|
11 |
+
|
12 |
+
from vibevoice.modular.configuration_vibevoice import (
|
13 |
+
VibeVoiceConfig
|
14 |
+
)
|
15 |
+
from vibevoice.modular.modeling_vibevoice import VibeVoiceForConditionalGeneration
|
16 |
+
from transformers.utils import logging
|
17 |
+
|
18 |
+
logger = logging.get_logger(__name__)
|
19 |
+
|
20 |
+
def convert_vibevoice_nnscaler_checkpoint_to_hf(
|
21 |
+
checkpoint_path: str,
|
22 |
+
pytorch_dump_folder_path: str,
|
23 |
+
config_path: str = None,
|
24 |
+
):
|
25 |
+
"""
|
26 |
+
Convert a nnscaler VibeVoice checkpoint to HuggingFace format.
|
27 |
+
Supports both regular checkpoints and tensor parallel checkpoints.
|
28 |
+
"""
|
29 |
+
|
30 |
+
# Load regular checkpoint
|
31 |
+
logger.info(f"Loading regular checkpoint from {checkpoint_path}")
|
32 |
+
checkpoint = torch.load(checkpoint_path, map_location="cpu") # ['model', 'optimizer', 'lr_scheduler', 'train_status', 'train_args', 'rng_states', 'nnscaler', 'dataloader']
|
33 |
+
|
34 |
+
# config = checkpoint['train_args']
|
35 |
+
init_config_name = checkpoint['train_args']['vars']['model_args']['config_path']['relative_path']
|
36 |
+
pretrained_name = checkpoint['train_args']['vars']['data_args']['tokenizer_path']
|
37 |
+
|
38 |
+
init_config_path = Path(__file__).parent.parent / 'configs' / init_config_name.split('/')[-1]
|
39 |
+
if init_config_path.exists():
|
40 |
+
logger.info(f"Loading initial config from {init_config_path}")
|
41 |
+
with open(init_config_path, 'r') as f:
|
42 |
+
init_config = json.load(f)
|
43 |
+
else:
|
44 |
+
raise FileNotFoundError(f"Initial config file {init_config_path} not found. Please provide a valid path.")
|
45 |
+
|
46 |
+
tie_word_embeddings = init_config['decoder_config'].get('tie_word_embeddings', True)
|
47 |
+
logger.info(f"Tie word embeddings: {tie_word_embeddings}")
|
48 |
+
|
49 |
+
init_config['decoder_config']['use_cache'] = True
|
50 |
+
config = VibeVoiceConfig(**init_config, tie_word_embeddings=tie_word_embeddings)
|
51 |
+
|
52 |
+
# # Extract the model state dict
|
53 |
+
model_state_dict = {k.replace('model.model.', 'model.'): v for k, v in checkpoint["model"].items() if k.startswith('model.model.')}
|
54 |
+
if not tie_word_embeddings and 'model.lm_head.weight' in checkpoint["model"].keys():
|
55 |
+
# If not tying weights, we need to add the lm_head weight separately
|
56 |
+
model_state_dict['lm_head.weight'] = checkpoint["model"]['model.lm_head.weight']
|
57 |
+
|
58 |
+
# Override with provided config if available
|
59 |
+
if config_path:
|
60 |
+
logger.info(f"Loading config from {config_path}")
|
61 |
+
with open(config_path, 'r') as f:
|
62 |
+
config_dict = json.load(f)
|
63 |
+
config = VibeVoiceConfig.from_dict(config_dict)
|
64 |
+
|
65 |
+
# Set the default dtype to bfloat16 before creating the model
|
66 |
+
original_dtype = torch.get_default_dtype()
|
67 |
+
torch.set_default_dtype(torch.bfloat16)
|
68 |
+
|
69 |
+
# Create the HuggingFace model
|
70 |
+
logger.info("Creating HuggingFace VibeVoiceForConditionalGeneration model")
|
71 |
+
model = VibeVoiceForConditionalGeneration(config)
|
72 |
+
|
73 |
+
# Restore original dtype
|
74 |
+
torch.set_default_dtype(original_dtype)
|
75 |
+
|
76 |
+
# Load the state dict
|
77 |
+
logger.info("Loading weights into model")
|
78 |
+
missing_keys, unexpected_keys = model.load_state_dict(model_state_dict, strict=False)
|
79 |
+
|
80 |
+
if missing_keys:
|
81 |
+
logger.warning(f"Missing keys: {missing_keys}")
|
82 |
+
if unexpected_keys:
|
83 |
+
logger.warning(f"Unexpected keys: {unexpected_keys}")
|
84 |
+
|
85 |
+
# Create output directory
|
86 |
+
os.makedirs(pytorch_dump_folder_path, exist_ok=True)
|
87 |
+
|
88 |
+
# Save the model and config
|
89 |
+
logger.info(f"Saving model to {pytorch_dump_folder_path}")
|
90 |
+
|
91 |
+
# Save config
|
92 |
+
config.save_pretrained(pytorch_dump_folder_path)
|
93 |
+
|
94 |
+
# Save VibeVoiceProcessor configuration
|
95 |
+
logger.info("Saving VibeVoiceProcessor configuration")
|
96 |
+
processor_config = {
|
97 |
+
"processor_class": "VibeVoiceProcessor",
|
98 |
+
"speech_tok_compress_ratio": 3200,
|
99 |
+
"db_normalize": True,
|
100 |
+
# Audio processor configuration
|
101 |
+
"audio_processor": {
|
102 |
+
"feature_extractor_type": "VibeVoiceTokenizerProcessor",
|
103 |
+
"sampling_rate": 24000,
|
104 |
+
"normalize_audio": True,
|
105 |
+
"target_dB_FS": -25,
|
106 |
+
"eps": 1e-6,
|
107 |
+
},
|
108 |
+
"language_model_pretrained_name": pretrained_name,
|
109 |
+
}
|
110 |
+
|
111 |
+
processor_config_path = os.path.join(pytorch_dump_folder_path, "preprocessor_config.json")
|
112 |
+
with open(processor_config_path, 'w') as f:
|
113 |
+
json.dump(processor_config, f, indent=2)
|
114 |
+
logger.info(f"Saved processor config to {processor_config_path}")
|
115 |
+
|
116 |
+
# Save model with sharding
|
117 |
+
# save_pretrained handles tied weights automatically
|
118 |
+
logger.info("Saving model weights with sharding...")
|
119 |
+
model.save_pretrained(
|
120 |
+
pytorch_dump_folder_path,
|
121 |
+
max_shard_size="2GB", # Set maximum size for each shard
|
122 |
+
safe_serialization=True # Ensure saving in .safetensors format
|
123 |
+
)
|
124 |
+
logger.info(f"Model weights saved to {pytorch_dump_folder_path}")
|
125 |
+
|
126 |
+
logger.info("Conversion complete!")
|
127 |
+
|
128 |
+
# Verify the saved model can be loaded
|
129 |
+
logger.info("Verifying saved model...")
|
130 |
+
loaded_model = VibeVoiceForConditionalGeneration.from_pretrained(pytorch_dump_folder_path)
|
131 |
+
logger.info("Model successfully loaded from saved checkpoint!")
|
132 |
+
|
133 |
+
def main():
|
134 |
+
parser = argparse.ArgumentParser()
|
135 |
+
parser.add_argument(
|
136 |
+
"--nnscaler_checkpoint_path",
|
137 |
+
type=str,
|
138 |
+
required=True,
|
139 |
+
help="Path to the fairseq checkpoint (.pt file). For tensor parallel checkpoints, "
|
140 |
+
"provide any one of the part files (e.g., checkpoint_1_5000-model_part-0.pt), "
|
141 |
+
"and the script will automatically detect and merge all parts.",
|
142 |
+
)
|
143 |
+
parser.add_argument(
|
144 |
+
"--pytorch_dump_folder_path",
|
145 |
+
type=str,
|
146 |
+
required=True,
|
147 |
+
help="Path to the output PyTorch model directory",
|
148 |
+
)
|
149 |
+
parser.add_argument(
|
150 |
+
"--config_path",
|
151 |
+
type=str,
|
152 |
+
default=None,
|
153 |
+
help="Optional path to a config JSON file to override extracted config",
|
154 |
+
)
|
155 |
+
|
156 |
+
args = parser.parse_args()
|
157 |
+
|
158 |
+
convert_vibevoice_nnscaler_checkpoint_to_hf(
|
159 |
+
args.nnscaler_checkpoint_path,
|
160 |
+
args.pytorch_dump_folder_path,
|
161 |
+
args.config_path,
|
162 |
+
)
|
163 |
+
|
164 |
+
|
165 |
+
if __name__ == "__main__":
|
166 |
+
main()
|