NeuralFalcon commited on
Commit
dec16d6
·
verified ·
1 Parent(s): 5df77b1

Upload 2 files

Browse files
scripts/__init__.py ADDED
File without changes
scripts/convert_nnscaler_checkpoint_to_transformers.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+
4
+ import argparse
5
+ import json
6
+ import os
7
+ from pathlib import Path
8
+ import re
9
+ import torch
10
+ from typing import Dict, List, Tuple
11
+
12
+ from vibevoice.modular.configuration_vibevoice import (
13
+ VibeVoiceConfig
14
+ )
15
+ from vibevoice.modular.modeling_vibevoice import VibeVoiceForConditionalGeneration
16
+ from transformers.utils import logging
17
+
18
+ logger = logging.get_logger(__name__)
19
+
20
+ def convert_vibevoice_nnscaler_checkpoint_to_hf(
21
+ checkpoint_path: str,
22
+ pytorch_dump_folder_path: str,
23
+ config_path: str = None,
24
+ ):
25
+ """
26
+ Convert a nnscaler VibeVoice checkpoint to HuggingFace format.
27
+ Supports both regular checkpoints and tensor parallel checkpoints.
28
+ """
29
+
30
+ # Load regular checkpoint
31
+ logger.info(f"Loading regular checkpoint from {checkpoint_path}")
32
+ checkpoint = torch.load(checkpoint_path, map_location="cpu") # ['model', 'optimizer', 'lr_scheduler', 'train_status', 'train_args', 'rng_states', 'nnscaler', 'dataloader']
33
+
34
+ # config = checkpoint['train_args']
35
+ init_config_name = checkpoint['train_args']['vars']['model_args']['config_path']['relative_path']
36
+ pretrained_name = checkpoint['train_args']['vars']['data_args']['tokenizer_path']
37
+
38
+ init_config_path = Path(__file__).parent.parent / 'configs' / init_config_name.split('/')[-1]
39
+ if init_config_path.exists():
40
+ logger.info(f"Loading initial config from {init_config_path}")
41
+ with open(init_config_path, 'r') as f:
42
+ init_config = json.load(f)
43
+ else:
44
+ raise FileNotFoundError(f"Initial config file {init_config_path} not found. Please provide a valid path.")
45
+
46
+ tie_word_embeddings = init_config['decoder_config'].get('tie_word_embeddings', True)
47
+ logger.info(f"Tie word embeddings: {tie_word_embeddings}")
48
+
49
+ init_config['decoder_config']['use_cache'] = True
50
+ config = VibeVoiceConfig(**init_config, tie_word_embeddings=tie_word_embeddings)
51
+
52
+ # # Extract the model state dict
53
+ model_state_dict = {k.replace('model.model.', 'model.'): v for k, v in checkpoint["model"].items() if k.startswith('model.model.')}
54
+ if not tie_word_embeddings and 'model.lm_head.weight' in checkpoint["model"].keys():
55
+ # If not tying weights, we need to add the lm_head weight separately
56
+ model_state_dict['lm_head.weight'] = checkpoint["model"]['model.lm_head.weight']
57
+
58
+ # Override with provided config if available
59
+ if config_path:
60
+ logger.info(f"Loading config from {config_path}")
61
+ with open(config_path, 'r') as f:
62
+ config_dict = json.load(f)
63
+ config = VibeVoiceConfig.from_dict(config_dict)
64
+
65
+ # Set the default dtype to bfloat16 before creating the model
66
+ original_dtype = torch.get_default_dtype()
67
+ torch.set_default_dtype(torch.bfloat16)
68
+
69
+ # Create the HuggingFace model
70
+ logger.info("Creating HuggingFace VibeVoiceForConditionalGeneration model")
71
+ model = VibeVoiceForConditionalGeneration(config)
72
+
73
+ # Restore original dtype
74
+ torch.set_default_dtype(original_dtype)
75
+
76
+ # Load the state dict
77
+ logger.info("Loading weights into model")
78
+ missing_keys, unexpected_keys = model.load_state_dict(model_state_dict, strict=False)
79
+
80
+ if missing_keys:
81
+ logger.warning(f"Missing keys: {missing_keys}")
82
+ if unexpected_keys:
83
+ logger.warning(f"Unexpected keys: {unexpected_keys}")
84
+
85
+ # Create output directory
86
+ os.makedirs(pytorch_dump_folder_path, exist_ok=True)
87
+
88
+ # Save the model and config
89
+ logger.info(f"Saving model to {pytorch_dump_folder_path}")
90
+
91
+ # Save config
92
+ config.save_pretrained(pytorch_dump_folder_path)
93
+
94
+ # Save VibeVoiceProcessor configuration
95
+ logger.info("Saving VibeVoiceProcessor configuration")
96
+ processor_config = {
97
+ "processor_class": "VibeVoiceProcessor",
98
+ "speech_tok_compress_ratio": 3200,
99
+ "db_normalize": True,
100
+ # Audio processor configuration
101
+ "audio_processor": {
102
+ "feature_extractor_type": "VibeVoiceTokenizerProcessor",
103
+ "sampling_rate": 24000,
104
+ "normalize_audio": True,
105
+ "target_dB_FS": -25,
106
+ "eps": 1e-6,
107
+ },
108
+ "language_model_pretrained_name": pretrained_name,
109
+ }
110
+
111
+ processor_config_path = os.path.join(pytorch_dump_folder_path, "preprocessor_config.json")
112
+ with open(processor_config_path, 'w') as f:
113
+ json.dump(processor_config, f, indent=2)
114
+ logger.info(f"Saved processor config to {processor_config_path}")
115
+
116
+ # Save model with sharding
117
+ # save_pretrained handles tied weights automatically
118
+ logger.info("Saving model weights with sharding...")
119
+ model.save_pretrained(
120
+ pytorch_dump_folder_path,
121
+ max_shard_size="2GB", # Set maximum size for each shard
122
+ safe_serialization=True # Ensure saving in .safetensors format
123
+ )
124
+ logger.info(f"Model weights saved to {pytorch_dump_folder_path}")
125
+
126
+ logger.info("Conversion complete!")
127
+
128
+ # Verify the saved model can be loaded
129
+ logger.info("Verifying saved model...")
130
+ loaded_model = VibeVoiceForConditionalGeneration.from_pretrained(pytorch_dump_folder_path)
131
+ logger.info("Model successfully loaded from saved checkpoint!")
132
+
133
+ def main():
134
+ parser = argparse.ArgumentParser()
135
+ parser.add_argument(
136
+ "--nnscaler_checkpoint_path",
137
+ type=str,
138
+ required=True,
139
+ help="Path to the fairseq checkpoint (.pt file). For tensor parallel checkpoints, "
140
+ "provide any one of the part files (e.g., checkpoint_1_5000-model_part-0.pt), "
141
+ "and the script will automatically detect and merge all parts.",
142
+ )
143
+ parser.add_argument(
144
+ "--pytorch_dump_folder_path",
145
+ type=str,
146
+ required=True,
147
+ help="Path to the output PyTorch model directory",
148
+ )
149
+ parser.add_argument(
150
+ "--config_path",
151
+ type=str,
152
+ default=None,
153
+ help="Optional path to a config JSON file to override extracted config",
154
+ )
155
+
156
+ args = parser.parse_args()
157
+
158
+ convert_vibevoice_nnscaler_checkpoint_to_hf(
159
+ args.nnscaler_checkpoint_path,
160
+ args.pytorch_dump_folder_path,
161
+ args.config_path,
162
+ )
163
+
164
+
165
+ if __name__ == "__main__":
166
+ main()