yasserrmd commited on
Commit
20a29ac
·
verified ·
1 Parent(s): 3f31b6d

Upload folder using huggingface_hub

Browse files
__init__.py ADDED
File without changes
configs/qwen2.5_1.5b_64k.json ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_attn_implementation_autoset": true,
3
+ "acoustic_vae_dim": 64,
4
+ "acoustic_tokenizer_config": {
5
+ "causal": true,
6
+ "channels": 1,
7
+ "conv_bias": true,
8
+ "conv_norm": "none",
9
+ "corpus_normalize": 0.0,
10
+ "decoder_depths": null,
11
+ "decoder_n_filters": 32,
12
+ "decoder_ratios": [
13
+ 8,
14
+ 5,
15
+ 5,
16
+ 4,
17
+ 2,
18
+ 2
19
+ ],
20
+ "disable_last_norm": true,
21
+ "encoder_depths": "3-3-3-3-3-3-8",
22
+ "encoder_n_filters": 32,
23
+ "encoder_ratios": [
24
+ 8,
25
+ 5,
26
+ 5,
27
+ 4,
28
+ 2,
29
+ 2
30
+ ],
31
+ "fix_std": 0.5,
32
+ "layer_scale_init_value": 1e-06,
33
+ "layernorm": "RMSNorm",
34
+ "layernorm_elementwise_affine": true,
35
+ "layernorm_eps": 1e-05,
36
+ "mixer_layer": "depthwise_conv",
37
+ "model_type": "vibepod_acoustic_tokenizer",
38
+ "pad_mode": "constant",
39
+ "std_dist_type": "gaussian",
40
+ "vae_dim": 64,
41
+ "weight_init_value": 0.01
42
+ },
43
+ "decoder_config": {
44
+ "attention_dropout": 0.0,
45
+ "hidden_act": "silu",
46
+ "hidden_size": 1536,
47
+ "initializer_range": 0.02,
48
+ "intermediate_size": 8960,
49
+ "max_position_embeddings": 65536,
50
+ "max_window_layers": 28,
51
+ "model_type": "qwen2",
52
+ "num_attention_heads": 12,
53
+ "num_hidden_layers": 28,
54
+ "num_key_value_heads": 2,
55
+ "rms_norm_eps": 1e-06,
56
+ "rope_scaling": null,
57
+ "rope_theta": 1000000.0,
58
+ "sliding_window": null,
59
+ "tie_word_embeddings": true,
60
+ "torch_dtype": "bfloat16",
61
+ "use_cache": true,
62
+ "use_sliding_window": false,
63
+ "vocab_size": 151936
64
+ },
65
+ "diffusion_head_config": {
66
+ "ddpm_batch_mul": 4,
67
+ "ddpm_beta_schedule": "cosine",
68
+ "ddpm_num_inference_steps": 20,
69
+ "ddpm_num_steps": 1000,
70
+ "diffusion_type": "ddpm",
71
+ "head_ffn_ratio": 3.0,
72
+ "head_layers": 4,
73
+ "hidden_size": 1536,
74
+ "latent_size": 64,
75
+ "model_type": "vibepod_diffusion_head",
76
+ "prediction_type": "v_prediction",
77
+ "rms_norm_eps": 1e-05,
78
+ "speech_vae_dim": 64
79
+ },
80
+ "model_type": "vibepod",
81
+ "semantic_tokenizer_config": {
82
+ "causal": true,
83
+ "channels": 1,
84
+ "conv_bias": true,
85
+ "conv_norm": "none",
86
+ "corpus_normalize": 0.0,
87
+ "disable_last_norm": true,
88
+ "encoder_depths": "3-3-3-3-3-3-8",
89
+ "encoder_n_filters": 32,
90
+ "encoder_ratios": [
91
+ 8,
92
+ 5,
93
+ 5,
94
+ 4,
95
+ 2,
96
+ 2
97
+ ],
98
+ "fix_std": 0,
99
+ "layer_scale_init_value": 1e-06,
100
+ "layernorm": "RMSNorm",
101
+ "layernorm_elementwise_affine": true,
102
+ "layernorm_eps": 1e-05,
103
+ "mixer_layer": "depthwise_conv",
104
+ "model_type": "vibepod_semantic_tokenizer",
105
+ "pad_mode": "constant",
106
+ "std_dist_type": "none",
107
+ "vae_dim": 128,
108
+ "weight_init_value": 0.01
109
+ },
110
+ "semantic_vae_dim": 128,
111
+ "torch_dtype": "bfloat16"
112
+ }
configs/qwen2.5_7b_32k.json ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_attn_implementation_autoset": true,
3
+ "acoustic_vae_dim": 64,
4
+ "acoustic_tokenizer_config": {
5
+ "causal": true,
6
+ "channels": 1,
7
+ "conv_bias": true,
8
+ "conv_norm": "none",
9
+ "corpus_normalize": 0.0,
10
+ "decoder_depths": null,
11
+ "decoder_n_filters": 32,
12
+ "decoder_ratios": [
13
+ 8,
14
+ 5,
15
+ 5,
16
+ 4,
17
+ 2,
18
+ 2
19
+ ],
20
+ "disable_last_norm": true,
21
+ "encoder_depths": "3-3-3-3-3-3-8",
22
+ "encoder_n_filters": 32,
23
+ "encoder_ratios": [
24
+ 8,
25
+ 5,
26
+ 5,
27
+ 4,
28
+ 2,
29
+ 2
30
+ ],
31
+ "fix_std": 0.5,
32
+ "layer_scale_init_value": 1e-06,
33
+ "layernorm": "RMSNorm",
34
+ "layernorm_elementwise_affine": true,
35
+ "layernorm_eps": 1e-05,
36
+ "mixer_layer": "depthwise_conv",
37
+ "model_type": "vibepod_acoustic_tokenizer",
38
+ "pad_mode": "constant",
39
+ "std_dist_type": "gaussian",
40
+ "vae_dim": 64,
41
+ "weight_init_value": 0.01
42
+ },
43
+ "decoder_config": {
44
+ "attention_dropout": 0.0,
45
+ "hidden_act": "silu",
46
+ "hidden_size": 3584,
47
+ "initializer_range": 0.02,
48
+ "intermediate_size": 18944,
49
+ "max_position_embeddings": 32768,
50
+ "max_window_layers": 28,
51
+ "model_type": "qwen2",
52
+ "num_attention_heads": 28,
53
+ "num_hidden_layers": 28,
54
+ "num_key_value_heads": 4,
55
+ "rms_norm_eps": 1e-06,
56
+ "rope_theta": 1000000.0,
57
+ "sliding_window": null,
58
+ "tie_word_embeddings": false,
59
+ "torch_dtype": "bfloat16",
60
+ "transformers_version": "4.40.1",
61
+ "use_cache": true,
62
+ "use_mrope": false,
63
+ "use_sliding_window": false,
64
+ "vocab_size": 152064
65
+ },
66
+ "diffusion_head_config": {
67
+ "ddpm_batch_mul": 4,
68
+ "ddpm_beta_schedule": "cosine",
69
+ "ddpm_num_inference_steps": 20,
70
+ "ddpm_num_steps": 1000,
71
+ "diffusion_type": "ddpm",
72
+ "head_ffn_ratio": 3.0,
73
+ "head_layers": 4,
74
+ "hidden_size": 3584,
75
+ "latent_size": 64,
76
+ "model_type": "vibepod_diffusion_head",
77
+ "prediction_type": "v_prediction",
78
+ "rms_norm_eps": 1e-05,
79
+ "speech_vae_dim": 64
80
+ },
81
+ "model_type": "vibepod",
82
+ "semantic_tokenizer_config": {
83
+ "causal": true,
84
+ "channels": 1,
85
+ "conv_bias": true,
86
+ "conv_norm": "none",
87
+ "corpus_normalize": 0.0,
88
+ "disable_last_norm": true,
89
+ "encoder_depths": "3-3-3-3-3-3-8",
90
+ "encoder_n_filters": 32,
91
+ "encoder_ratios": [
92
+ 8,
93
+ 5,
94
+ 5,
95
+ 4,
96
+ 2,
97
+ 2
98
+ ],
99
+ "fix_std": 0,
100
+ "layer_scale_init_value": 1e-06,
101
+ "layernorm": "RMSNorm",
102
+ "layernorm_elementwise_affine": true,
103
+ "layernorm_eps": 1e-05,
104
+ "mixer_layer": "depthwise_conv",
105
+ "model_type": "vibepod_semantic_tokenizer",
106
+ "pad_mode": "constant",
107
+ "std_dist_type": "none",
108
+ "vae_dim": 128,
109
+ "weight_init_value": 0.01
110
+ },
111
+ "semantic_vae_dim": 128,
112
+ "torch_dtype": "bfloat16"
113
+ }
modular/__init__.py ADDED
File without changes
modular/configuration_vibevoice.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ VibeVoice_AcousticTokenizer model configuration"""
2
+
3
+ from typing import Dict, List, Optional, Tuple
4
+
5
+ from transformers.configuration_utils import PretrainedConfig
6
+ from transformers.utils import logging
7
+
8
+ from transformers.models.qwen2.configuration_qwen2 import Qwen2Config
9
+
10
+ logger = logging.get_logger(__name__)
11
+
12
+
13
+ class VibeVoiceAcousticTokenizerConfig(PretrainedConfig):
14
+ model_type = "vibevoice_acoustic_tokenizer"
15
+
16
+ def __init__(
17
+ self,
18
+ channels: int = 1,
19
+ corpus_normalize: float = 0.0,
20
+ causal: bool = True,
21
+ vae_dim: int = 64,
22
+ fix_std: float = 0.5,
23
+ std_dist_type: str = 'gaussian',
24
+ # common
25
+ mixer_layer: str = 'depthwise_conv',
26
+ conv_norm: str = 'none',
27
+ pad_mode: str = 'constant',
28
+ disable_last_norm: bool = True,
29
+ layernorm: str = 'RMSNorm',
30
+ layernorm_eps: float = 1e-5,
31
+ layernorm_elementwise_affine: bool = True,
32
+ conv_bias: bool = True,
33
+ layer_scale_init_value: float = 1e-6,
34
+ weight_init_value: float = 1e-2,
35
+ # encoder specific
36
+ encoder_n_filters: int = 32,
37
+ encoder_ratios: Optional[List[int]] = [8,5,5,4,2,2],
38
+ encoder_depths: str = "3-3-3-3-3-3-8",
39
+ # decoder specific
40
+ decoder_n_filters: int = 32,
41
+ decoder_ratios: Optional[List[int]] = None, # if None, same as encoder
42
+ decoder_depths: Optional[str] = None,
43
+ **kwargs
44
+ ):
45
+ super().__init__(**kwargs)
46
+ self.channels = channels
47
+ self.corpus_normalize = corpus_normalize
48
+ self.causal = causal
49
+ self.vae_dim = vae_dim
50
+ self.fix_std = fix_std
51
+ self.std_dist_type = std_dist_type
52
+
53
+ # common parameters
54
+ self.conv_norm = conv_norm
55
+ self.pad_mode = pad_mode
56
+ self.layernorm_eps = layernorm_eps
57
+ self.disable_last_norm = disable_last_norm
58
+ self.layernorm = layernorm
59
+ self.layernorm_elementwise_affine = layernorm_elementwise_affine
60
+ self.conv_bias = conv_bias
61
+ self.layer_scale_init_value = layer_scale_init_value
62
+ self.weight_init_value = weight_init_value
63
+ self.mixer_layer = mixer_layer
64
+
65
+ # encoder specific parameters
66
+ self.encoder_n_filters = encoder_n_filters
67
+ self.encoder_ratios = encoder_ratios
68
+ self.encoder_depths = encoder_depths
69
+
70
+ # decoder specific parameters
71
+ self.decoder_ratios = decoder_ratios if decoder_ratios is not None else encoder_ratios
72
+ self.decoder_n_filters = decoder_n_filters
73
+ self.decoder_depths = decoder_depths
74
+
75
+
76
+ class VibeVoiceSemanticTokenizerConfig(PretrainedConfig):
77
+ model_type = "vibevoice_semantic_tokenizer"
78
+
79
+ def __init__(
80
+ self,
81
+ channels: int = 1,
82
+ corpus_normalize: float = 0.0,
83
+ causal: bool = True,
84
+ vae_dim: int = 64,
85
+ fix_std: float = 0,
86
+ std_dist_type: str = 'none',
87
+ # common
88
+ mixer_layer: str = 'depthwise_conv',
89
+ conv_norm: str = 'none',
90
+ pad_mode: str = 'constant',
91
+ disable_last_norm: bool = True,
92
+ layernorm: str = 'RMSNorm',
93
+ layernorm_eps: float = 1e-5,
94
+ layernorm_elementwise_affine: bool = True,
95
+ conv_bias: bool = True,
96
+ layer_scale_init_value: float = 1e-6,
97
+ weight_init_value: float = 1e-2,
98
+ # encoder specific
99
+ encoder_n_filters: int = 32,
100
+ encoder_ratios: Optional[List[int]] = [8,5,5,4,2,2],
101
+ encoder_depths: str = "3-3-3-3-3-3-8",
102
+ **kwargs
103
+ ):
104
+ super().__init__(**kwargs)
105
+ self.channels = channels
106
+ self.corpus_normalize = corpus_normalize
107
+ self.causal = causal
108
+ self.vae_dim = vae_dim
109
+ self.fix_std = fix_std
110
+ self.std_dist_type = std_dist_type
111
+
112
+ # common parameters
113
+ self.conv_norm = conv_norm
114
+ self.pad_mode = pad_mode
115
+ self.layernorm_eps = layernorm_eps
116
+ self.disable_last_norm = disable_last_norm
117
+ self.layernorm = layernorm
118
+ self.layernorm_elementwise_affine = layernorm_elementwise_affine
119
+ self.conv_bias = conv_bias
120
+ self.layer_scale_init_value = layer_scale_init_value
121
+ self.weight_init_value = weight_init_value
122
+ self.mixer_layer = mixer_layer
123
+
124
+ # encoder specific parameters
125
+ self.encoder_n_filters = encoder_n_filters
126
+ self.encoder_ratios = encoder_ratios
127
+ self.encoder_depths = encoder_depths
128
+
129
+
130
+ class VibeVoiceDiffusionHeadConfig(PretrainedConfig):
131
+ model_type = "vibevoice_diffusion_head"
132
+
133
+ def __init__(
134
+ self,
135
+ hidden_size=768,
136
+ head_layers=4,
137
+ head_ffn_ratio=3.0,
138
+ rms_norm_eps=1e-5,
139
+ latent_size=64,
140
+ speech_vae_dim=None,
141
+ prediction_type="v_prediction",
142
+ diffusion_type="ddpm",
143
+ ddpm_num_steps=1000,
144
+ ddpm_num_inference_steps=20,
145
+ ddpm_beta_schedule="cosine",
146
+ ddpm_batch_mul=4,
147
+ **kwargs
148
+ ):
149
+ self.hidden_size = hidden_size
150
+ self.head_layers = head_layers
151
+ self.head_ffn_ratio = head_ffn_ratio
152
+ self.rms_norm_eps = rms_norm_eps
153
+ self.latent_size = latent_size
154
+ self.speech_vae_dim = speech_vae_dim
155
+ self.prediction_type = prediction_type
156
+ self.diffusion_type = diffusion_type
157
+ self.ddpm_num_steps = ddpm_num_steps
158
+ self.ddpm_num_inference_steps = ddpm_num_inference_steps
159
+ self.ddpm_beta_schedule = ddpm_beta_schedule
160
+ self.ddpm_batch_mul = ddpm_batch_mul
161
+
162
+ super().__init__(**kwargs)
163
+
164
+ class VibeVoiceConfig(PretrainedConfig):
165
+ model_type = "vibevoice"
166
+ is_composition = True
167
+ sub_configs = {
168
+ "acoustic_tokenizer_config": VibeVoiceAcousticTokenizerConfig,
169
+ "semantic_tokenizer_config": VibeVoiceSemanticTokenizerConfig,
170
+ "decoder_config": Qwen2Config,
171
+ "diffusion_head_config": VibeVoiceDiffusionHeadConfig,
172
+ }
173
+ # keys_to_ignore_at_inference = ["past_key_values"]
174
+ # Default tensor parallel plan for base model `Qwen2`
175
+ base_model_tp_plan = {
176
+ "layers.*.self_attn.q_proj": "colwise",
177
+ "layers.*.self_attn.k_proj": "colwise",
178
+ "layers.*.self_attn.v_proj": "colwise",
179
+ "layers.*.self_attn.o_proj": "rowwise",
180
+ "layers.*.mlp.gate_proj": "colwise",
181
+ "layers.*.mlp.up_proj": "colwise",
182
+ "layers.*.mlp.down_proj": "rowwise",
183
+ }
184
+
185
+ def __init__(
186
+ self,
187
+ acoustic_tokenizer_config=None,
188
+ semantic_tokenizer_config=None,
189
+ decoder_config=None,
190
+ diffusion_head_config=None,
191
+ **kwargs
192
+ ):
193
+
194
+ # kwargs["_attn_implementation"] = "flash_attention_2"
195
+ kwargs["_attn_implementation_autoset"] = False
196
+
197
+ if acoustic_tokenizer_config is None:
198
+ self.acoustic_tokenizer_config = self.sub_configs["acoustic_tokenizer_config"]()
199
+ elif isinstance(acoustic_tokenizer_config, dict):
200
+ acoustic_tokenizer_config["model_type"] = "vibevoice_acoustic_tokenizer"
201
+ self.acoustic_tokenizer_config = self.sub_configs["acoustic_tokenizer_config"](**acoustic_tokenizer_config)
202
+ elif isinstance(acoustic_tokenizer_config, VibeVoiceAcousticTokenizerConfig):
203
+ # If an instance of the config class is provided
204
+ self.acoustic_tokenizer_config = acoustic_tokenizer_config
205
+
206
+ if semantic_tokenizer_config is None:
207
+ self.semantic_tokenizer_config = self.sub_configs["semantic_tokenizer_config"]()
208
+ elif isinstance(semantic_tokenizer_config, dict):
209
+ semantic_tokenizer_config["model_type"] = "vibevoice_semantic_tokenizer"
210
+ self.semantic_tokenizer_config = self.sub_configs["semantic_tokenizer_config"](**semantic_tokenizer_config)
211
+ elif isinstance(semantic_tokenizer_config, VibeVoiceSemanticTokenizerConfig):
212
+ # If an instance of the config class is provided
213
+ self.semantic_tokenizer_config = semantic_tokenizer_config
214
+
215
+ if decoder_config is None:
216
+ self.decoder_config = self.sub_configs["decoder_config"]()
217
+ elif isinstance(decoder_config, dict):
218
+ # If a dictionary is provided, instantiate the config class with it
219
+ # self.decoder_config = self.sub_configs["decoder_config"](**decoder_config)
220
+ if decoder_config.get("model_type", '') == "qwen2":
221
+ self.decoder_config = Qwen2Config(**decoder_config)
222
+ else:
223
+ raise ValueError(f"Unsupported decoder model type: {decoder_config.get('model_type', '')}")
224
+ elif isinstance(decoder_config, (Qwen2Config,)):
225
+ # If an instance of the config class is provided
226
+ self.decoder_config = decoder_config
227
+
228
+ if diffusion_head_config is None:
229
+ self.diffusion_head_config = self.sub_configs["diffusion_head_config"]()
230
+ elif isinstance(diffusion_head_config, dict):
231
+ diffusion_head_config["model_type"] = "vibevoice_diffusion_head"
232
+ self.diffusion_head_config = self.sub_configs["diffusion_head_config"](**diffusion_head_config)
233
+ elif isinstance(diffusion_head_config, VibeVoiceDiffusionHeadConfig):
234
+ # If an instance of the config class is provided
235
+ self.diffusion_head_config = diffusion_head_config
236
+
237
+ # other parameters
238
+ self.acoustic_vae_dim = getattr(self.acoustic_tokenizer_config, 'vae_dim', 64)
239
+ self.semantic_vae_dim = getattr(self.semantic_tokenizer_config, 'vae_dim', 128)
240
+
241
+ super().__init__(**kwargs)
242
+
243
+ __all__ = [
244
+ "VibeVoiceAcousticTokenizerConfig",
245
+ "VibeVoiceSemanticTokenizerConfig",
246
+ "VibeVoiceDiffusionHeadConfig",
247
+ "VibeVoiceConfig"
248
+ ]
modular/modeling_vibevoice.py ADDED
@@ -0,0 +1,488 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Dict, List, Optional, Tuple, Union, Callable
3
+ from tqdm import tqdm
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ import torch.distributed as dist
8
+
9
+ from transformers.models.auto import AutoModel, AutoModelForCausalLM
10
+
11
+ from transformers.activations import ACT2FN
12
+ from transformers.modeling_outputs import CausalLMOutput, BaseModelOutputWithPast, ModelOutput
13
+ from transformers.models.llama.modeling_llama import LlamaRMSNorm
14
+ from transformers import modeling_utils
15
+ from transformers.modeling_utils import PreTrainedModel
16
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
17
+ from transformers.utils import logging
18
+
19
+
20
+ from .modular_vibevoice_tokenizer import VibeVoiceTokenizerStreamingCache, VibeVoiceAcousticTokenizerModel, VibeVoiceSemanticTokenizerModel
21
+ from .modular_vibevoice_diffusion_head import VibeVoiceDiffusionHead
22
+ from vibevoice.schedule.dpm_solver import DPMSolverMultistepScheduler
23
+
24
+ from .configuration_vibevoice import VibeVoiceConfig
25
+
26
+
27
+ logger = logging.get_logger(__name__)
28
+
29
+ if not hasattr(modeling_utils, "ALL_PARALLEL_STYLES") or modeling_utils.ALL_PARALLEL_STYLES is None:
30
+ modeling_utils.ALL_PARALLEL_STYLES = ["tp", "none", "colwise", "rowwise"]
31
+
32
+ @dataclass
33
+ class VibeVoiceCausalLMOutputWithPast(ModelOutput):
34
+ loss: Optional[torch.FloatTensor] = None
35
+ diffusion_loss: Optional[torch.FloatTensor] = None
36
+ speech_token_num: Optional[int] = None
37
+ logits: torch.FloatTensor = None
38
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
39
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
40
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
41
+
42
+
43
+ @dataclass
44
+ class VibeVoiceGenerationOutput(ModelOutput):
45
+ """
46
+ Output type for VibeVoice generation.
47
+
48
+ Args:
49
+ sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
50
+ The generated sequences.
51
+ speech_outputs (`List[torch.FloatTensor]`, *optional*):
52
+ List of generated speech waveforms or latents for each speech segment.
53
+ """
54
+ sequences: torch.LongTensor = None
55
+ speech_outputs: Optional[List[torch.FloatTensor]] = None
56
+
57
+
58
+ class SpeechConnector(nn.Module):
59
+ def __init__(self, input_dim, output_dim):
60
+ super().__init__()
61
+ self.fc1 = nn.Linear(input_dim, output_dim)
62
+ self.norm = LlamaRMSNorm(output_dim, eps=1e-6)
63
+ self.fc2 = nn.Linear(output_dim, output_dim)
64
+
65
+ def forward(self, features, **kwargs):
66
+ x = self.fc1(features)
67
+ x = self.norm(x)
68
+ x = self.fc2(x)
69
+ return x
70
+
71
+
72
+ # @auto_docstring
73
+ class VibeVoicePreTrainedModel(PreTrainedModel):
74
+ config_class = VibeVoiceConfig
75
+ base_model_prefix = "model"
76
+ supports_gradient_checkpointing = True
77
+ _skip_keys_device_placement = "past_key_values"
78
+ _supports_cache_class = True
79
+ _supports_flash_attn_2 = True
80
+ _supports_sdpa = True
81
+ _supports_quantized_cache = True
82
+ _supports_static_cache = True
83
+ _supports_attention_backend = True
84
+
85
+ def _init_weights(self, module):
86
+ if isinstance(module, VibeVoiceDiffusionHead):
87
+ module.initialize_weights()
88
+ return
89
+
90
+ # Use the language model's initializer_range if available
91
+ if hasattr(self.config, 'language_model_config') and hasattr(self.config.language_model_config, 'initializer_range'):
92
+ std = self.config.language_model_config.initializer_range
93
+ elif hasattr(self.config, 'decoder_config') and hasattr(self.config.decoder_config, 'initializer_range'):
94
+ std = self.config.decoder_config.initializer_range
95
+ else:
96
+ std = 0.02 # Default value
97
+
98
+ if isinstance(module, nn.Linear):
99
+ module.weight.data.normal_(mean=0.0, std=std)
100
+ if module.bias is not None:
101
+ module.bias.data.zero_()
102
+ elif isinstance(module, nn.LayerNorm):
103
+ module.weight.data.fill_(1.0)
104
+ module.bias.data.zero_()
105
+
106
+ # @auto_docstring
107
+ class VibeVoiceModel(VibeVoicePreTrainedModel):
108
+ def __init__(self, config):
109
+ super().__init__(config)
110
+
111
+ if hasattr(config, 'torch_dtype') and config.torch_dtype is not None:
112
+ if isinstance(config.torch_dtype, str):
113
+ dtype = getattr(torch, config.torch_dtype)
114
+ else:
115
+ dtype = config.torch_dtype
116
+ else:
117
+ dtype = torch.float32
118
+
119
+ # Initialize Qwen2 model for language modeling
120
+ lm_config = config.decoder_config
121
+ self.language_model = AutoModel.from_config(lm_config)
122
+
123
+ # Initialize speech components if needed
124
+ self.acoustic_tokenizer = AutoModel.from_config(config.acoustic_tokenizer_config).to(dtype)
125
+ self.semantic_tokenizer = AutoModel.from_config(config.semantic_tokenizer_config).to(dtype)
126
+
127
+ self.acoustic_connector = SpeechConnector(config.acoustic_vae_dim, lm_config.hidden_size).to(dtype)
128
+ self.semantic_connector = SpeechConnector(config.semantic_vae_dim, lm_config.hidden_size).to(dtype)
129
+
130
+ # Register scaling factors as buffers - use 1D tensors for FSDP compatibility
131
+ self.register_buffer('speech_scaling_factor', torch.tensor(float('nan')))
132
+ self.register_buffer('speech_bias_factor', torch.tensor(float('nan')))
133
+
134
+ # Initialize prediction head for speech generation
135
+ self.prediction_head = AutoModel.from_config(config.diffusion_head_config).to(dtype)
136
+
137
+ # Initialize noise scheduler
138
+ self.noise_scheduler = DPMSolverMultistepScheduler(
139
+ num_train_timesteps=config.diffusion_head_config.ddpm_num_steps,
140
+ beta_schedule=config.diffusion_head_config.ddpm_beta_schedule,
141
+ prediction_type=config.diffusion_head_config.prediction_type
142
+ )
143
+
144
+ def get_input_embeddings(self):
145
+ if hasattr(self.language_model, 'embed_tokens'):
146
+ # If the language model has an embed_tokens attribute, return it
147
+ return self.language_model.embed_tokens
148
+
149
+ for name, attr in self.language_model.fullmap.items(): # parallel by nnscaler, the name is changed
150
+ if attr.orig_name == 'embed_tokens.weight':
151
+ return getattr(self.language_model, name)
152
+ assert False, 'should not arrive here'
153
+
154
+ def set_input_embeddings(self, value):
155
+ self.language_model.embed_tokens = value
156
+
157
+ def set_speech_tokenizers(self, acoustic_tokenizer=None, semantic_tokenizer=None):
158
+ """Set the speech tokenizers used for encoding and decoding speech."""
159
+ self.acoustic_tokenizer = acoustic_tokenizer
160
+ self.semantic_tokenizer = semantic_tokenizer
161
+
162
+ # Reset the encoder to evaluation mode
163
+ if self.acoustic_tokenizer is not None:
164
+ self.acoustic_tokenizer.eval()
165
+
166
+ if self.semantic_tokenizer is not None:
167
+ self.semantic_tokenizer.eval()
168
+
169
+ def forward(
170
+ self,
171
+ input_ids: torch.LongTensor = None,
172
+ attention_mask: Optional[torch.Tensor] = None,
173
+ position_ids: Optional[torch.LongTensor] = None,
174
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
175
+ inputs_embeds: Optional[torch.FloatTensor] = None,
176
+ use_cache: Optional[bool] = None,
177
+ output_attentions: Optional[bool] = None,
178
+ output_hidden_states: Optional[bool] = None,
179
+ return_dict: Optional[bool] = None,
180
+ cache_position: Optional[torch.LongTensor] = None,
181
+ **kwargs,
182
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
183
+
184
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
185
+
186
+ # Forward through language model
187
+ outputs = self.language_model(
188
+ input_ids=input_ids,
189
+ attention_mask=attention_mask,
190
+ position_ids=position_ids,
191
+ past_key_values=past_key_values,
192
+ inputs_embeds=inputs_embeds,
193
+ use_cache=use_cache,
194
+ output_attentions=output_attentions,
195
+ output_hidden_states=output_hidden_states,
196
+ return_dict=return_dict,
197
+ cache_position=cache_position,
198
+ **kwargs,
199
+ )
200
+
201
+ if not return_dict:
202
+ return outputs
203
+
204
+ return BaseModelOutputWithPast(
205
+ last_hidden_state=outputs.last_hidden_state,
206
+ past_key_values=outputs.past_key_values,
207
+ hidden_states=outputs.hidden_states,
208
+ attentions=outputs.attentions,
209
+ )
210
+
211
+
212
+ class VibeVoiceForConditionalGeneration(VibeVoicePreTrainedModel):
213
+ _tied_weights_keys = ["lm_head.weight"]
214
+ _tp_plan = {"lm_head": "colwise_rep"}
215
+
216
+ def __init__(self, config):
217
+ super().__init__(config)
218
+ self.model = VibeVoiceModel(config)
219
+ self.vocab_size = config.decoder_config.vocab_size
220
+ self.lm_head = nn.Linear(config.decoder_config.hidden_size, self.vocab_size, bias=False)
221
+
222
+ self.post_init()
223
+
224
+ def get_input_embeddings(self):
225
+ return self.model.get_input_embeddings()
226
+
227
+ def set_input_embeddings(self, value):
228
+ self.model.set_input_embeddings(value)
229
+
230
+ def get_output_embeddings(self):
231
+ return self.lm_head
232
+
233
+ def set_decoder(self, decoder):
234
+ self.model.language_model = decoder
235
+
236
+ def get_decoder(self):
237
+ return self.model.language_model
238
+
239
+ def tie_weights(self):
240
+ """
241
+ Tie the weights between the input embeddings and the output embeddings.
242
+ """
243
+ if getattr(self.config.decoder_config, 'tie_word_embeddings', False):
244
+ # The standard PreTrainedModel method will handle the tying.
245
+ # It typically does a simple parameter object assignment, which is
246
+ # CORRECT to do BEFORE FSDP wraps the model.
247
+ output_embeddings = self.get_output_embeddings()
248
+ input_embeddings = self.get_input_embeddings()
249
+ if hasattr(input_embeddings, 'weight'):
250
+ output_embeddings.weight = input_embeddings.weight
251
+ else:
252
+ # maybe returned input_embeddings a tensor directly
253
+ output_embeddings.weight = input_embeddings
254
+
255
+ if getattr(output_embeddings, "bias", None) is not None:
256
+ output_embeddings.bias.data = nn.functional.pad(
257
+ output_embeddings.bias.data,
258
+ (0, output_embeddings.weight.shape[0] - output_embeddings.bias.shape[0]),
259
+ "constant",
260
+ 0,
261
+ )
262
+ print("✅ Tied input and output embeddings using standard assignment.")
263
+ else:
264
+ print("ℹ️ tie_word_embeddings is False, not tying weights.")
265
+
266
+ # Also, ensure set_output_embeddings is safe, though your implementation looks okay.
267
+ # The key is to avoid calling it after accelerator.prepare().
268
+ def set_output_embeddings(self, new_embeddings):
269
+ # Your current implementation using data.copy_ is good practice,
270
+ # but the best way is to not call this after prepare().
271
+ self.lm_head = new_embeddings
272
+
273
+ def forward_speech_features(
274
+ self,
275
+ speech_tensors=None,
276
+ speech_masks=None,
277
+ speech_type="audio",
278
+ return_unmask=False
279
+ ):
280
+ if speech_tensors is None:
281
+ # Use config to get vae_dim instead of non-existent self.args
282
+ vae_dim = self.config.acoustic_tokenizer_config.vae_dim
283
+ audio_features = torch.zeros(1, 1, vae_dim).to(self.get_input_embeddings().weight)
284
+ connect_features = self.model.acoustic_connector(audio_features)
285
+ return audio_features, connect_features
286
+ else:
287
+ with torch.no_grad():
288
+ if speech_type == "audio":
289
+ with torch.no_grad():
290
+ frames = self.model.acoustic_tokenizer.encode(speech_tensors.unsqueeze(1))[0][0]
291
+ audio_tokens = frames.sample(self.model.acoustic_tokenizer.std_dist_type)[0]
292
+
293
+ elif speech_type == "vae":
294
+ # Use config to get vae_dim instead of non-existent self.args
295
+ vae_dim = self.config.acoustic_tokenizer_config.vae_dim
296
+ speech_mode = speech_tensors.reshape(speech_tensors.size(0), -1, vae_dim)
297
+
298
+ # gaussian sample from the speech_mode
299
+ batch_size = speech_mode.size(0)
300
+ value = self.model.acoustic_tokenizer.fix_std / 0.8
301
+ std = torch.randn(batch_size, dtype=speech_mode.dtype, device=speech_mode.device) * value
302
+ std = std.view(-1, *[1] * (speech_mode.dim() - 1))
303
+ audio_tokens = speech_mode + std * torch.randn(speech_mode.shape).to(speech_mode)
304
+ else:
305
+ raise NotImplementedError(f"Speech type {speech_type} not implemented")
306
+
307
+ if torch.isnan(self.model.speech_scaling_factor) or torch.isnan(self.model.speech_bias_factor):
308
+ scaling_factor = 1. / audio_tokens[speech_masks].flatten().std()
309
+ bias_factor = -audio_tokens[speech_masks].flatten().mean()
310
+
311
+ # Only use distributed operations if the process group is initialized
312
+ if dist.is_available() and dist.is_initialized():
313
+ dist.all_reduce(scaling_factor, op=dist.ReduceOp.SUM)
314
+ dist.all_reduce(bias_factor, op=dist.ReduceOp.SUM)
315
+ world_size = dist.get_world_size()
316
+ self.model.speech_scaling_factor.copy_(scaling_factor / world_size)
317
+ self.model.speech_bias_factor.copy_(bias_factor / world_size)
318
+ print(f"Speech scaling factor (distributed): {self.model.speech_scaling_factor}, bias factor: {self.model.speech_bias_factor}", flush=True)
319
+ else:
320
+ # Single process case
321
+ self.model.speech_scaling_factor.copy_(scaling_factor)
322
+ self.model.speech_bias_factor.copy_(bias_factor)
323
+ print(f"Speech scaling factor (single process): {self.model.speech_scaling_factor}, bias factor: {self.model.speech_bias_factor}", flush=True)
324
+
325
+ audio_features = (audio_tokens + self.model.speech_bias_factor) * self.model.speech_scaling_factor
326
+
327
+ connect_features = self.model.acoustic_connector(audio_features)
328
+ if return_unmask:
329
+ return audio_features, connect_features
330
+ return audio_features[speech_masks], connect_features[speech_masks]
331
+
332
+ def forward(
333
+ self,
334
+ input_ids: torch.LongTensor = None,
335
+ attention_mask: Optional[torch.Tensor] = None,
336
+ position_ids: Optional[torch.LongTensor] = None,
337
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
338
+ inputs_embeds: Optional[torch.FloatTensor] = None,
339
+ labels: Optional[torch.LongTensor] = None,
340
+ use_cache: Optional[bool] = False,
341
+ output_attentions: Optional[bool] = None,
342
+ output_hidden_states: Optional[bool] = None,
343
+ return_dict: Optional[bool] = None,
344
+ cache_position: Optional[torch.LongTensor] = None,
345
+ # New arguments for speech processing and loss calculation
346
+ speech_tensors: Optional[torch.FloatTensor] = None,
347
+ speech_masks: Optional[torch.BoolTensor] = None,
348
+ speeches_loss_input: Optional[torch.FloatTensor] = None,
349
+ speech_semantic_tensors: Optional[torch.FloatTensor] = None,
350
+ acoustic_input_mask: Optional[torch.BoolTensor] = None,
351
+ acoustic_loss_mask: Optional[torch.BoolTensor] = None,
352
+ ddpm_batch_mul: int = 1,
353
+ **kwargs: Optional[Dict[str, Union[torch.Tensor, str]]],
354
+ ) -> Union[Tuple, VibeVoiceCausalLMOutputWithPast]:
355
+
356
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
357
+
358
+ x = self.get_input_embeddings()(input_ids)
359
+
360
+ semantic_speech_all_connect_features = self.model.semantic_connector(speech_semantic_tensors)
361
+ if speeches_loss_input is not None:
362
+ # only part audio need diffuse
363
+ speech_all_features, speech_all_connect_features = self.forward_speech_features(
364
+ speech_tensors=speech_tensors.type_as(x) if speech_tensors is not None else None,
365
+ speech_masks=speech_masks,
366
+ speech_type=kwargs.get("speech_type", "audio"),
367
+ return_unmask=True
368
+ )
369
+ if speech_tensors is not None:
370
+ if semantic_speech_all_connect_features is not None:
371
+ x[acoustic_input_mask] = speech_all_connect_features[speech_masks] + semantic_speech_all_connect_features[speech_masks]
372
+ else:
373
+ x[acoustic_input_mask] = speech_all_connect_features[speech_masks]
374
+ speech_features = speech_all_features[speeches_loss_input.unsqueeze(-1) & speech_masks] # only part audio need diffuse
375
+ speech_connect_features = speech_all_connect_features[speeches_loss_input.unsqueeze(-1) & speech_masks]
376
+ else:
377
+ speech_features, speech_connect_features = self.forward_speech_features(
378
+ speech_tensors=speech_tensors.type_as(x) if speech_tensors is not None else None,
379
+ speech_masks=speech_masks,
380
+ speech_type=kwargs.get("speech_type", "audio"),
381
+ )
382
+ if speech_tensors is not None:
383
+ x[acoustic_input_mask] = speech_connect_features
384
+
385
+ outputs = self.model(
386
+ input_ids=None,
387
+ attention_mask=attention_mask,
388
+ position_ids=position_ids,
389
+ past_key_values=past_key_values,
390
+ inputs_embeds=x,
391
+ use_cache=use_cache,
392
+ output_attentions=output_attentions,
393
+ output_hidden_states=False,
394
+ return_dict=return_dict,
395
+ cache_position=cache_position,
396
+ )
397
+
398
+ hidden_states = outputs.last_hidden_state
399
+ logits = self.lm_head(hidden_states)
400
+ # logits = logits.float()
401
+
402
+ loss = None
403
+ if labels is not None:
404
+ # The custom CE loss with masking is calculated in the training script.
405
+ # We leave the standard loss calculation here as None.
406
+ pass
407
+
408
+ # --- Diffusion Loss Calculation ---
409
+ diffusion_loss = None
410
+ # This block is executed only if we are in a context that involves speech.
411
+ if speech_tensors is not None and acoustic_loss_mask.sum().item() > 0:
412
+ condition_features = hidden_states[acoustic_loss_mask]
413
+
414
+ speech_len, latent_size = speech_features.shape
415
+
416
+ noise = torch.randn(
417
+ (speech_len * ddpm_batch_mul, latent_size),
418
+ device=hidden_states.device,
419
+ dtype=hidden_states.dtype
420
+ )
421
+
422
+ timesteps = torch.multinomial(
423
+ torch.ones(self.config.diffusion_head_config.ddpm_num_steps),
424
+ speech_len * ddpm_batch_mul,
425
+ replacement=True,
426
+ ).to(hidden_states.device)
427
+
428
+ speech_features_repeated = speech_features.repeat_interleave(ddpm_batch_mul, dim=0)
429
+ condition_features_repeated = condition_features.repeat_interleave(ddpm_batch_mul, dim=0)
430
+
431
+ noisy_speech_features = self.model.noise_scheduler.add_noise(
432
+ speech_features_repeated, noise, timesteps
433
+ )
434
+
435
+ model_output = self.model.prediction_head(
436
+ noisy_speech_features,
437
+ timesteps.type_as(x),
438
+ condition_features_repeated
439
+ )
440
+
441
+ prediction_type = self.config.diffusion_head_config.prediction_type
442
+ if prediction_type == "epsilon":
443
+ target_for_loss = noise
444
+ elif prediction_type == "v_prediction":
445
+ target_for_loss = self.model.noise_scheduler.get_velocity(
446
+ speech_features_repeated, noise, timesteps
447
+ )
448
+ else:
449
+ raise NotImplementedError(f"Prediction type {prediction_type} not implemented")
450
+
451
+ diffusion_loss = F.mse_loss(model_output.float(), target_for_loss.float(), reduction='sum')
452
+ if latent_size > 0 and ddpm_batch_mul > 0:
453
+ diffusion_loss = diffusion_loss / latent_size / ddpm_batch_mul
454
+ else:
455
+ diffusion_loss = torch.tensor(0.0, device=diffusion_loss.device)
456
+
457
+ else:
458
+ # Dummy loss for DDP to work when there are no speech samples in a batch,
459
+ # but we are in a speech context.
460
+ diffusion_loss = sum(p.sum() for p in self.model.prediction_head.parameters()) * 0.0
461
+ diffusion_loss += sum(p.sum() for p in self.model.acoustic_connector.parameters()) * 0.0
462
+ diffusion_loss += sum(p.sum() for p in self.model.semantic_connector.parameters()) * 0.0
463
+ # --- End Diffusion Loss Calculation ---
464
+
465
+ if not return_dict:
466
+ output = (logits, speech_len) + outputs.to_tuple()[1:]
467
+ return (loss, diffusion_loss) + output
468
+
469
+ return VibeVoiceCausalLMOutputWithPast(
470
+ loss=loss,
471
+ diffusion_loss=diffusion_loss,
472
+ speech_token_num=speech_len if speech_tensors is not None else 0,
473
+ logits=logits,
474
+ past_key_values=outputs.past_key_values,
475
+ hidden_states=outputs.hidden_states,
476
+ attentions=outputs.attentions,
477
+ )
478
+
479
+ AutoModel.register(VibeVoiceConfig, VibeVoiceModel)
480
+ AutoModelForCausalLM.register(VibeVoiceConfig, VibeVoiceForConditionalGeneration)
481
+
482
+ __all__ = [
483
+ "VibeVoiceModel",
484
+ "VibeVoicePreTrainedModel",
485
+ "VibeVoiceForConditionalGeneration",
486
+ "VibeVoiceCausalLMOutputWithPast",
487
+ "VibeVoiceGenerationOutput",
488
+ ]
modular/modeling_vibevoice_inference.py ADDED
@@ -0,0 +1,715 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Dict, List, Optional, Tuple, Union, Callable
3
+ from tqdm import tqdm
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ from transformers.models.auto import AutoModel, AutoModelForCausalLM
8
+
9
+ from transformers.generation import GenerationMixin, GenerationConfig, LogitsProcessor, LogitsProcessorList, StoppingCriteriaList
10
+ from transformers.modeling_outputs import BaseModelOutputWithPast, ModelOutput
11
+ from transformers import modeling_utils
12
+ from transformers.modeling_utils import PreTrainedModel
13
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
14
+ from transformers.utils import logging
15
+
16
+
17
+ # from .modular_vibevoice_tokenizer import VibeVoiceTokenizerStreamingCache, VibeVoiceAcousticTokenizerModel, VibeVoiceSemanticTokenizerModel
18
+ from .modular_vibevoice_tokenizer import VibeVoiceTokenizerStreamingCache, VibeVoiceTokenizerEncoderOutput
19
+ from .modular_vibevoice_diffusion_head import VibeVoiceDiffusionHead
20
+ from vibevoice.schedule.dpm_solver import DPMSolverMultistepScheduler
21
+
22
+ from .configuration_vibevoice import VibeVoiceConfig
23
+
24
+ from .modular_vibevoice_text_tokenizer import VibeVoiceTextTokenizer, VibeVoiceTextTokenizerFast
25
+
26
+ from .modeling_vibevoice import VibeVoiceModel, VibeVoicePreTrainedModel
27
+ from .streamer import AudioStreamer, AsyncAudioStreamer
28
+
29
+ logger = logging.get_logger(__name__)
30
+
31
+ if not hasattr(modeling_utils, "ALL_PARALLEL_STYLES") or modeling_utils.ALL_PARALLEL_STYLES is None:
32
+ modeling_utils.ALL_PARALLEL_STYLES = ["tp", "none", "colwise", "rowwise"]
33
+
34
+ @dataclass
35
+ class VibeVoiceCausalLMOutputWithPast(BaseModelOutputWithPast):
36
+ logits: Optional[torch.FloatTensor] = None
37
+
38
+ @dataclass
39
+ class VibeVoiceGenerationOutput(ModelOutput):
40
+ """
41
+ Output type for VibeVoice generation.
42
+
43
+ Args:
44
+ sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
45
+ The generated sequences.
46
+ speech_outputs (`List[torch.FloatTensor]`, *optional*):
47
+ List of generated speech waveforms or latents for each speech segment.
48
+ """
49
+ sequences: torch.LongTensor = None
50
+ speech_outputs: Optional[List[torch.FloatTensor]] = None
51
+ reach_max_step_sample: Optional[torch.BoolTensor] = None
52
+
53
+ class VibeVoiceTokenConstraintProcessor(LogitsProcessor):
54
+ """Constrains token generation to only valid tokens during speech generation."""
55
+
56
+ def __init__(self, valid_token_ids: List[int], device: torch.device = None):
57
+ self.valid_token_ids = torch.tensor(valid_token_ids, dtype=torch.long, device=device)
58
+
59
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
60
+ # Create a mask for valid tokens
61
+ mask = torch.full_like(scores, float('-inf'))
62
+ mask[:, self.valid_token_ids] = 0
63
+
64
+ # Apply mask to scores
65
+ scores = scores + mask
66
+ return scores
67
+
68
+ class VibeVoiceForConditionalGenerationInference(VibeVoicePreTrainedModel, GenerationMixin):
69
+ _tied_weights_keys = ["lm_head.weight"]
70
+ _tp_plan = {"lm_head": "colwise_rep"}
71
+
72
+ def __init__(self, config):
73
+ super().__init__(config)
74
+
75
+ # Initialize the base model
76
+ self.model = VibeVoiceModel(config)
77
+
78
+ # LM head for text generation
79
+ self.lm_head = nn.Linear(config.decoder_config.hidden_size, config.decoder_config.vocab_size, bias=False)
80
+
81
+ # inference configuration
82
+ self.ddpm_inference_steps = config.diffusion_head_config.ddpm_num_inference_steps
83
+
84
+ # Initialize weights and apply final processing
85
+ self.post_init()
86
+
87
+ @property
88
+ def noise_scheduler(self):
89
+ return self.model.noise_scheduler
90
+
91
+ @property
92
+ def prediction_head(self):
93
+ return self.model.prediction_head
94
+
95
+ @property
96
+ def speech_scaling_factor(self):
97
+ return self.model.speech_scaling_factor
98
+
99
+ @property
100
+ def speech_bias_factor(self):
101
+ return self.model.speech_bias_factor
102
+
103
+ @property
104
+ def acoustic_tokenizer(self):
105
+ return self.model.acoustic_tokenizer
106
+
107
+ @property
108
+ def semantic_tokenizer(self):
109
+ return self.model.semantic_tokenizer
110
+
111
+ @property
112
+ def acoustic_connector(self):
113
+ return self.model.acoustic_connector
114
+
115
+ @property
116
+ def semantic_connector(self):
117
+ return self.model.semantic_connector
118
+
119
+ def tie_weights(self):
120
+ """
121
+ Tie the weights between the input embeddings and the output embeddings.
122
+ """
123
+ # Tie lm_head.weight to language_model.embed_tokens.weight
124
+ if not getattr(self.config, 'tie_word_embeddings', False):
125
+ return
126
+
127
+ if hasattr(self, 'lm_head') and hasattr(self.model.language_model, 'embed_tokens'):
128
+ self.lm_head.weight = self.model.language_model.embed_tokens.weight
129
+
130
+ def get_input_embeddings(self):
131
+ return self.model.get_input_embeddings()
132
+
133
+ def set_input_embeddings(self, value):
134
+ self.model.set_input_embeddings(value)
135
+
136
+ def get_output_embeddings(self):
137
+ return self.lm_head
138
+
139
+ def set_output_embeddings(self, new_embeddings):
140
+ self.lm_head = new_embeddings
141
+
142
+ def set_speech_tokenizers(self, acoustic_tokenizer=None, semantic_tokenizer=None):
143
+ """Set the speech tokenizers used for encoding and decoding speech."""
144
+ self.model.set_speech_tokenizers(acoustic_tokenizer, semantic_tokenizer)
145
+
146
+ def set_ddpm_inference_steps(self, num_steps=None):
147
+ self.ddpm_inference_steps = num_steps or self.config.diffusion_head_config.ddpm_num_inference_steps
148
+
149
+ def _process_speech_inputs(self, speech_tensors, speech_masks, speech_type="audio"):
150
+ """Process speech inputs through tokenizers and connectors."""
151
+ with torch.no_grad():
152
+ if speech_type == "audio":
153
+ # Encode audio to acoustic latents
154
+ encoder_output = self.model.acoustic_tokenizer.encode(speech_tensors.unsqueeze(1))
155
+ acoustic_latents = encoder_output.sample(dist_type=self.model.acoustic_tokenizer.std_dist_type)[0]
156
+
157
+ # Apply scaling and bias
158
+ acoustic_features = (acoustic_latents + self.model.speech_bias_factor.to(acoustic_latents.device)) * self.model.speech_scaling_factor.to(acoustic_latents.device)
159
+
160
+ # Connect to language model space
161
+ acoustic_connected = self.model.acoustic_connector(acoustic_features)[speech_masks.cpu()]
162
+
163
+ return acoustic_features, acoustic_connected
164
+ elif speech_type == "pt":
165
+ encoder_output = VibeVoiceTokenizerEncoderOutput(mean=speech_tensors, std=self.acoustic_tokenizer.config.fix_std)
166
+ acoustic_latents = encoder_output.sample(dist_type=self.model.acoustic_tokenizer.std_dist_type)[0]
167
+
168
+ # Apply scaling and bias
169
+ acoustic_features = (acoustic_latents + self.model.speech_bias_factor.to(acoustic_latents.device)) * self.model.speech_scaling_factor.to(acoustic_latents.device)
170
+
171
+ # Connect to language model space
172
+ acoustic_connected = self.model.acoustic_connector(acoustic_features)[speech_masks.cpu()]
173
+
174
+ return acoustic_features, acoustic_connected
175
+ else:
176
+ raise NotImplementedError(f"Speech type {speech_type} not implemented")
177
+
178
+ # @can_return_tuple
179
+ def forward(
180
+ self,
181
+ input_ids: torch.LongTensor = None,
182
+ attention_mask: Optional[torch.Tensor] = None,
183
+ position_ids: Optional[torch.LongTensor] = None,
184
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
185
+ inputs_embeds: Optional[torch.FloatTensor] = None,
186
+ labels: Optional[torch.LongTensor] = None,
187
+ use_cache: Optional[bool] = None,
188
+ output_attentions: Optional[bool] = None,
189
+ output_hidden_states: Optional[bool] = None,
190
+ return_dict: Optional[bool] = None,
191
+ cache_position: Optional[torch.LongTensor] = None,
192
+ speech_tensors: Optional[torch.FloatTensor] = None,
193
+ speech_masks: Optional[torch.BoolTensor] = None,
194
+ speech_input_mask: Optional[torch.BoolTensor] = None,
195
+ logits_to_keep: Union[int, slice] = 0,
196
+ **kwargs,
197
+ ) -> Union[Tuple, VibeVoiceCausalLMOutputWithPast]:
198
+ """
199
+ Args:
200
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
201
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
202
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
203
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
204
+ speech_tensors (`torch.FloatTensor`, *optional*):
205
+ Input speech waveforms for voice cloning or speech understanding.
206
+ speech_masks (`torch.BoolTensor`, *optional*):
207
+ Masks indicating valid speech frames.
208
+ speech_input_mask (`torch.BoolTensor`, *optional*):
209
+ Positions in the input sequence where speech embeddings should be inserted.
210
+
211
+ Returns:
212
+ `VibeVoiceCausalLMOutputWithPast` or tuple
213
+ """
214
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
215
+
216
+ # Get embeddings
217
+ if inputs_embeds is None:
218
+ inputs_embeds = self.model.get_input_embeddings()(input_ids)
219
+
220
+ # Process speech inputs if provided
221
+ if speech_tensors is not None and speech_masks is not None:
222
+ acoustic_features, speech_embeds = self._process_speech_inputs(speech_tensors.to(self.dtype), speech_masks)
223
+ if speech_input_mask is not None:
224
+ inputs_embeds[speech_input_mask] = speech_embeds
225
+
226
+ outputs = self.model(
227
+ inputs_embeds=inputs_embeds,
228
+ attention_mask=attention_mask,
229
+ position_ids=position_ids,
230
+ past_key_values=past_key_values,
231
+ use_cache=use_cache,
232
+ output_attentions=output_attentions,
233
+ output_hidden_states=output_hidden_states,
234
+ return_dict=return_dict,
235
+ cache_position=cache_position,
236
+ **kwargs,
237
+ )
238
+
239
+ hidden_states = outputs[0] if not return_dict else outputs.last_hidden_state
240
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
241
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
242
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
243
+
244
+ if labels is not None:
245
+ raise NotImplementedError("Loss computation is not implemented in this version.")
246
+
247
+ return VibeVoiceCausalLMOutputWithPast(
248
+ logits=logits,
249
+ past_key_values=outputs.past_key_values,
250
+ last_hidden_state=hidden_states,
251
+ attentions=outputs.attentions,
252
+ )
253
+
254
+ def _build_generate_config_model_kwargs(self, generation_config, inputs, tokenizer, return_processors=False, **kwargs):
255
+ if generation_config is None:
256
+ generation_config = GenerationConfig(
257
+ bos_token_id=tokenizer.bos_token_id,
258
+ eos_token_id=tokenizer.eos_token_id,
259
+ pad_token_id = tokenizer.pad_token_id
260
+ )
261
+ else:
262
+ generation_config = GenerationConfig(
263
+ **generation_config,
264
+ bos_token_id=tokenizer.bos_token_id,
265
+ eos_token_id=tokenizer.eos_token_id,
266
+ pad_token_id = tokenizer.pad_token_id
267
+ )
268
+
269
+ generation_config, model_kwargs = self._prepare_generation_config(
270
+ generation_config,
271
+ True,
272
+ speech_start_id=tokenizer.speech_start_id,
273
+ speech_end_id=tokenizer.speech_end_id,
274
+ speech_diffusion_id=tokenizer.speech_diffusion_id,
275
+ **kwargs
276
+ )
277
+ generation_config.speech_start_id = tokenizer.speech_start_id
278
+ generation_config.speech_end_id = tokenizer.speech_end_id
279
+ generation_config.speech_diffusion_id = tokenizer.speech_diffusion_id
280
+
281
+ inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(inputs, generation_config.bos_token_id, model_kwargs)
282
+ batch_size = inputs_tensor.shape[0]
283
+ device = self.device
284
+
285
+ self._prepare_special_tokens(generation_config, True, device=device)
286
+ generation_config.use_cache = True
287
+ model_kwargs["use_cache"] = generation_config.use_cache
288
+ input_ids = inputs_tensor.to(self.device)
289
+
290
+ input_ids_length = input_ids.shape[1]
291
+ has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
292
+ has_default_min_length = kwargs.get("min_length") is None and generation_config.min_length is not None
293
+ generation_config = self._prepare_generated_length(
294
+ generation_config=generation_config,
295
+ has_default_max_length=has_default_max_length,
296
+ has_default_min_length=has_default_min_length,
297
+ model_input_name=model_input_name,
298
+ inputs_tensor=inputs_tensor,
299
+ input_ids_length=input_ids_length,
300
+ )
301
+
302
+ max_cache_length = generation_config.max_length - 1
303
+ self._prepare_cache_for_generation(generation_config, model_kwargs, None, batch_size, max_cache_length, device)
304
+ model_kwargs['cache_position'] = torch.arange(input_ids_length, device=device, dtype=torch.long)
305
+ for k, v in model_kwargs.items():
306
+ if isinstance(v, torch.Tensor):
307
+ model_kwargs[k] = v.to(device=device)
308
+
309
+ if return_processors:
310
+ logits_processor = self._get_logits_processor(
311
+ generation_config=generation_config,
312
+ input_ids_seq_length=input_ids_length,
313
+ encoder_input_ids=inputs_tensor,
314
+ prefix_allowed_tokens_fn=None,
315
+ logits_processor=LogitsProcessorList(),
316
+ device=inputs_tensor.device,
317
+ model_kwargs=model_kwargs,
318
+ )
319
+
320
+ stopping_criteria = self._get_stopping_criteria(generation_config=generation_config, stopping_criteria=StoppingCriteriaList())
321
+
322
+ return generation_config, model_kwargs, input_ids, logits_processor, stopping_criteria
323
+ else:
324
+ return generation_config, model_kwargs, input_ids
325
+
326
+ @torch.no_grad()
327
+ def generate(
328
+ self,
329
+ inputs: Optional[torch.Tensor] = None,
330
+ generation_config: Optional[GenerationConfig] = None,
331
+ logits_processor: Optional[LogitsProcessorList] = None,
332
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
333
+ prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
334
+ synced_gpus: Optional[bool] = None,
335
+ assistant_model: Optional["PreTrainedModel"] = None,
336
+ audio_streamer: Optional[Union[AudioStreamer, AsyncAudioStreamer]] = None,
337
+ negative_prompt_ids: Optional[torch.Tensor] = None,
338
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
339
+ speech_tensors: Optional[torch.FloatTensor] = None,
340
+ speech_masks: Optional[torch.BoolTensor] = None,
341
+ speech_input_mask: Optional[torch.BoolTensor] = None,
342
+ return_speech: bool = True,
343
+ cfg_scale: float = 1.0,
344
+ stop_check_fn: Optional[Callable[[], bool]] = None,
345
+ **kwargs,
346
+ ) -> Union[torch.LongTensor, VibeVoiceGenerationOutput]:
347
+ """
348
+ Generates sequences of token ids and optionally speech outputs.
349
+
350
+ Args:
351
+ All standard generation arguments from GenerationMixin
352
+ negative_prompt_ids: Negative prompt for CFG in speech generation
353
+ negative_prompt_attention_mask: Attention mask for negative prompt
354
+ speech_tensors: Input speech for voice cloning
355
+ speech_masks: Masks for speech tensors
356
+ speech_input_mask: Positions to insert speech embeddings
357
+ return_speech: Whether to decode and return speech outputs
358
+ cfg_scale: CFG scale for speech generation
359
+ stop_check_fn: Optional callable that returns True if generation should stop
360
+
361
+ Returns:
362
+ Generated token sequences and optionally speech outputs
363
+ """
364
+ # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
365
+ tokenizer = kwargs.pop("tokenizer", None) # Pull this out first, we only use it for stopping criteria
366
+ parsed_scripts = kwargs.pop("parsed_scripts", None)
367
+ all_speakers_list = kwargs.pop("all_speakers_list", None)
368
+ max_length_times = kwargs.pop("max_length_times", 2)
369
+
370
+ if kwargs.get('max_new_tokens', None) is None:
371
+ kwargs['max_new_tokens'] = self.config.decoder_config.max_position_embeddings - kwargs['input_ids'].shape[-1]
372
+
373
+ generation_config, model_kwargs, input_ids, logits_processor, stopping_criteria = self._build_generate_config_model_kwargs(
374
+ generation_config, inputs, tokenizer, return_processors=True, **kwargs
375
+ )
376
+
377
+ negative_kwargs = {
378
+ 'input_ids': torch.full((kwargs['input_ids'].shape[0], 1), tokenizer.speech_start_id, dtype=torch.long, device=kwargs['input_ids'].device),
379
+ 'attention_mask': torch.ones((kwargs['input_ids'].shape[0], 1), dtype=torch.long, device=kwargs['input_ids'].device),
380
+ 'max_new_tokens': kwargs.get('max_new_tokens', 100)
381
+ }
382
+ negative_generation_config, negative_model_kwargs, negative_input_ids = self._build_generate_config_model_kwargs(
383
+ None, None, tokenizer, return_processors=False, **negative_kwargs
384
+ )
385
+
386
+ acoustic_cache = VibeVoiceTokenizerStreamingCache()
387
+ semantic_cache = VibeVoiceTokenizerStreamingCache()
388
+
389
+ batch_size = input_ids.shape[0]
390
+ device = input_ids.device
391
+ finished_tags = torch.zeros(batch_size, dtype=torch.bool, device=device)
392
+ correct_cnt = torch.zeros(batch_size, dtype=torch.long, device=device)
393
+ is_prefill = True
394
+ inputs_embeds = None
395
+ verbose = kwargs.get("verbose", False)
396
+
397
+ # Initialize audio chunks storage for each sample
398
+ audio_chunks = [[] for _ in range(batch_size)]
399
+
400
+ initial_length = input_ids.shape[-1]
401
+ initial_length_per_sample = model_kwargs['attention_mask'].sum(dim=-1)
402
+
403
+ # Define all valid tokens that can be generated
404
+ valid_tokens = [
405
+ generation_config.speech_start_id,
406
+ generation_config.speech_end_id,
407
+ generation_config.speech_diffusion_id,
408
+ generation_config.eos_token_id
409
+ ]
410
+ # Add bos_token_id if it exists
411
+ if hasattr(generation_config, 'bos_token_id') and generation_config.bos_token_id is not None:
412
+ valid_tokens.append(generation_config.bos_token_id)
413
+
414
+ # Add custom processor to constrain token generation
415
+ token_constraint_processor = VibeVoiceTokenConstraintProcessor(valid_tokens, device=device)
416
+ if logits_processor is None:
417
+ logits_processor = LogitsProcessorList()
418
+ logits_processor.append(token_constraint_processor)
419
+
420
+ max_steps = min(generation_config.max_length - initial_length, int(max_length_times * initial_length))
421
+ max_step_per_sample = torch.min(generation_config.max_length - initial_length_per_sample, (max_length_times * initial_length_per_sample).long())
422
+ reach_max_step_sample = torch.zeros(batch_size, dtype=torch.bool, device=device)
423
+
424
+ # Create progress iterator if verbose
425
+ if kwargs.get("show_progress_bar", True):
426
+ progress_bar = tqdm(range(max_steps), desc="Generating", leave=False)
427
+ else:
428
+ progress_bar = range(max_steps)
429
+
430
+ for step in progress_bar:
431
+ # Check for external stop signal
432
+ if stop_check_fn is not None and stop_check_fn():
433
+ if verbose:
434
+ print(f"Generation stopped externally at step {step + 1}")
435
+ # End the audio streamer if it exists
436
+ if audio_streamer is not None:
437
+ audio_streamer.end()
438
+ break
439
+
440
+ # Check if audio_streamer has been ended (stopped externally)
441
+ if audio_streamer is not None and hasattr(audio_streamer, 'finished_flags'):
442
+ if any(audio_streamer.finished_flags):
443
+ if verbose:
444
+ print(f"Audio generation stopped externally at step {step + 1}")
445
+ break
446
+
447
+ if finished_tags.all():
448
+ if hasattr(progress_bar, 'set_description'):
449
+ progress_bar.set_description("Generation complete")
450
+ break
451
+
452
+ if input_ids.shape[-1] >= generation_config.max_length:
453
+ print(f"Reached maximum generation length {generation_config.max_length}, stopped it.")
454
+ reached_samples = torch.arange(batch_size, device=device)[~finished_tags]
455
+ if reached_samples.numel() > 0:
456
+ reach_max_step_sample[reached_samples] = True
457
+ break
458
+
459
+ # Update progress bar description with active samples
460
+ if hasattr(progress_bar, 'set_description'):
461
+ active_samples = (~finished_tags).sum().item()
462
+ progress_bar.set_description(f"Generating (active: {active_samples}/{batch_size})")
463
+
464
+ model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
465
+ if is_prefill:
466
+ # we process the speech inputs only during the first generation step
467
+ prefill_inputs = {
468
+ "speech_tensors": speech_tensors.to(device=device),
469
+ "speech_masks": speech_masks.to(device),
470
+ "speech_input_mask": speech_input_mask.to(device),
471
+ }
472
+ is_prefill = False
473
+ else:
474
+ _ = model_inputs.pop('inputs_embeds', None)
475
+ prefill_inputs = {'inputs_embeds': inputs_embeds}
476
+
477
+ # Forward pass through the model
478
+ outputs = self(
479
+ **model_inputs, **prefill_inputs, logits_to_keep=1, return_dict=True, output_attentions=False, output_hidden_states=False,
480
+ )
481
+ model_kwargs = self._update_model_kwargs_for_generation(
482
+ outputs, model_kwargs, is_encoder_decoder=False,
483
+ )
484
+
485
+ # Get logits and apply logits processor
486
+ next_token_logits = outputs.logits[:, -1, :].to(copy=True, dtype=torch.float32, device=input_ids.device)
487
+ # next_token_logits = outputs.logits[:, -1, :].to(copy=True, device=input_ids.device)
488
+ next_token_scores = logits_processor(input_ids, next_token_logits)
489
+
490
+ # token selection
491
+ if generation_config.do_sample:
492
+ probs = nn.functional.softmax(next_token_scores, dim=-1)
493
+ # TODO (joao): this OP throws "skipping cudagraphs due to ['incompatible ops']", find solution
494
+ next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
495
+ else:
496
+ next_tokens = torch.argmax(next_token_scores, dim=-1)
497
+
498
+ next_tokens[finished_tags] = generation_config.eos_token_id
499
+ input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
500
+
501
+ if not kwargs.get('refresh_negative', True):
502
+ negative_model_inputs = self.prepare_inputs_for_generation(negative_input_ids, **negative_model_kwargs)
503
+ # Forward negative pass through the model
504
+ if negative_model_inputs['inputs_embeds'] is None and inputs_embeds is not None:
505
+ negative_model_inputs['inputs_embeds'] = inputs_embeds
506
+ negative_model_inputs['input_ids'] = None
507
+
508
+ negative_outputs = self(
509
+ **negative_model_inputs, logits_to_keep=0, return_dict=True, output_attentions=False, output_hidden_states=False,
510
+ )
511
+ negative_model_kwargs = self._update_model_kwargs_for_generation(
512
+ negative_outputs, negative_model_kwargs, is_encoder_decoder=False,
513
+ )
514
+ negative_input_ids = torch.cat([negative_input_ids, next_tokens[:, None]], dim=-1)
515
+
516
+ # reached end of generation
517
+ if (next_tokens == generation_config.eos_token_id).any():
518
+ eos_indices = (next_tokens == generation_config.eos_token_id).nonzero(as_tuple=False).squeeze(1)
519
+ # Only print for samples that are newly finished (not already marked as finished)
520
+ new_eos_indices = eos_indices[~finished_tags[eos_indices]]
521
+ if new_eos_indices.numel() > 0:
522
+ finished_tags[new_eos_indices] = True
523
+ if verbose:
524
+ print(f"Samples {new_eos_indices.tolist()} reached EOS token at step {step + 1}.", flush=True)
525
+ if audio_streamer is not None:
526
+ audio_streamer.end(new_eos_indices)
527
+
528
+ # Check if any sample reached its maximum generation length
529
+ max_length_reached = step >= max_step_per_sample
530
+ new_max_length_indices = torch.nonzero(max_length_reached & ~finished_tags, as_tuple=False).squeeze(1)
531
+ if new_max_length_indices.numel() > 0:
532
+ finished_tags[new_max_length_indices] = True
533
+ reach_max_step_sample[new_max_length_indices] = True
534
+ if verbose:
535
+ print(f"Samples {new_max_length_indices.tolist()} reached max generation length at step {step + 1}.", flush=True)
536
+ if audio_streamer is not None:
537
+ audio_streamer.end(new_max_length_indices)
538
+
539
+ # speech_end
540
+ diffusion_end_indices = (next_tokens == generation_config.speech_end_id).nonzero(as_tuple=False).squeeze(1)
541
+ if diffusion_end_indices.numel() > 0:
542
+ # Clear tokenizer caches for samples that reached speech end
543
+ acoustic_cache.set_to_zero(diffusion_end_indices)
544
+ semantic_cache.set_to_zero(diffusion_end_indices)
545
+
546
+ # speech_begin
547
+ diffusion_start_indices = torch.arange(batch_size, device=device)[~finished_tags & (next_tokens == generation_config.speech_start_id)]
548
+ if diffusion_start_indices.numel() > 0 and kwargs.get('refresh_negative', True):
549
+ # update attention mask
550
+ for i, sample_idx in enumerate(diffusion_start_indices.tolist()):
551
+ negative_model_kwargs['attention_mask'][sample_idx, :] = 0
552
+ negative_model_kwargs['attention_mask'][sample_idx, -1] = 1
553
+ # update past key values
554
+ for layer_idx, (k_cache, v_cache) in enumerate(zip(negative_model_kwargs['past_key_values'].key_cache,
555
+ negative_model_kwargs['past_key_values'].value_cache)):
556
+ # Process each non-diffusion sample
557
+ for sample_idx in diffusion_start_indices.tolist():
558
+ # Shift cache for this sample
559
+ k_cache[sample_idx, :, -1, :] = k_cache[sample_idx, :, 0, :].clone()
560
+ v_cache[sample_idx, :, -1, :] = v_cache[sample_idx, :, 0, :].clone()
561
+ # update negative_input_ids
562
+ for sample_idx in diffusion_start_indices.tolist():
563
+ negative_input_ids[sample_idx, -1] = generation_config.speech_start_id
564
+
565
+ # Prepare inputs_embeds for next iteration
566
+ # Initialize with default embeddings for all tokens
567
+ next_inputs_embeds = self.model.get_input_embeddings()(next_tokens).unsqueeze(1) # [batch_size, 1, hidden_size]
568
+
569
+ # forward diffusion
570
+ # Diffusion indices are those that are not finished and not special tokens
571
+ diffusion_indices = torch.arange(batch_size, device=device)[~finished_tags & (next_tokens == generation_config.speech_diffusion_id)]
572
+
573
+ if diffusion_indices.numel() > 0:
574
+ if kwargs.get('refresh_negative', True):
575
+ negative_model_inputs = self.prepare_inputs_for_generation(negative_input_ids, **negative_model_kwargs)
576
+ # Forward negative pass through the model
577
+ if negative_model_inputs['inputs_embeds'] is None and inputs_embeds is not None:
578
+ negative_model_inputs['inputs_embeds'] = inputs_embeds
579
+ negative_model_inputs['input_ids'] = None
580
+
581
+ negative_outputs = self(
582
+ **negative_model_inputs, logits_to_keep=0, return_dict=True, output_attentions=False, output_hidden_states=False,
583
+ )
584
+ negative_model_kwargs = self._update_model_kwargs_for_generation(
585
+ negative_outputs, negative_model_kwargs, is_encoder_decoder=False,
586
+ )
587
+ negative_input_ids = torch.cat([negative_input_ids, next_tokens[:, None]], dim=-1)
588
+ # correct the non-diffusion indices
589
+ # we forward all samples' negative outputs even if
590
+ # they are not in diffusion mode to keep the cache consistent
591
+ # So we need to correct the kv cache of non-diffusion samples
592
+ non_diffusion_mask = ~finished_tags & (next_tokens != generation_config.speech_diffusion_id)
593
+ if non_diffusion_mask.any():
594
+ non_diffusion_indices = torch.arange(batch_size, device=device)[non_diffusion_mask]
595
+ start_indices = correct_cnt[non_diffusion_indices]
596
+
597
+ # 1. Update attention_mask - need to handle each sample separately
598
+ seq_len = negative_model_kwargs['attention_mask'].shape[1]
599
+ for i, (sample_idx, start_idx) in enumerate(zip(non_diffusion_indices.tolist(), start_indices.tolist())):
600
+ # Shift the attention mask for this sample
601
+ if start_idx + 1 < seq_len - 1:
602
+ negative_model_kwargs['attention_mask'][sample_idx, start_idx+1:] = \
603
+ negative_model_kwargs['attention_mask'][sample_idx, start_idx:-1].clone()
604
+ negative_model_kwargs['attention_mask'][sample_idx, start_idx] = 0
605
+
606
+ # 2. Update past_key_values
607
+ for layer_idx, (k_cache, v_cache) in enumerate(zip(negative_model_kwargs['past_key_values'].key_cache,
608
+ negative_model_kwargs['past_key_values'].value_cache)):
609
+ # Process each non-diffusion sample
610
+ for sample_idx, start_idx in zip(non_diffusion_indices.tolist(), start_indices.tolist()):
611
+ if start_idx + 1 < k_cache.shape[2] - 1:
612
+ # Shift cache for this sample
613
+ k_cache[sample_idx, :, start_idx+1:, :] = k_cache[sample_idx, :, start_idx:-1, :].clone()
614
+ v_cache[sample_idx, :, start_idx+1:, :] = v_cache[sample_idx, :, start_idx:-1, :].clone()
615
+
616
+ # 3. Update negative_input_ids
617
+ for sample_idx, start_idx in zip(non_diffusion_indices.tolist(), start_indices.tolist()):
618
+ if start_idx + 1 < negative_input_ids.shape[1] - 1:
619
+ negative_input_ids[sample_idx, start_idx+1:] = \
620
+ negative_input_ids[sample_idx, start_idx:-1].clone()
621
+
622
+ correct_cnt[non_diffusion_indices] += 1
623
+
624
+ positive_condition = outputs.last_hidden_state[diffusion_indices, -1, :]
625
+ negative_condition = negative_outputs.last_hidden_state[diffusion_indices, -1, :]
626
+
627
+ speech_latent = self.sample_speech_tokens(
628
+ positive_condition,
629
+ negative_condition,
630
+ cfg_scale=cfg_scale,
631
+ ).unsqueeze(1)
632
+
633
+ # Decode acoustic latent to audio using acoustic streaming cache
634
+ scaled_latent = speech_latent / self.model.speech_scaling_factor.to(speech_latent.device) - self.model.speech_bias_factor.to(speech_latent.device)
635
+ audio_chunk = self.model.acoustic_tokenizer.decode(
636
+ scaled_latent.to(self.model.acoustic_tokenizer.device),
637
+ cache=acoustic_cache, # Use acoustic-specific cache
638
+ sample_indices=diffusion_indices.to(self.model.acoustic_tokenizer.device),
639
+ use_cache=True,
640
+ debug=False
641
+ )
642
+
643
+ # Store audio chunks for each sample
644
+ for i, sample_idx in enumerate(diffusion_indices):
645
+ idx = sample_idx.item()
646
+ # Only append audio chunk if the sample is not finished
647
+ if not finished_tags[idx]:
648
+ audio_chunks[idx].append(audio_chunk[i])
649
+
650
+ # Add streaming support here
651
+ if audio_streamer is not None:
652
+ # Stream the audio chunks immediately
653
+ audio_streamer.put(audio_chunk, diffusion_indices)
654
+
655
+ # Encode audio to semantic features using semantic streaming cache
656
+ semantic_features = self.model.semantic_tokenizer.encode(
657
+ audio_chunk,
658
+ cache=semantic_cache, # Use semantic-specific cache
659
+ sample_indices=diffusion_indices,
660
+ use_cache=True,
661
+ debug=False
662
+ ).mean # semantic tokenizer has no VAE.
663
+
664
+ # Combine acoustic and semantic features for next input
665
+ acoustic_embed = self.model.acoustic_connector(speech_latent)
666
+ semantic_embed = self.model.semantic_connector(semantic_features)
667
+ diffusion_embeds = acoustic_embed + semantic_embed
668
+
669
+ # Update embeddings for diffusion indices
670
+ next_inputs_embeds[diffusion_indices] = diffusion_embeds
671
+
672
+ # Set inputs_embeds for next iteration
673
+ inputs_embeds = next_inputs_embeds
674
+
675
+ if audio_streamer is not None:
676
+ audio_streamer.end()
677
+
678
+ # Concatenate audio chunks for each sample
679
+ final_audio_outputs = []
680
+ for sample_chunks in audio_chunks:
681
+ if sample_chunks:
682
+ # Concatenate all chunks along the time dimension (assumed to be the last dimension)
683
+ concatenated_audio = torch.cat(sample_chunks, dim=-1)
684
+ final_audio_outputs.append(concatenated_audio)
685
+ else:
686
+ # If no audio was generated for this sample, append None
687
+ final_audio_outputs.append(None)
688
+
689
+ return VibeVoiceGenerationOutput(
690
+ sequences=input_ids,
691
+ speech_outputs=final_audio_outputs if return_speech else None,
692
+ reach_max_step_sample=reach_max_step_sample,
693
+ )
694
+
695
+ @torch.no_grad()
696
+ def sample_speech_tokens(self, condition, neg_condition, cfg_scale=3.0):
697
+ self.model.noise_scheduler.set_timesteps(self.ddpm_inference_steps)
698
+ condition = torch.cat([condition, neg_condition], dim=0).to(self.model.prediction_head.device)
699
+ speech = torch.randn(condition.shape[0], self.config.acoustic_vae_dim).to(condition)
700
+ for t in self.model.noise_scheduler.timesteps:
701
+ half = speech[: len(speech) // 2]
702
+ combined = torch.cat([half, half], dim=0)
703
+ eps = self.model.prediction_head(combined, t.repeat(combined.shape[0]).to(combined), condition=condition)
704
+ cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
705
+ half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
706
+ eps = torch.cat([half_eps, half_eps], dim=0)
707
+ speech = self.model.noise_scheduler.step(eps, t, speech).prev_sample
708
+ return speech[: len(speech) // 2]
709
+
710
+
711
+ AutoModelForCausalLM.register(VibeVoiceConfig, VibeVoiceForConditionalGenerationInference)
712
+
713
+ __all__ = [
714
+ "VibeVoiceForConditionalGenerationInference",
715
+ ]
modular/modular_vibevoice_diffusion_head.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Optional, Tuple, Union
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ from transformers.models.auto import AutoModel
9
+ from transformers.modeling_utils import PreTrainedModel
10
+ # from transformers.modeling_layers import GradientCheckpointingLayer
11
+ from transformers.activations import ACT2FN
12
+ from transformers.utils import logging
13
+
14
+ from .configuration_vibevoice import VibeVoiceDiffusionHeadConfig
15
+
16
+
17
+ logger = logging.get_logger(__name__)
18
+
19
+
20
+ class RMSNorm(nn.Module):
21
+ def __init__(self, dim: int, eps: float = 1e-6, elementwise_affine=True, memory_efficient=False):
22
+ super().__init__()
23
+ self.dim = dim
24
+ self.eps = eps
25
+ self.elementwise_affine = elementwise_affine
26
+ if self.elementwise_affine:
27
+ self.weight = nn.Parameter(torch.ones(dim))
28
+ else:
29
+ self.register_parameter('weight', None)
30
+
31
+ def _norm(self, x):
32
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
33
+
34
+ def forward(self, x):
35
+ output = self._norm(x.float()).type_as(x)
36
+ if self.weight is not None:
37
+ output = output * self.weight
38
+ return output
39
+
40
+ def extra_repr(self) -> str:
41
+ return f'dim={self.dim}, eps={self.eps}, elementwise_affine={self.elementwise_affine}'
42
+
43
+ def modulate(x, shift, scale):
44
+ """Apply modulation to input tensor."""
45
+ return x * (1 + scale) + shift
46
+
47
+
48
+ class TimestepEmbedder(nn.Module):
49
+ """
50
+ Embeds scalar timesteps into vector representations.
51
+
52
+ Args:
53
+ hidden_size (`int`): Size of the output embedding
54
+ frequency_embedding_size (`int`, optional): Size of the intermediate frequency embedding
55
+ """
56
+ def __init__(self, hidden_size, frequency_embedding_size=256):
57
+ super().__init__()
58
+ self.mlp = nn.Sequential(
59
+ nn.Linear(frequency_embedding_size, hidden_size, bias=False),
60
+ # nn.SiLU(),
61
+ ACT2FN['silu'],
62
+ nn.Linear(hidden_size, hidden_size, bias=False),
63
+ )
64
+ self.frequency_embedding_size = frequency_embedding_size
65
+
66
+ @staticmethod
67
+ def timestep_embedding(t, dim, max_period=10000):
68
+ """
69
+ Create sinusoidal timestep embeddings.
70
+
71
+ Args:
72
+ t (`torch.Tensor`): A 1-D Tensor of N indices, one per batch element.
73
+ These may be fractional.
74
+ dim (`int`): The dimension of the output.
75
+ max_period (`int`, optional): Controls the minimum frequency of the embeddings.
76
+
77
+ Returns:
78
+ `torch.Tensor`: An [N, D] Tensor of positional embeddings.
79
+ """
80
+ half = dim // 2
81
+ freqs = torch.exp(
82
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
83
+ ).to(t.device)
84
+ args = t[:, None].float() * freqs[None]
85
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
86
+ if dim % 2:
87
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
88
+ return embedding.to(t.dtype)
89
+
90
+ def forward(self, t):
91
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
92
+ t_emb = self.mlp(t_freq)
93
+ return t_emb
94
+
95
+
96
+ class FeedForwardNetwork(nn.Module):
97
+ """
98
+ Standard feed-forward network with SwiGLU activation.
99
+
100
+ Args:
101
+ embed_dim (`int`): Input dimension
102
+ ffn_dim (`int`): Hidden dimension
103
+ """
104
+ def __init__(
105
+ self,
106
+ embed_dim,
107
+ ffn_dim,
108
+ ):
109
+ super().__init__()
110
+ self.embed_dim = embed_dim
111
+ self.gate_proj = nn.Linear(self.embed_dim, ffn_dim, bias=False)
112
+ self.up_proj = nn.Linear(self.embed_dim, ffn_dim, bias=False)
113
+ self.down_proj = nn.Linear(ffn_dim, self.embed_dim, bias=False)
114
+ self.act_fn = ACT2FN['silu'] # Using SiLU as the activation function
115
+
116
+ def forward(self, x):
117
+ gate = self.gate_proj(x)
118
+ up = self.up_proj(x)
119
+
120
+ # SwiGLU activation
121
+ # gate = F.silu(gate)
122
+ gate = self.act_fn(gate)
123
+ return self.down_proj(gate * up)
124
+
125
+
126
+ class HeadLayer(nn.Module):
127
+ """
128
+ A layer in the diffusion head.
129
+
130
+ Args:
131
+ embed_dim (`int`): Input dimension
132
+ ffn_dim (`int`): Hidden dimension
133
+ cond_dim (`int`): Condition embedding dimension
134
+ norm_eps (`float`, optional): Epsilon for normalization
135
+ """
136
+ def __init__(
137
+ self,
138
+ embed_dim,
139
+ ffn_dim,
140
+ cond_dim,
141
+ norm_eps=1e-5,
142
+ ):
143
+ super().__init__()
144
+ self.embed_dim = embed_dim
145
+ self.cond_dim = cond_dim
146
+ self.ffn_dim = ffn_dim
147
+ self.ffn = FeedForwardNetwork(
148
+ self.embed_dim,
149
+ self.ffn_dim,
150
+ )
151
+ self.norm = RMSNorm(self.embed_dim, eps=norm_eps)
152
+ self.adaLN_modulation = nn.Sequential(
153
+ # nn.SiLU(),
154
+ ACT2FN['silu'],
155
+ nn.Linear(cond_dim, 3 * self.embed_dim, bias=False)
156
+ )
157
+
158
+ def forward(self, x, c):
159
+ shift_ffn, scale_ffn, gate_ffn = self.adaLN_modulation(c).chunk(3, dim=-1)
160
+ x = x + gate_ffn * self.ffn(modulate(self.norm(x), shift_ffn, scale_ffn))
161
+ return x
162
+
163
+
164
+ class FinalLayer(nn.Module):
165
+ """
166
+ Final layer in the diffusion head.
167
+
168
+ Args:
169
+ hidden_size (`int`): Input dimension
170
+ output_size (`int`): Output dimension
171
+ cond_size (`int`): Condition embedding dimension
172
+ norm_eps (`float`, optional): Epsilon for normalization
173
+ """
174
+ def __init__(self, hidden_size, output_size, cond_size, norm_eps=1e-5):
175
+ super().__init__()
176
+ self.norm_final = RMSNorm(hidden_size, eps=norm_eps, elementwise_affine=False)
177
+ self.linear = nn.Linear(hidden_size, output_size, bias=False)
178
+ self.adaLN_modulation = nn.Sequential(
179
+ # nn.SiLU(),
180
+ ACT2FN['silu'],
181
+ nn.Linear(cond_size, 2 * hidden_size, bias=False)
182
+ )
183
+
184
+ def forward(self, x, c):
185
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
186
+ x = modulate(self.norm_final(x), shift, scale)
187
+ x = self.linear(x)
188
+ return x
189
+
190
+
191
+ class VibeVoiceDiffusionHead(PreTrainedModel):
192
+ """
193
+ Diffusion head model for vibevoice.
194
+
195
+ Args:
196
+ config (`VibeVoiceDiffusionHeadConfig`): Model configuration
197
+ latent_size (`int`, optional): Size of the latent space. If not provided, uses `config.latent_size`.
198
+ """
199
+ config_class = VibeVoiceDiffusionHeadConfig
200
+ supports_gradient_checkpointing = True
201
+ _supports_flash_attn_2 = True
202
+ _supports_sdpa = True
203
+
204
+ def __init__(
205
+ self,
206
+ config,
207
+ ):
208
+ super().__init__(config)
209
+ self.config = config
210
+ self.cond_dim = config.hidden_size
211
+ latent_size = config.latent_size
212
+
213
+ self.noisy_images_proj = nn.Linear(latent_size, config.hidden_size, bias=False)
214
+ self.cond_proj = nn.Linear(config.hidden_size, self.cond_dim, bias=False)
215
+ self.t_embedder = TimestepEmbedder(self.cond_dim)
216
+
217
+ ffn_dim = int(config.hidden_size * config.head_ffn_ratio)
218
+
219
+ # Create the intermediate layers
220
+ self.layers = nn.ModuleList([
221
+ HeadLayer(
222
+ embed_dim=config.hidden_size,
223
+ ffn_dim=ffn_dim,
224
+ cond_dim=self.cond_dim,
225
+ norm_eps=config.rms_norm_eps
226
+ )
227
+ for _ in range(config.head_layers)
228
+ ])
229
+
230
+ # Final layer for output
231
+ self.final_layer = FinalLayer(
232
+ hidden_size=config.hidden_size,
233
+ output_size=latent_size,
234
+ cond_size=self.cond_dim,
235
+ norm_eps=config.rms_norm_eps
236
+ )
237
+
238
+ self.initialize_weights()
239
+
240
+ def initialize_weights(self):
241
+ """Initialize the weights of the model."""
242
+ # Initialize timestep embedder
243
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
244
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
245
+
246
+ # Zero-out adaLN modulation layers
247
+ for layer in self.layers:
248
+ nn.init.constant_(layer.adaLN_modulation[-1].weight, 0)
249
+
250
+ # Zero-out output layers
251
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
252
+ nn.init.constant_(self.final_layer.linear.weight, 0)
253
+
254
+ def forward(
255
+ self,
256
+ noisy_images,
257
+ timesteps,
258
+ condition,
259
+ ):
260
+ """
261
+ Forward pass of the prediction head.
262
+
263
+ Args:
264
+ noisy_images (`torch.Tensor`): Noisy images/latents to denoise
265
+ timesteps (`torch.Tensor`): Timesteps for diffusion
266
+ condition (`torch.Tensor`): Conditioning information
267
+
268
+ Returns:
269
+ `torch.Tensor`: The predicted noise/velocity
270
+ """
271
+ x = self.noisy_images_proj(noisy_images)
272
+ t = self.t_embedder(timesteps)
273
+ condition = self.cond_proj(condition)
274
+ c = condition + t
275
+
276
+ for layer in self.layers:
277
+ x = layer(x, c)
278
+
279
+ x = self.final_layer(x, c)
280
+ return x
281
+
282
+
283
+ AutoModel.register(VibeVoiceDiffusionHeadConfig, VibeVoiceDiffusionHead)
284
+
285
+ __all__ = [
286
+ "VibeVoiceDiffusionHead",
287
+ ]
modular/modular_vibevoice_text_tokenizer.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tokenization classes for vibevoice."""
2
+
3
+ from typing import List, Optional, Union
4
+
5
+ from transformers.utils import logging
6
+ from transformers.models.qwen2.tokenization_qwen2 import Qwen2Tokenizer
7
+ from transformers.models.qwen2.tokenization_qwen2_fast import Qwen2TokenizerFast
8
+
9
+ logger = logging.get_logger(__name__)
10
+
11
+
12
+ class VibeVoiceTextTokenizer(Qwen2Tokenizer):
13
+ """
14
+ Construct a VibeVoice tokenizer. Based on the Qwen2 tokenizer with additional special tokens for speech.
15
+
16
+ Args:
17
+ vocab_file (`str`):
18
+ Path to the vocabulary file.
19
+ merges_file (`str`):
20
+ Path to the merges file.
21
+ errors (`str`, *optional*, defaults to `"replace"`):
22
+ Paradigm to follow when decoding bytes to UTF-8.
23
+ unk_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
24
+ The unknown token.
25
+ bos_token (`str`, *optional*):
26
+ The beginning of sequence token. Not used for vibevoice.
27
+ eos_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
28
+ The end of sequence token.
29
+ pad_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
30
+ The token used for padding.
31
+ add_special_tokens (`bool`, *optional*, defaults to `True`):
32
+ Whether or not to add special tokens when encoding.
33
+ """
34
+
35
+ model_input_names = ["input_ids", "attention_mask"]
36
+
37
+ def __init__(
38
+ self,
39
+ vocab_file,
40
+ merges_file,
41
+ errors="replace",
42
+ unk_token="<|endoftext|>",
43
+ bos_token=None,
44
+ eos_token="<|endoftext|>",
45
+ pad_token="<|endoftext|>",
46
+ add_prefix_space=False,
47
+ add_special_tokens=True,
48
+ **kwargs,
49
+ ):
50
+ super().__init__(
51
+ vocab_file=vocab_file,
52
+ merges_file=merges_file,
53
+ errors=errors,
54
+ unk_token=unk_token,
55
+ bos_token=bos_token,
56
+ eos_token=eos_token,
57
+ pad_token=pad_token,
58
+ add_prefix_space=add_prefix_space,
59
+ add_special_tokens=add_special_tokens,
60
+ **kwargs,
61
+ )
62
+
63
+ # Add VibeVoice-specific special tokens
64
+ self._add_vibevoice_special_tokens()
65
+
66
+ def _add_vibevoice_special_tokens(self):
67
+ """Add VibeVoice-specific special tokens."""
68
+ special_tokens = {
69
+ "additional_special_tokens": [
70
+ "<|vision_start|>", # Speech start (reusing vision tokens)
71
+ "<|vision_end|>", # Speech end
72
+ "<|vision_pad|>", # Speech diffusion pad
73
+ ]
74
+ }
75
+ num_added = self.add_special_tokens(special_tokens)
76
+
77
+ # Cache special token IDs
78
+ self._speech_start_id = self.convert_tokens_to_ids("<|vision_start|>")
79
+ self._speech_end_id = self.convert_tokens_to_ids("<|vision_end|>")
80
+ self._speech_diffusion_id = self.convert_tokens_to_ids("<|vision_pad|>")
81
+
82
+ self._eos_id = self.convert_tokens_to_ids('<|endoftext|>')
83
+
84
+ return num_added
85
+
86
+ @property
87
+ def eos_id(self) -> int:
88
+ """Id of the end of sequence token."""
89
+ return self._eos_id
90
+
91
+ @property
92
+ def speech_start_id(self) -> int:
93
+ """Id of the speech start token."""
94
+ return self._speech_start_id
95
+
96
+ @property
97
+ def speech_end_id(self) -> int:
98
+ """Id of the speech end token."""
99
+ return self._speech_end_id
100
+
101
+ @property
102
+ def speech_diffusion_id(self) -> int:
103
+ """Id of the speech diffusion token."""
104
+ return self._speech_diffusion_id
105
+
106
+ @property
107
+ def pad_id(self) -> int:
108
+ """Id used for padding (returns -100 for loss masking)."""
109
+ return -100
110
+
111
+
112
+ class VibeVoiceTextTokenizerFast(Qwen2TokenizerFast):
113
+ """
114
+ Construct a "fast" VibeVoice tokenizer (backed by HuggingFace's *tokenizers* library).
115
+ Based on the Qwen2 tokenizer with additional special tokens for speech.
116
+
117
+ Args:
118
+ vocab_file (`str`, *optional*):
119
+ Path to the vocabulary file.
120
+ merges_file (`str`, *optional*):
121
+ Path to the merges file.
122
+ tokenizer_file (`str`, *optional*):
123
+ Path to [tokenizers](https://github.com/huggingface/tokenizers) file.
124
+ unk_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
125
+ The unknown token.
126
+ bos_token (`str`, *optional*):
127
+ The beginning of sequence token. Not used for vibevoice.
128
+ eos_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
129
+ The end of sequence token.
130
+ pad_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
131
+ The token used for padding.
132
+ """
133
+
134
+ model_input_names = ["input_ids", "attention_mask"]
135
+
136
+ def __init__(
137
+ self,
138
+ vocab_file=None,
139
+ merges_file=None,
140
+ tokenizer_file=None,
141
+ unk_token="<|endoftext|>",
142
+ bos_token=None,
143
+ eos_token="<|endoftext|>",
144
+ pad_token="<|endoftext|>",
145
+ add_prefix_space=False,
146
+ **kwargs,
147
+ ):
148
+ super().__init__(
149
+ vocab_file=vocab_file,
150
+ merges_file=merges_file,
151
+ tokenizer_file=tokenizer_file,
152
+ unk_token=unk_token,
153
+ bos_token=bos_token,
154
+ eos_token=eos_token,
155
+ pad_token=pad_token,
156
+ add_prefix_space=add_prefix_space,
157
+ **kwargs,
158
+ )
159
+
160
+ # Add VibeVoice-specific special tokens
161
+ self._add_vibevoice_special_tokens()
162
+
163
+ def _add_vibevoice_special_tokens(self):
164
+ """Add VibeVoice-specific special tokens."""
165
+ special_tokens = {
166
+ "additional_special_tokens": [
167
+ "<|vision_start|>", # Speech start (reusing vision tokens)
168
+ "<|vision_end|>", # Speech end
169
+ "<|vision_pad|>", # Speech diffusion pad
170
+ ]
171
+ }
172
+ num_added = self.add_special_tokens(special_tokens)
173
+
174
+ # Cache special token IDs
175
+ self._speech_start_id = self.convert_tokens_to_ids("<|vision_start|>")
176
+ self._speech_end_id = self.convert_tokens_to_ids("<|vision_end|>")
177
+ self._speech_diffusion_id = self.convert_tokens_to_ids("<|vision_pad|>")
178
+
179
+ # self._eos_id = self.convert_tokens_to_ids('<|endoftext|>')
180
+ self._eos_id = self.eos_token_id # qwen2 / qwen3
181
+ self._pad_id = self.convert_tokens_to_ids('<|image_pad|>')
182
+
183
+ return num_added
184
+
185
+ @property
186
+ def eos_id(self) -> int:
187
+ """Id of the end of sequence token."""
188
+ return self._eos_id
189
+
190
+ @property
191
+ def speech_start_id(self) -> int:
192
+ """Id of the speech start token."""
193
+ return self._speech_start_id
194
+
195
+ @property
196
+ def speech_end_id(self) -> int:
197
+ """Id of the speech end token."""
198
+ return self._speech_end_id
199
+
200
+ @property
201
+ def speech_diffusion_id(self) -> int:
202
+ """Id of the speech diffusion token."""
203
+ return self._speech_diffusion_id
204
+
205
+ @property
206
+ def pad_id(self) -> int:
207
+ """Id used for padding (returns -100 for loss masking)."""
208
+ return self._pad_id
209
+
210
+
211
+ __all__ = [
212
+ "VibeVoiceTextTokenizer",
213
+ "VibeVoiceTextTokenizerFast",
214
+ ]
modular/modular_vibevoice_tokenizer.py ADDED
@@ -0,0 +1,1195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import typing as tp
3
+ from functools import partial
4
+ from dataclasses import dataclass, field
5
+ from typing import Dict, List, Optional, Tuple, Union
6
+ import copy
7
+
8
+ import numpy as np
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+
13
+ from transformers.models.auto import AutoModel
14
+
15
+ from transformers.configuration_utils import PretrainedConfig
16
+ from transformers.utils import logging
17
+ from transformers.modeling_utils import PreTrainedModel
18
+ from transformers.activations import ACT2FN
19
+
20
+ from .configuration_vibevoice import VibeVoiceAcousticTokenizerConfig, VibeVoiceSemanticTokenizerConfig
21
+
22
+ logger = logging.get_logger(__name__)
23
+
24
+ import os
25
+ # Try to import APEX FusedRMSNorm
26
+ try:
27
+ from apex.normalization.fused_layer_norm import fused_rms_norm_affine
28
+ APEX_AVAILABLE = True
29
+ logger.info("APEX FusedRMSNorm is available and will be used for optimization")
30
+ if int(os.getenv("OPTIMIZE_FOR_SPEED", "0")) == 0:
31
+ APEX_AVAILABLE = False
32
+ logger.warning("APEX FusedRMSNorm is disabled by environment variable OPTIMIZE_FOR_SPEED=0")
33
+ except ImportError:
34
+ APEX_AVAILABLE = False
35
+ logger.warning("APEX FusedRMSNorm not available, using native implementation")
36
+ # APEX_AVAILABLE=False
37
+
38
+ # Normalization modules
39
+ class ConvLayerNorm(nn.LayerNorm):
40
+ """
41
+ Convolution-friendly LayerNorm that moves channels to last dimensions
42
+ before running the normalization and moves them back to original position right after.
43
+ """
44
+ def __init__(self, normalized_shape: tp.Union[int, tp.List[int], torch.Size], **kwargs):
45
+ super().__init__(normalized_shape, **kwargs)
46
+
47
+ def forward(self, x):
48
+ x = x.transpose(1, 2) # b ... t -> b t ...
49
+ x = nn.functional.layer_norm(x.float(), self.normalized_shape, self.weight.float(), self.bias.float(), self.eps).type_as(x)
50
+ x = x.transpose(1, 2) # b t ... -> b ... t
51
+ return x
52
+
53
+ class RMSNorm(nn.Module):
54
+ def __init__(self, dim: int, eps: float = 1e-5, elementwise_affine=True, weight_shape=None):
55
+ super().__init__()
56
+ self.dim = dim
57
+ self.eps = eps
58
+ self.elementwise_affine = elementwise_affine
59
+ if self.elementwise_affine:
60
+ weight_shape = (dim,) if weight_shape is None else weight_shape
61
+ self.weight = nn.Parameter(torch.ones(weight_shape))
62
+ else:
63
+ self.register_parameter('weight', None)
64
+
65
+ def _norm(self, x):
66
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
67
+
68
+ def forward(self, x):
69
+ output = self._norm(x.float()).type_as(x)
70
+ if self.weight is not None:
71
+ output = output * self.weight
72
+ return output
73
+
74
+ def extra_repr(self) -> str:
75
+ return f'dim={self.dim}, eps={self.eps}, elementwise_affine={self.elementwise_affine}'
76
+
77
+ class ConvRMSNorm(RMSNorm):
78
+ def __init__(self, dim: int, eps: float = 1e-5, elementwise_affine=True, weight_shape=None):
79
+ super().__init__(dim, eps, elementwise_affine, weight_shape)
80
+
81
+ def forward(self, x):
82
+ x = x.transpose(1, 2) # b ... t -> b t ...
83
+ if (not APEX_AVAILABLE) or (not self.elementwise_affine):
84
+ # Fallback to native implementation
85
+ output = self._norm(x.float()).type_as(x)
86
+ if self.weight is not None:
87
+ output = output * self.weight
88
+ else:
89
+ output = fused_rms_norm_affine(x, self.weight, self.weight.shape, self.eps)
90
+ output = output.transpose(1, 2) # b t ... -> b ... t
91
+ return output
92
+
93
+ # Convolutional layers and utilities
94
+ CONV_NORMALIZATIONS = frozenset(['none', 'weight_norm', 'spectral_norm',
95
+ 'time_layer_norm', 'layer_norm', 'time_group_norm'])
96
+
97
+
98
+ def apply_parametrization_norm(module: nn.Module, norm: str = 'none') -> nn.Module:
99
+ assert norm in CONV_NORMALIZATIONS
100
+ if norm == 'weight_norm':
101
+ return nn.utils.weight_norm(module)
102
+ elif norm == 'spectral_norm':
103
+ return nn.utils.spectral_norm(module)
104
+ else:
105
+ # We already check was in CONV_NORMALIZATION, so any other choice
106
+ # doesn't need reparametrization.
107
+ return module
108
+
109
+
110
+ def get_norm_module(module: nn.Module, causal: bool = False, norm: str = 'none', **norm_kwargs) -> nn.Module:
111
+ """Return the proper normalization module. If causal is True, this will ensure the returned
112
+ module is causal, or return an error if the normalization doesn't support causal evaluation.
113
+ """
114
+ assert norm in CONV_NORMALIZATIONS
115
+ if norm == 'layer_norm':
116
+ assert isinstance(module, nn.modules.conv._ConvNd)
117
+ return ConvLayerNorm(module.out_channels, **norm_kwargs)
118
+ elif norm == 'time_group_norm':
119
+ if causal:
120
+ raise ValueError("GroupNorm doesn't support causal evaluation.")
121
+ assert isinstance(module, nn.modules.conv._ConvNd)
122
+ return nn.GroupNorm(1, module.out_channels, **norm_kwargs)
123
+ else:
124
+ return nn.Identity()
125
+
126
+
127
+ def get_extra_padding_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int,
128
+ padding_total: int = 0) -> int:
129
+ """Calculate extra padding needed for convolution to have the same output length"""
130
+ length = x.shape[-1]
131
+ n_frames = (length - kernel_size + padding_total) / stride + 1
132
+ ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
133
+ return ideal_length - length
134
+
135
+
136
+ def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'zero', value: float = 0.):
137
+ """Pad 1D input with handling for small inputs in reflect mode"""
138
+ length = x.shape[-1]
139
+ padding_left, padding_right = paddings
140
+ assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
141
+ if mode == 'reflect':
142
+ max_pad = max(padding_left, padding_right)
143
+ extra_pad = 0
144
+ if length <= max_pad:
145
+ extra_pad = max_pad - length + 1
146
+ x = F.pad(x, (0, extra_pad))
147
+ padded = F.pad(x, paddings, mode, value)
148
+ end = padded.shape[-1] - extra_pad
149
+ return padded[..., :end]
150
+ else:
151
+ return F.pad(x, paddings, mode, value)
152
+
153
+
154
+ def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]):
155
+ """Remove padding from x, handling properly zero padding. Only for 1d!"""
156
+ padding_left, padding_right = paddings
157
+ assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
158
+ assert (padding_left + padding_right) <= x.shape[-1]
159
+ end = x.shape[-1] - padding_right
160
+ return x[..., padding_left: end]
161
+
162
+
163
+ class NormConv1d(nn.Module):
164
+ """Wrapper around Conv1d and normalization applied to this conv"""
165
+ def __init__(self, *args, causal: bool = False, norm: str = 'none',
166
+ norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
167
+ super().__init__()
168
+ self.conv = apply_parametrization_norm(nn.Conv1d(*args, **kwargs), norm)
169
+ self.norm = get_norm_module(self.conv, causal, norm, **norm_kwargs)
170
+ self.norm_type = norm
171
+
172
+ def forward(self, x):
173
+ x = self.conv(x)
174
+ x = self.norm(x)
175
+ return x
176
+
177
+
178
+ class NormConvTranspose1d(nn.Module):
179
+ """Wrapper around ConvTranspose1d and normalization applied to this conv"""
180
+ def __init__(self, *args, causal: bool = False, norm: str = 'none',
181
+ norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
182
+ super().__init__()
183
+ self.convtr = apply_parametrization_norm(nn.ConvTranspose1d(*args, **kwargs), norm)
184
+ self.norm = get_norm_module(self.convtr, causal, norm, **norm_kwargs)
185
+ self.norm_type = norm
186
+
187
+ def forward(self, x):
188
+ x = self.convtr(x)
189
+ x = self.norm(x)
190
+ return x
191
+
192
+
193
+ class VibeVoiceTokenizerStreamingCache:
194
+ """Cache for streaming convolution, similar to KV cache in attention"""
195
+ def __init__(self):
196
+ self.cache = {} # Dict mapping (layer_id, sample_idx) to state tensor
197
+
198
+ def get(self, layer_id: str, sample_indices: torch.Tensor) -> Optional[torch.Tensor]:
199
+ """Get cached states for given layer and sample indices"""
200
+ states = []
201
+ max_length = 0
202
+
203
+ # First pass: collect states and find max length
204
+ for idx in sample_indices.tolist():
205
+ key = (layer_id, idx)
206
+ if key not in self.cache:
207
+ return None # If any sample is missing, return None
208
+ state = self.cache[key]
209
+ states.append(state)
210
+ max_length = max(max_length, state.shape[-1])
211
+
212
+ # Second pass: pad states to max length if needed
213
+ if len(states) > 0 and states[0].dim() >= 2:
214
+ padded_states = []
215
+ for state in states:
216
+ if state.shape[-1] < max_length:
217
+ # Pad on the time dimension (last dimension)
218
+ pad_size = max_length - state.shape[-1]
219
+ # Pad with zeros on the LEFT to align the most recent samples
220
+ padded_state = F.pad(state, (pad_size, 0), mode='constant', value=0)
221
+ padded_states.append(padded_state)
222
+ else:
223
+ padded_states.append(state)
224
+ return torch.stack(padded_states, dim=0)
225
+ else:
226
+ return torch.stack(states, dim=0)
227
+
228
+ def set(self, layer_id: str, sample_indices: torch.Tensor, states: torch.Tensor):
229
+ """Set cached states for given layer and sample indices"""
230
+ for i, idx in enumerate(sample_indices.tolist()):
231
+ key = (layer_id, idx)
232
+ self.cache[key] = states[i].detach()
233
+
234
+ def set_to_zero(self, sample_indices: torch.Tensor):
235
+ """Set all cached states to zero for given sample indices"""
236
+ for key in list(self.cache.keys()):
237
+ layer_id, sample_idx = key
238
+ if sample_idx in sample_indices.tolist():
239
+ # Create zero tensor with same shape and dtype as cached tensor
240
+ cached_tensor = self.cache[key]
241
+ self.cache[key] = torch.zeros_like(cached_tensor)
242
+
243
+ def clear(self, layer_id: Optional[str] = None, sample_indices: Optional[torch.Tensor] = None):
244
+ """Clear cache for specific layer/samples or everything"""
245
+ if layer_id is None and sample_indices is None:
246
+ self.cache.clear()
247
+ elif layer_id is not None and sample_indices is None:
248
+ # Clear all samples for a specific layer
249
+ keys_to_remove = [k for k in self.cache.keys() if k[0] == layer_id]
250
+ for k in keys_to_remove:
251
+ del self.cache[k]
252
+ elif layer_id is not None and sample_indices is not None:
253
+ # Clear specific samples for a specific layer
254
+ for idx in sample_indices.tolist():
255
+ key = (layer_id, idx)
256
+ self.cache.pop(key, None)
257
+
258
+ class SConv1d(nn.Module):
259
+ """Conv1d with built-in handling of asymmetric or causal padding and normalization."""
260
+ def __init__(self, in_channels: int, out_channels: int,
261
+ kernel_size: int, stride: int = 1, dilation: int = 1,
262
+ groups: int = 1, bias: bool = True, causal: bool = False,
263
+ norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {},
264
+ pad_mode: str = 'reflect'):
265
+ super().__init__()
266
+ self.conv = NormConv1d(in_channels, out_channels, kernel_size, stride,
267
+ dilation=dilation, groups=groups, bias=bias, causal=causal,
268
+ norm=norm, norm_kwargs=norm_kwargs)
269
+ self.causal = causal
270
+ self.pad_mode = pad_mode
271
+
272
+ # Store configuration
273
+ self.kernel_size = kernel_size
274
+ self.dilation = dilation
275
+ self.stride = stride
276
+ self.in_channels = in_channels
277
+ self.out_channels = out_channels
278
+
279
+ # For causal convolution, we need to maintain kernel_size - 1 samples as context
280
+ # need to check use which context_size is more suitable
281
+ # self.context_size = (kernel_size - 1) * dilation
282
+ self.context_size = (kernel_size - 1) * dilation - (stride - 1)
283
+
284
+ # For non-streaming mode, calculate padding
285
+ self.padding_total = (kernel_size - 1) * dilation - (stride - 1)
286
+
287
+ # Create a unique layer ID for cache management
288
+ self._layer_id = None
289
+
290
+ @property
291
+ def layer_id(self):
292
+ if self._layer_id is None:
293
+ self._layer_id = f"sconv1d_{id(self)}"
294
+ return self._layer_id
295
+
296
+ def forward(self, x: torch.Tensor,
297
+ cache: Optional[VibeVoiceTokenizerStreamingCache] = None,
298
+ sample_indices: Optional[torch.Tensor] = None,
299
+ use_cache: bool = False,
300
+ debug: bool = False) -> torch.Tensor:
301
+ """
302
+ Forward pass with optional streaming support via cache.
303
+
304
+ Args:
305
+ x: Input tensor [batch_size, channels, time]
306
+ cache: VibeVoiceTokenizerStreamingCache object for maintaining states
307
+ sample_indices: Indices identifying each sample for cache management
308
+ use_cache: Whether to use cached states for streaming
309
+ debug: Whether to print debug information
310
+
311
+ Returns:
312
+ Output tensor
313
+ """
314
+ B, C, T = x.shape
315
+
316
+ # Non-streaming mode
317
+ if not use_cache or cache is None:
318
+ return self._forward_non_streaming(x, debug=debug)
319
+
320
+ # Streaming mode
321
+ assert self.causal, "Streaming mode is only supported for causal convolutions"
322
+ assert sample_indices is not None, "sample_indices must be provided for streaming mode"
323
+ assert len(sample_indices) == B, "sample_indices must match batch size"
324
+
325
+ return self._forward_streaming(x, cache, sample_indices, debug)
326
+
327
+ def _forward_streaming(self, x: torch.Tensor,
328
+ cache: VibeVoiceTokenizerStreamingCache,
329
+ sample_indices: torch.Tensor,
330
+ debug: bool = False) -> torch.Tensor:
331
+ """Streaming forward pass with cache operations kept separate from compiled code"""
332
+ B, C, T = x.shape
333
+
334
+ # Cache operations (not compiled)
335
+ cached_states = cache.get(self.layer_id, sample_indices)
336
+
337
+ if cached_states is None:
338
+ # First chunk - initialize with zeros for context
339
+ if self.context_size > 0:
340
+ cached_states = torch.zeros(B, C, self.context_size, device=x.device, dtype=x.dtype)
341
+ if debug:
342
+ print(f"[DEBUG] Initialized cache with shape: {cached_states.shape}, context_size={self.context_size}")
343
+ else:
344
+ cached_states = torch.zeros(B, C, 0, device=x.device, dtype=x.dtype)
345
+ if debug:
346
+ print(f"[DEBUG] No context needed (kernel_size=stride)")
347
+
348
+ # Concatenate cached states with input
349
+ if cached_states.shape[2] > 0:
350
+ input_with_context = torch.cat([cached_states, x], dim=2)
351
+ else:
352
+ input_with_context = x
353
+
354
+ if debug:
355
+ print(f"[DEBUG] Input shape: {x.shape}, Cache shape: {cached_states.shape}, Combined: {input_with_context.shape}")
356
+
357
+ # Apply convolution directly - no extra padding in streaming mode
358
+ # The conv layer will handle its own padding internally
359
+ output = self.conv(input_with_context)
360
+
361
+ if debug:
362
+ print(f"[DEBUG] Output shape: {output.shape}")
363
+
364
+ # Update cache for next chunk
365
+ if self.context_size > 0:
366
+ # Calculate how many samples to keep
367
+ total_input_length = input_with_context.shape[2]
368
+
369
+ # Keep the last context_size samples
370
+ if total_input_length >= self.context_size:
371
+ new_cache_start = total_input_length - self.context_size
372
+ new_cache = input_with_context[:, :, new_cache_start:]
373
+ else:
374
+ # If we have less than context_size samples, keep everything
375
+ new_cache = input_with_context
376
+
377
+ if debug:
378
+ print(f"[DEBUG] New cache shape: {new_cache.shape}")
379
+
380
+ cache.set(self.layer_id, sample_indices, new_cache)
381
+
382
+ return output
383
+
384
+ def _forward_non_streaming(self, x: torch.Tensor, debug: bool = False) -> torch.Tensor:
385
+ """Standard forward pass without streaming"""
386
+ B, C, T = x.shape
387
+ kernel_size = self.kernel_size
388
+ stride = self.stride
389
+ dilation = self.dilation
390
+ padding_total = self.padding_total
391
+
392
+ # Compute extra padding for stride alignment
393
+ extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
394
+
395
+ if debug:
396
+ print(f"[DEBUG NON-STREAMING] Input shape: {x.shape}, padding_total={padding_total}, extra_padding={extra_padding}")
397
+
398
+ if self.causal:
399
+ # Left padding for causal
400
+ if self.pad_mode == 'constant':
401
+ x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode, value=0)
402
+ else:
403
+ x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode)
404
+ else:
405
+ # Symmetric padding for non-causal
406
+ padding_right = padding_total // 2
407
+ padding_left = padding_total - padding_right
408
+ x = pad1d(x, (padding_left, padding_right + extra_padding), mode=self.pad_mode)
409
+
410
+ if debug:
411
+ print(f"[DEBUG NON-STREAMING] After padding: {x.shape}")
412
+
413
+ output = self.conv(x)
414
+
415
+ if debug:
416
+ print(f"[DEBUG NON-STREAMING] Output shape: {output.shape}")
417
+
418
+ return output
419
+
420
+
421
+ class SConvTranspose1d(nn.Module):
422
+ """ConvTranspose1d with built-in handling of asymmetric or causal padding and normalization."""
423
+ def __init__(self, in_channels: int, out_channels: int,
424
+ kernel_size: int, stride: int = 1, causal: bool = False,
425
+ norm: str = 'none', trim_right_ratio: float = 1.,
426
+ norm_kwargs: tp.Dict[str, tp.Any] = {}, bias: bool = True):
427
+ super().__init__()
428
+ self.convtr = NormConvTranspose1d(in_channels, out_channels, kernel_size, stride,
429
+ causal=causal, norm=norm, norm_kwargs=norm_kwargs, bias=bias)
430
+ self.causal = causal
431
+ self.trim_right_ratio = trim_right_ratio
432
+ assert self.causal or self.trim_right_ratio == 1., \
433
+ "`trim_right_ratio` != 1.0 only makes sense for causal convolutions"
434
+ assert self.trim_right_ratio >= 0. and self.trim_right_ratio <= 1.
435
+
436
+ # Store configuration
437
+ self.kernel_size = kernel_size
438
+ self.stride = stride
439
+ self.in_channels = in_channels
440
+ self.out_channels = out_channels
441
+
442
+ # For transposed convolution, padding calculation is different
443
+ self.padding_total = kernel_size - stride
444
+
445
+ # For streaming, we need to keep track of input history
446
+ # Transposed conv needs to see multiple input samples to produce correct output
447
+ self.context_size = kernel_size - 1
448
+
449
+ # Create a unique layer ID for cache management
450
+ self._layer_id = None
451
+
452
+ @property
453
+ def layer_id(self):
454
+ if self._layer_id is None:
455
+ self._layer_id = f"sconvtr1d_{id(self)}"
456
+ return self._layer_id
457
+
458
+ def forward(self, x: torch.Tensor,
459
+ cache: Optional[VibeVoiceTokenizerStreamingCache] = None,
460
+ sample_indices: Optional[torch.Tensor] = None,
461
+ use_cache: bool = False,
462
+ debug: bool = False) -> torch.Tensor:
463
+ """
464
+ Forward pass with optional streaming support via cache.
465
+ """
466
+ B, C, T = x.shape
467
+
468
+ # Non-streaming mode
469
+ if not use_cache or cache is None:
470
+ return self._forward_non_streaming(x, debug=debug)
471
+
472
+ # Streaming mode
473
+ assert sample_indices is not None, "sample_indices must be provided for streaming mode"
474
+ assert len(sample_indices) == B, "sample_indices must match batch size"
475
+
476
+ return self._forward_streaming(x, cache, sample_indices, debug)
477
+
478
+ def _forward_streaming(self, x: torch.Tensor,
479
+ cache: VibeVoiceTokenizerStreamingCache,
480
+ sample_indices: torch.Tensor,
481
+ debug: bool = False) -> torch.Tensor:
482
+ """Streaming forward pass with cache operations kept separate from compiled code"""
483
+ B, C, T = x.shape
484
+
485
+ # Cache operations (not compiled)
486
+ cached_input = cache.get(self.layer_id, sample_indices)
487
+
488
+ if cached_input is None:
489
+ # First chunk - no history yet
490
+ cached_input = torch.zeros(B, C, 0, device=x.device, dtype=x.dtype)
491
+ if debug:
492
+ print(f"[DEBUG] Initialized empty cache for transposed conv")
493
+
494
+ # Concatenate cached input with new input
495
+ full_input = torch.cat([cached_input, x], dim=2)
496
+
497
+ if debug:
498
+ print(f"[DEBUG] Input shape: {x.shape}, Cache shape: {cached_input.shape}, Combined: {full_input.shape}")
499
+
500
+ # First chunk or debug mode - use uncompiled version
501
+ full_output = self.convtr(full_input)
502
+
503
+ if debug:
504
+ print(f"[DEBUG] Full transposed conv output shape: {full_output.shape}")
505
+
506
+ # Calculate padding to remove
507
+ if self.causal:
508
+ padding_right = math.ceil(self.padding_total * self.trim_right_ratio)
509
+ padding_left = self.padding_total - padding_right
510
+ else:
511
+ padding_right = self.padding_total // 2
512
+ padding_left = self.padding_total - padding_right
513
+
514
+ # Remove padding
515
+ if padding_left + padding_right > 0:
516
+ full_output = unpad1d(full_output, (padding_left, padding_right))
517
+
518
+ if debug:
519
+ print(f"[DEBUG] After unpadding: {full_output.shape}")
520
+
521
+ # Determine which part of the output corresponds to the new input
522
+ if cached_input.shape[2] == 0:
523
+ # First chunk - return all output
524
+ output = full_output
525
+ else:
526
+ # Subsequent chunks - return only the new output
527
+ expected_new_output = T * self.stride
528
+
529
+ # Take the last expected_new_output samples
530
+ if full_output.shape[2] >= expected_new_output:
531
+ output = full_output[:, :, -expected_new_output:]
532
+ else:
533
+ output = full_output
534
+
535
+ if debug:
536
+ print(f"[DEBUG] Final streaming output shape: {output.shape}")
537
+
538
+ # Update cache
539
+ if full_input.shape[2] > self.context_size:
540
+ new_cache = full_input[:, :, -self.context_size:]
541
+ else:
542
+ new_cache = full_input
543
+
544
+ if debug:
545
+ print(f"[DEBUG] New cache shape: {new_cache.shape}")
546
+
547
+ cache.set(self.layer_id, sample_indices, new_cache)
548
+
549
+ return output
550
+
551
+ def _forward_non_streaming(self, x: torch.Tensor, debug: bool = False) -> torch.Tensor:
552
+ """Standard forward pass without streaming"""
553
+ if debug:
554
+ print(f"[DEBUG NON-STREAMING] Input shape: {x.shape}")
555
+
556
+ # Apply transposed convolution
557
+ y = self.convtr(x)
558
+
559
+ if debug:
560
+ print(f"[DEBUG NON-STREAMING] After transposed conv: {y.shape}")
561
+
562
+ # Calculate and remove padding
563
+ if self.causal:
564
+ padding_right = math.ceil(self.padding_total * self.trim_right_ratio)
565
+ padding_left = self.padding_total - padding_right
566
+ else:
567
+ padding_right = self.padding_total // 2
568
+ padding_left = self.padding_total - padding_right
569
+
570
+ if padding_left + padding_right > 0:
571
+ y = unpad1d(y, (padding_left, padding_right))
572
+
573
+ if debug:
574
+ print(f"[DEBUG NON-STREAMING] Final output shape: {y.shape}")
575
+
576
+ return y
577
+
578
+ # FFN
579
+ class FFN(nn.Module):
580
+ def __init__(
581
+ self,
582
+ embed_dim,
583
+ ffn_dim,
584
+ bias=False,
585
+ ):
586
+ super().__init__()
587
+ self.embed_dim = embed_dim
588
+ self.linear1 = nn.Linear(self.embed_dim, ffn_dim, bias=bias)
589
+ self.gelu = ACT2FN["gelu"]
590
+ self.linear2 = nn.Linear(ffn_dim, self.embed_dim, bias=bias)
591
+
592
+ def forward(self, x):
593
+ x = self.linear1(x)
594
+ x = self.gelu(x)
595
+ x = self.linear2(x)
596
+ return x
597
+
598
+
599
+ class Convlayer(nn.Module):
600
+ def __init__(
601
+ self,
602
+ in_channels,
603
+ out_channels,
604
+ kernel_size,
605
+ stride=1,
606
+ dilation=1,
607
+ groups=1,
608
+ bias=True,
609
+ pad_mode='zeros',
610
+ norm='weight_norm',
611
+ causal=True,
612
+ ):
613
+ super().__init__()
614
+ self.conv = SConv1d(in_channels, out_channels, kernel_size, stride=stride, dilation=dilation,
615
+ groups=groups, bias=bias, pad_mode=pad_mode, norm=norm, causal=causal)
616
+
617
+ def forward(self, x):
618
+ return self.conv(x)
619
+
620
+ class Block1D(nn.Module):
621
+ def __init__(self, dim, kernel_size=7, drop_path=0., mixer_layer='conv',
622
+ layer_scale_init_value=1e-6, **kwargs):
623
+ super().__init__()
624
+
625
+ if kwargs.get('layernorm', 'LN') == 'LN':
626
+ self.norm = ConvLayerNorm(dim, eps=kwargs.get('eps', 1e-6))
627
+ self.ffn_norm = ConvLayerNorm(dim, eps=kwargs.get('eps', 1e-6))
628
+ elif kwargs.get('layernorm', 'RMSNorm') == 'RMSNorm':
629
+ self.norm = ConvRMSNorm(dim, eps=kwargs.get('eps', 1e-6))
630
+ self.ffn_norm = ConvRMSNorm(dim, eps=kwargs.get('eps', 1e-6))
631
+
632
+ if mixer_layer == 'conv':
633
+ self.mixer = Convlayer(dim, dim, groups=kwargs.get('groups', 1),
634
+ kernel_size=kernel_size,
635
+ pad_mode=kwargs.get('pad_mode', 'reflect'),
636
+ norm=kwargs.get('norm', 'none'),
637
+ causal=kwargs.get('causal', True),
638
+ bias=kwargs.get('bias', True),
639
+ )
640
+ elif mixer_layer == 'depthwise_conv':
641
+ self.mixer = Convlayer(dim, dim, groups=dim,
642
+ kernel_size=kernel_size,
643
+ pad_mode=kwargs.get('pad_mode', 'reflect'),
644
+ norm=kwargs.get('norm', 'none'),
645
+ causal=kwargs.get('causal', True),
646
+ bias=kwargs.get('bias', True),
647
+ )
648
+ else:
649
+ raise ValueError(f"Unsupported mixer layer: {mixer_layer}")
650
+
651
+ self.ffn = FFN(
652
+ dim,
653
+ kwargs.get('ffn_expansion', 4) * dim,
654
+ bias=kwargs.get('bias', False),
655
+ )
656
+ self.drop_path = nn.Identity() if drop_path <= 0. else nn.modules.DropPath(drop_path)
657
+
658
+ if layer_scale_init_value > 0:
659
+ self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
660
+ self.ffn_gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
661
+ else:
662
+ self.gamma = None
663
+ self.ffn_gamma = None
664
+
665
+ def forward(self, x):
666
+ # mixer
667
+ residual = x
668
+ x = self.norm(x)
669
+ x = self.mixer(x)
670
+ if self.gamma is not None:
671
+ x = x * self.gamma.unsqueeze(-1)
672
+ x = residual + self.drop_path(x)
673
+
674
+ # ffn
675
+ residual = x
676
+ x = self.ffn_norm(x)
677
+ x = x.permute(0, 2, 1)
678
+ x = self.ffn(x)
679
+ x = x.permute(0, 2, 1)
680
+ if self.ffn_gamma is not None:
681
+ x = x * self.ffn_gamma.unsqueeze(-1)
682
+ x = residual + self.drop_path(x)
683
+
684
+ return x
685
+
686
+
687
+ class TokenizerEncoder(nn.Module):
688
+ """
689
+ Encoder component for the VibeVoice tokenizer that converts audio to latent representations.
690
+
691
+ Args:
692
+ config: Configuration object with model parameters
693
+ """
694
+ def __init__(self, config):
695
+ super().__init__()
696
+
697
+ # Extract parameters from config
698
+ self.channels = config.channels
699
+ self.dimension = config.dimension
700
+ self.n_filters = config.n_filters
701
+ self.ratios = list(reversed(config.ratios))
702
+ self.depths = config.depths
703
+ self.n_residual_layers = getattr(config, "n_residual_layers", 1)
704
+ self.hop_length = np.prod(self.ratios)
705
+ self.causal = config.causal
706
+
707
+ # Additional config parameters with defaults
708
+ kernel_size = getattr(config, "kernel_size", 7)
709
+ last_kernel_size = getattr(config, "last_kernel_size", 7)
710
+ norm = getattr(config, "norm", "none")
711
+ norm_params = getattr(config, "norm_params", {})
712
+ pad_mode = getattr(config, "pad_mode", "reflect")
713
+ bias = getattr(config, "bias", True)
714
+ layernorm = getattr(config, "layernorm", "LN")
715
+ layernorm_eps = getattr(config, "layernorm_eps", 1e-6)
716
+ layernorm_elementwise_affine = getattr(config, "layernorm_elementwise_affine", True)
717
+ drop_path_rate = getattr(config, "drop_path_rate", 0.0)
718
+ mixer_layer = getattr(config, "mixer_layer", "conv")
719
+ layer_scale_init_value = getattr(config, "layer_scale_init_value", 0)
720
+ disable_last_norm = getattr(config, "disable_last_norm", False)
721
+
722
+ # determine the norm type based on layernorm
723
+ if layernorm == 'LN':
724
+ norm_type = ConvLayerNorm
725
+ elif layernorm == 'RMSNorm':
726
+ norm_type = partial(ConvRMSNorm, elementwise_affine=layernorm_elementwise_affine)
727
+ else:
728
+ raise ValueError(f"Unsupported norm type: {layernorm}")
729
+
730
+ # stem and intermediate downsampling conv layers
731
+ stem = nn.Sequential(
732
+ SConv1d(self.channels, self.n_filters, kernel_size, norm=norm, norm_kwargs=norm_params, causal=self.causal, pad_mode=pad_mode, bias=bias),
733
+ )
734
+
735
+ self.downsample_layers = nn.ModuleList()
736
+ self.downsample_layers.append(stem)
737
+ for i in range(len(self.ratios)):
738
+ in_ch = self.n_filters * (2 ** i)
739
+ out_ch = self.n_filters * (2 ** (i + 1))
740
+ downsample_layer = nn.Sequential(
741
+ SConv1d(in_ch, out_ch, kernel_size=self.ratios[i] * 2, stride=self.ratios[i], causal=self.causal, pad_mode=pad_mode, norm=norm, bias=bias)
742
+ )
743
+ self.downsample_layers.append(downsample_layer)
744
+
745
+ # configure the transformer blocks
746
+ layer_type = partial(
747
+ Block1D,
748
+ mixer_layer=mixer_layer,
749
+ layernorm=layernorm,
750
+ eps=layernorm_eps,
751
+ causal=self.causal,
752
+ pad_mode=pad_mode,
753
+ norm=norm,
754
+ bias=bias,
755
+ layer_scale_init_value=layer_scale_init_value,
756
+ )
757
+
758
+ self.stages = nn.ModuleList()
759
+ dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))]
760
+ cur = 0
761
+
762
+ for i in range(len(self.depths)):
763
+ in_ch = self.n_filters * (2 ** i)
764
+ stage = nn.Sequential(
765
+ *[layer_type(dim=in_ch, drop_path=dp_rates[cur + j]) for j in range(self.depths[i])]
766
+ )
767
+ self.stages.append(stage)
768
+ cur += self.depths[i]
769
+
770
+ if not disable_last_norm:
771
+ self.norm = norm_type(in_ch, eps=layernorm_eps)
772
+ else:
773
+ self.norm = nn.Identity()
774
+ self.head = SConv1d(in_ch, self.dimension, kernel_size=last_kernel_size, causal=self.causal, pad_mode=pad_mode, norm=norm, bias=bias)
775
+
776
+ def forward_features(self, x, cache=None, sample_indices=None, use_cache=False, debug=False):
777
+ for i in range(len(self.depths)):
778
+ # Apply downsampling
779
+ for layer in self.downsample_layers[i]:
780
+ if isinstance(layer, SConv1d):
781
+ x = layer(x, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
782
+ else:
783
+ x = layer(x)
784
+
785
+ # Apply stage (Block1D contains Convlayer which contains SConv1d)
786
+ for block in self.stages[i]:
787
+ if hasattr(block, 'mixer') and hasattr(block.mixer, 'conv') and isinstance(block.mixer.conv, SConv1d):
788
+ # Block1D forward with cache support
789
+ residual = x
790
+ x = block.norm(x)
791
+ x = block.mixer.conv(x, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
792
+ if block.gamma is not None:
793
+ x = x * block.gamma.unsqueeze(-1)
794
+ x = residual + x
795
+
796
+ # FFN part
797
+ residual = x
798
+ x = block.ffn_norm(x)
799
+ x = x.permute(0, 2, 1)
800
+ x = block.ffn(x)
801
+ x = x.permute(0, 2, 1)
802
+ if block.ffn_gamma is not None:
803
+ x = x * block.ffn_gamma.unsqueeze(-1)
804
+ x = residual + x
805
+ else:
806
+ x = block(x)
807
+
808
+ return self.norm(x)
809
+
810
+ def forward(self, x, cache=None, sample_indices=None, use_cache=False, debug=False):
811
+ x = self.forward_features(x, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
812
+ x = self.head(x, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
813
+ return x
814
+
815
+
816
+ class TokenizerDecoder(nn.Module):
817
+ """
818
+ Decoder component for the VibeVoice tokenizer that converts latent representations back to audio.
819
+
820
+ Args:
821
+ config: Configuration object with model parameters
822
+ """
823
+ def __init__(self, config):
824
+ super().__init__()
825
+
826
+ # Extract parameters from config
827
+ self.dimension = config.dimension
828
+ self.channels = config.channels
829
+ self.n_filters = config.n_filters
830
+ self.ratios = config.ratios
831
+
832
+ # IMPORTANT CHANGE: Don't reverse depths again since they're already reversed in VibeVoiceAcousticTokenizerModel
833
+ self.depths = config.depths # Changed from list(reversed(config.depths))
834
+
835
+ self.n_residual_layers = getattr(config, "n_residual_layers", 1)
836
+ self.hop_length = np.prod(self.ratios)
837
+ self.causal = config.causal
838
+
839
+ # Additional config parameters with defaults
840
+ kernel_size = getattr(config, "kernel_size", 7)
841
+ last_kernel_size = getattr(config, "last_kernel_size", 7)
842
+ norm = getattr(config, "norm", "none")
843
+ norm_params = getattr(config, "norm_params", {})
844
+ pad_mode = getattr(config, "pad_mode", "reflect")
845
+ bias = getattr(config, "bias", True)
846
+ layernorm = getattr(config, "layernorm", "LN")
847
+ layernorm_eps = getattr(config, "layernorm_eps", 1e-6)
848
+ trim_right_ratio = getattr(config, "trim_right_ratio", 1.0)
849
+ layernorm_elementwise_affine = getattr(config, "layernorm_elementwise_affine", True)
850
+ drop_path_rate = getattr(config, "drop_path_rate", 0.0)
851
+ mixer_layer = getattr(config, "mixer_layer", "conv")
852
+ layer_scale_init_value = getattr(config, "layer_scale_init_value", 0)
853
+ disable_last_norm = getattr(config, "disable_last_norm", False)
854
+
855
+ # determine the norm type based on layernorm
856
+ if layernorm == 'LN':
857
+ norm_type = ConvLayerNorm
858
+ elif layernorm == 'RMSNorm':
859
+ norm_type = partial(ConvRMSNorm, elementwise_affine=layernorm_elementwise_affine)
860
+ else:
861
+ raise ValueError(f"Unsupported norm type: {layernorm}")
862
+
863
+ # stem and upsampling layers
864
+ stem = nn.Sequential(
865
+ SConv1d(self.dimension, self.n_filters * 2 ** (len(self.depths) - 1), kernel_size, norm=norm,
866
+ norm_kwargs=norm_params, causal=self.causal, pad_mode=pad_mode, bias=bias),
867
+ )
868
+
869
+ self.upsample_layers = nn.ModuleList()
870
+ self.upsample_layers.append(stem)
871
+ for i in range(len(self.ratios)):
872
+ in_ch = self.n_filters * (2 ** (len(self.depths) - 1 - i))
873
+ out_ch = self.n_filters * (2 ** (len(self.depths) - 1 - i - 1))
874
+ upsample_layer = nn.Sequential(
875
+ SConvTranspose1d(in_ch, out_ch,
876
+ kernel_size=self.ratios[i] * 2, stride=self.ratios[i],
877
+ norm=norm, norm_kwargs=norm_params, bias=bias,
878
+ causal=self.causal, trim_right_ratio=trim_right_ratio),
879
+ )
880
+ self.upsample_layers.append(upsample_layer)
881
+
882
+ # configure transformer blocks
883
+ layer_type = partial(
884
+ Block1D,
885
+ mixer_layer=mixer_layer,
886
+ layernorm=layernorm,
887
+ eps=layernorm_eps,
888
+ causal=self.causal,
889
+ pad_mode=pad_mode,
890
+ norm=norm,
891
+ bias=bias,
892
+ layer_scale_init_value=layer_scale_init_value,
893
+ )
894
+
895
+ self.stages = nn.ModuleList()
896
+ dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))]
897
+ cur = 0
898
+
899
+ # Create stages in the same order as the original model
900
+ for i in range(len(self.depths)):
901
+ in_ch = self.n_filters * (2 ** (len(self.depths) - 1 - i))
902
+ stage = nn.Sequential(
903
+ *[layer_type(dim=in_ch, drop_path=dp_rates[cur + j]) for j in range(self.depths[i])]
904
+ )
905
+ self.stages.append(stage)
906
+ cur += self.depths[i]
907
+
908
+ if not disable_last_norm:
909
+ self.norm = norm_type(in_ch, eps=layernorm_eps)
910
+ else:
911
+ self.norm = nn.Identity()
912
+ self.head = SConv1d(in_ch, self.channels, kernel_size=last_kernel_size, causal=self.causal, pad_mode=pad_mode, norm=norm, bias=bias)
913
+
914
+ def forward_features(self, x, cache=None, sample_indices=None, use_cache=False, debug=False):
915
+ for i in range(len(self.depths)):
916
+ # Apply upsampling
917
+ for layer in self.upsample_layers[i]:
918
+ if isinstance(layer, (SConv1d, SConvTranspose1d)):
919
+ x = layer(x, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
920
+ else:
921
+ x = layer(x)
922
+
923
+ # Apply stage (Block1D contains Convlayer which contains SConv1d)
924
+ for block in self.stages[i]:
925
+ if hasattr(block, 'mixer') and hasattr(block.mixer, 'conv') and isinstance(block.mixer.conv, SConv1d):
926
+ # Block1D forward with cache support
927
+ residual = x
928
+ x = block.norm(x)
929
+ x = block.mixer.conv(x, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
930
+ if block.gamma is not None:
931
+ x = x * block.gamma.unsqueeze(-1)
932
+ x = residual + x
933
+
934
+ # FFN part
935
+ residual = x
936
+ x = block.ffn_norm(x)
937
+ x = x.permute(0, 2, 1)
938
+ x = block.ffn(x)
939
+ x = x.permute(0, 2, 1)
940
+ if block.ffn_gamma is not None:
941
+ x = x * block.ffn_gamma.unsqueeze(-1)
942
+ x = residual + x
943
+ else:
944
+ x = block(x)
945
+
946
+ return self.norm(x)
947
+
948
+ def forward(self, x, cache=None, sample_indices=None, use_cache=False, debug=False):
949
+ x = self.forward_features(x, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
950
+ x = self.head(x, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
951
+ return x
952
+
953
+
954
+ @dataclass
955
+ class VibeVoiceTokenizerEncoderOutput:
956
+ """
957
+ Output of VibeVoice tokenizer encoder, representing a Gaussian distribution with fixed variance.
958
+
959
+ Args:
960
+ mean (`torch.FloatTensor`): The mean parameters of the distribution.
961
+ std (`float` or `torch.FloatTensor`): Fixed standard deviation value.
962
+ """
963
+ mean: torch.Tensor
964
+ std: Optional[Union[float, torch.Tensor]] = None
965
+
966
+ def sample(self, dist_type='fix'):
967
+ """
968
+ Sample from the distribution.
969
+
970
+ Args:
971
+ dist_type (`str`): Sampling method, either 'fix' or 'gaussian'.
972
+
973
+ Returns:
974
+ `torch.FloatTensor`: Sampled values.
975
+ `torch.FloatTensor` (optional): Standard deviation used (only when dist_type='gaussian').
976
+ """
977
+ if dist_type == 'fix':
978
+ x = self.mean + self.std * torch.randn_like(self.mean)
979
+ return x, self.std
980
+ elif dist_type == 'gaussian':
981
+ batch_size = self.mean.size(0)
982
+ value = self.std / 0.8
983
+ std = torch.randn(batch_size, device=self.mean.device, dtype=self.mean.dtype) * value
984
+
985
+ while std.dim() < self.mean.dim():
986
+ std = std.unsqueeze(-1)
987
+
988
+ x = self.mean + std * torch.randn_like(self.mean)
989
+ return x, std
990
+ else:
991
+ return self.mean, self.std
992
+
993
+ def kl(self):
994
+ """Compute KL divergence between this distribution and a standard normal."""
995
+ target = torch.zeros_like(self.mean)
996
+ return F.mse_loss(self.mean, target, reduction='none')
997
+
998
+ def mode(self):
999
+ """Return the distribution mode (which is the mean for Gaussian)."""
1000
+ return self.mean
1001
+
1002
+ class VibeVoiceAcousticTokenizerModel(PreTrainedModel):
1003
+ """VibeVoice speech tokenizer model combining encoder and decoder for acoustic tokens"""
1004
+
1005
+ config_class = VibeVoiceAcousticTokenizerConfig
1006
+ base_model_prefix = "vibevoice_acoustic_tokenizer"
1007
+ _supports_flash_attn_2 = True
1008
+ _supports_sdpa = True
1009
+ _no_split_modules = ["TokenizerEncoder", "TokenizerDecoder"]
1010
+
1011
+ def __init__(self, config):
1012
+ super().__init__(config)
1013
+
1014
+ self.register_buffer('fix_std', torch.tensor(config.fix_std), persistent=False)
1015
+ self.std_dist_type = getattr(config, "std_dist_type", "fix")
1016
+
1017
+ # Parse encoder depths
1018
+ if isinstance(config.encoder_depths, str):
1019
+ encoder_depths = [int(d) for d in config.encoder_depths.split('-')]
1020
+ else:
1021
+ encoder_depths = config.encoder_depths
1022
+
1023
+ # Parse decoder depths if provided
1024
+ if config.decoder_depths is not None and isinstance(config.decoder_depths, str):
1025
+ decoder_depths = [int(d) for d in config.decoder_depths.split('-')]
1026
+ else:
1027
+ # Default: use reversed encoder depths if decoder_depths is None
1028
+ decoder_depths = list(reversed(encoder_depths))
1029
+
1030
+ # Create encoder config
1031
+ encoder_config = copy.deepcopy(config)
1032
+ encoder_config.dimension = config.vae_dim
1033
+ encoder_config.n_filters = config.encoder_n_filters
1034
+ encoder_config.ratios = config.encoder_ratios
1035
+ encoder_config.depths = encoder_depths
1036
+ encoder_config.norm = config.conv_norm
1037
+ encoder_config.pad_mode = config.pad_mode
1038
+ encoder_config.bias = config.conv_bias
1039
+ encoder_config.layernorm_eps = config.layernorm_eps
1040
+ encoder_config.layernorm_elementwise_affine = config.layernorm_elementwise_affine
1041
+ encoder_config.mixer_layer = config.mixer_layer
1042
+ encoder_config.layer_scale_init_value = config.layer_scale_init_value
1043
+ encoder_config.disable_last_norm = config.disable_last_norm
1044
+
1045
+ # Create decoder config
1046
+ decoder_config = copy.deepcopy(config)
1047
+ decoder_config.dimension = config.vae_dim
1048
+ decoder_config.n_filters = config.decoder_n_filters
1049
+ decoder_config.ratios = config.decoder_ratios
1050
+ decoder_config.depths = decoder_depths
1051
+ decoder_config.norm = config.conv_norm
1052
+ decoder_config.pad_mode = config.pad_mode
1053
+ decoder_config.bias = config.conv_bias
1054
+ decoder_config.layernorm_eps = config.layernorm_eps
1055
+ decoder_config.layernorm_elementwise_affine = config.layernorm_elementwise_affine
1056
+ decoder_config.mixer_layer = config.mixer_layer
1057
+ decoder_config.layer_scale_init_value = config.layer_scale_init_value
1058
+ decoder_config.disable_last_norm = config.disable_last_norm
1059
+
1060
+ # Initialize encoder and decoder
1061
+ self.encoder = TokenizerEncoder(encoder_config)
1062
+ self.decoder = TokenizerDecoder(decoder_config)
1063
+
1064
+ # Initialize weights
1065
+ self.apply(self._init_weights)
1066
+
1067
+ def _init_weights(self, module):
1068
+ """Initialize weights for the model"""
1069
+ if isinstance(module, nn.Linear):
1070
+ nn.init.normal_(module.weight, std=self.config.weight_init_value)
1071
+ if module.bias is not None:
1072
+ nn.init.zeros_(module.bias)
1073
+ elif isinstance(module, nn.LayerNorm):
1074
+ nn.init.ones_(module.weight)
1075
+ nn.init.zeros_(module.bias)
1076
+ elif isinstance(module, nn.Conv1d):
1077
+ nn.init.normal_(module.weight, std=self.config.weight_init_value)
1078
+ if module.bias is not None:
1079
+ nn.init.zeros_(module.bias)
1080
+
1081
+ @torch.no_grad()
1082
+ def encode(self, audio, cache=None, sample_indices=None, use_cache=False, debug=False):
1083
+ """Convert audio to latent representations"""
1084
+ latents = self.encoder(audio, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
1085
+ return VibeVoiceTokenizerEncoderOutput(mean=latents.permute(0, 2, 1), std=self.fix_std)
1086
+
1087
+ @torch.no_grad()
1088
+ def sampling(self, encoder_output, dist_type=None):
1089
+ """Sample from the encoder output distribution"""
1090
+ dist_type = dist_type or self.std_dist_type
1091
+
1092
+ if dist_type == 'fix':
1093
+ return encoder_output.sample(dist_type='fix')
1094
+ elif dist_type == 'gaussian':
1095
+ return encoder_output.sample(dist_type='gaussian')
1096
+ else:
1097
+ raise ValueError(f"Unsupported dist_type: {dist_type}, expected 'fix' or 'gaussian'")
1098
+
1099
+ @torch.no_grad()
1100
+ def decode(self, latents, cache=None, sample_indices=None, use_cache=False, debug=False):
1101
+ """Convert latent representations back to audio"""
1102
+ if latents.shape[1] == self.config.vae_dim:
1103
+ pass
1104
+ else:
1105
+ latents = latents.permute(0, 2, 1)
1106
+
1107
+ audio = self.decoder(latents, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
1108
+ return audio
1109
+
1110
+ def forward(self, audio, cache=None, sample_indices=None, use_cache=False, debug=False):
1111
+ """Full forward pass: encode audio to latents, then decode back to audio"""
1112
+ encoder_output = self.encode(audio, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
1113
+ sampled_latents, _ = self.sampling(encoder_output)
1114
+ reconstructed = self.decode(sampled_latents, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
1115
+ return reconstructed, sampled_latents
1116
+
1117
+
1118
+ class VibeVoiceSemanticTokenizerModel(PreTrainedModel):
1119
+ """VibeVoice speech tokenizer model with only encoder for semantic tokens"""
1120
+
1121
+ config_class = VibeVoiceSemanticTokenizerConfig
1122
+ base_model_prefix = "vibevoice_semantic_tokenizer"
1123
+ _supports_flash_attn_2 = True
1124
+ _supports_sdpa = True
1125
+ _no_split_modules = ["TokenizerEncoder"]
1126
+
1127
+ def __init__(self, config):
1128
+ super().__init__(config)
1129
+
1130
+ # Parse encoder depths
1131
+ if isinstance(config.encoder_depths, str):
1132
+ encoder_depths = [int(d) for d in config.encoder_depths.split('-')]
1133
+ else:
1134
+ encoder_depths = config.encoder_depths
1135
+
1136
+ # Create encoder config
1137
+ encoder_config = copy.deepcopy(config)
1138
+ encoder_config.dimension = config.vae_dim
1139
+ encoder_config.n_filters = config.encoder_n_filters
1140
+ encoder_config.ratios = config.encoder_ratios
1141
+ encoder_config.depths = encoder_depths
1142
+ encoder_config.norm = config.conv_norm
1143
+ encoder_config.pad_mode = config.pad_mode
1144
+ encoder_config.bias = config.conv_bias
1145
+ encoder_config.layernorm_eps = config.layernorm_eps
1146
+ encoder_config.layernorm_elementwise_affine = config.layernorm_elementwise_affine
1147
+ encoder_config.mixer_layer = config.mixer_layer
1148
+ encoder_config.layer_scale_init_value = config.layer_scale_init_value
1149
+ encoder_config.disable_last_norm = config.disable_last_norm
1150
+
1151
+ # Initialize encoder and decoder
1152
+ self.encoder = TokenizerEncoder(encoder_config)
1153
+
1154
+ # Initialize weights
1155
+ self.apply(self._init_weights)
1156
+
1157
+ def _init_weights(self, module):
1158
+ """Initialize weights for the model"""
1159
+ if isinstance(module, nn.Linear):
1160
+ nn.init.normal_(module.weight, std=self.config.weight_init_value)
1161
+ if module.bias is not None:
1162
+ nn.init.zeros_(module.bias)
1163
+ elif isinstance(module, nn.LayerNorm):
1164
+ nn.init.ones_(module.weight)
1165
+ nn.init.zeros_(module.bias)
1166
+ elif isinstance(module, nn.Conv1d):
1167
+ nn.init.normal_(module.weight, std=self.config.weight_init_value)
1168
+ if module.bias is not None:
1169
+ nn.init.zeros_(module.bias)
1170
+
1171
+ @torch.no_grad()
1172
+ def encode(self, audio, cache=None, sample_indices=None, use_cache=False, debug=False):
1173
+ """Convert audio to latent representations"""
1174
+ latents = self.encoder(audio, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
1175
+ return VibeVoiceTokenizerEncoderOutput(mean=latents.permute(0, 2, 1))
1176
+
1177
+ @torch.no_grad()
1178
+ def sampling(self, encoder_output, dist_type=None):
1179
+ """Sample from the encoder output distribution"""
1180
+ return encoder_output.sample(dist_type='none')
1181
+
1182
+ def forward(self, audio, cache=None, sample_indices=None, use_cache=False, debug=False):
1183
+ """Full forward pass: encode audio to latents, then decode back to audio"""
1184
+ encoder_output = self.encode(audio, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
1185
+ sampled_latents, _ = self.sampling(encoder_output, dist_type='none')
1186
+ return None, sampled_latents
1187
+
1188
+ AutoModel.register(VibeVoiceAcousticTokenizerConfig, VibeVoiceAcousticTokenizerModel)
1189
+ AutoModel.register(VibeVoiceSemanticTokenizerConfig, VibeVoiceSemanticTokenizerModel)
1190
+
1191
+ __all__ = [
1192
+ "VibeVoiceTokenizerStreamingCache",
1193
+ "VibeVoiceAcousticTokenizerModel",
1194
+ "VibeVoiceSemanticTokenizerModel",
1195
+ ]
modular/streamer.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+
5
+ import asyncio
6
+ from queue import Queue
7
+ from typing import TYPE_CHECKING, Optional
8
+
9
+
10
+ from transformers.generation import BaseStreamer
11
+
12
+
13
+ class AudioStreamer(BaseStreamer):
14
+ """
15
+ Audio streamer that stores audio chunks in queues for each sample in the batch.
16
+ This allows streaming audio generation for multiple samples simultaneously.
17
+
18
+ Parameters:
19
+ batch_size (`int`):
20
+ The batch size for generation
21
+ stop_signal (`any`, *optional*):
22
+ The signal to put in the queue when generation ends. Defaults to None.
23
+ timeout (`float`, *optional*):
24
+ The timeout for the audio queue. If `None`, the queue will block indefinitely.
25
+ """
26
+
27
+ def __init__(
28
+ self,
29
+ batch_size: int,
30
+ stop_signal: Optional[any] = None,
31
+ timeout: Optional[float] = None,
32
+ ):
33
+ self.batch_size = batch_size
34
+ self.stop_signal = stop_signal
35
+ self.timeout = timeout
36
+
37
+ # Create a queue for each sample in the batch
38
+ self.audio_queues = [Queue() for _ in range(batch_size)]
39
+ self.finished_flags = [False for _ in range(batch_size)]
40
+ self.sample_indices_map = {} # Maps from sample index to queue index
41
+
42
+ def put(self, audio_chunks: torch.Tensor, sample_indices: torch.Tensor):
43
+ """
44
+ Receives audio chunks and puts them in the appropriate queues.
45
+
46
+ Args:
47
+ audio_chunks: Tensor of shape (num_samples, ...) containing audio chunks
48
+ sample_indices: Tensor indicating which samples these chunks belong to
49
+ """
50
+ for i, sample_idx in enumerate(sample_indices):
51
+ idx = sample_idx.item()
52
+ if idx < self.batch_size and not self.finished_flags[idx]:
53
+ # Convert to numpy or keep as tensor based on preference
54
+ audio_chunk = audio_chunks[i].detach().cpu()
55
+ self.audio_queues[idx].put(audio_chunk, timeout=self.timeout)
56
+
57
+ def end(self, sample_indices: Optional[torch.Tensor] = None):
58
+ """
59
+ Signals the end of generation for specified samples or all samples.
60
+
61
+ Args:
62
+ sample_indices: Optional tensor of sample indices to end. If None, ends all.
63
+ """
64
+ if sample_indices is None:
65
+ # End all samples
66
+ for idx in range(self.batch_size):
67
+ if not self.finished_flags[idx]:
68
+ self.audio_queues[idx].put(self.stop_signal, timeout=self.timeout)
69
+ self.finished_flags[idx] = True
70
+ else:
71
+ # End specific samples
72
+ for sample_idx in sample_indices:
73
+ idx = sample_idx.item() if torch.is_tensor(sample_idx) else sample_idx
74
+ if idx < self.batch_size and not self.finished_flags[idx]:
75
+ self.audio_queues[idx].put(self.stop_signal, timeout=self.timeout)
76
+ self.finished_flags[idx] = True
77
+
78
+ def __iter__(self):
79
+ """Returns an iterator over the batch of audio streams."""
80
+ return AudioBatchIterator(self)
81
+
82
+ def get_stream(self, sample_idx: int):
83
+ """Get the audio stream for a specific sample."""
84
+ if sample_idx >= self.batch_size:
85
+ raise ValueError(f"Sample index {sample_idx} exceeds batch size {self.batch_size}")
86
+ return AudioSampleIterator(self, sample_idx)
87
+
88
+
89
+ class AudioSampleIterator:
90
+ """Iterator for a single audio stream from the batch."""
91
+
92
+ def __init__(self, streamer: AudioStreamer, sample_idx: int):
93
+ self.streamer = streamer
94
+ self.sample_idx = sample_idx
95
+
96
+ def __iter__(self):
97
+ return self
98
+
99
+ def __next__(self):
100
+ value = self.streamer.audio_queues[self.sample_idx].get(timeout=self.streamer.timeout)
101
+ if value == self.streamer.stop_signal:
102
+ raise StopIteration()
103
+ return value
104
+
105
+
106
+ class AudioBatchIterator:
107
+ """Iterator that yields audio chunks for all samples in the batch."""
108
+
109
+ def __init__(self, streamer: AudioStreamer):
110
+ self.streamer = streamer
111
+ self.active_samples = set(range(streamer.batch_size))
112
+
113
+ def __iter__(self):
114
+ return self
115
+
116
+ def __next__(self):
117
+ if not self.active_samples:
118
+ raise StopIteration()
119
+
120
+ batch_chunks = {}
121
+ samples_to_remove = set()
122
+
123
+ # Try to get chunks from all active samples
124
+ for idx in self.active_samples:
125
+ try:
126
+ value = self.streamer.audio_queues[idx].get(block=False)
127
+ if value == self.streamer.stop_signal:
128
+ samples_to_remove.add(idx)
129
+ else:
130
+ batch_chunks[idx] = value
131
+ except:
132
+ # Queue is empty for this sample, skip it this iteration
133
+ pass
134
+
135
+ # Remove finished samples
136
+ self.active_samples -= samples_to_remove
137
+
138
+ if batch_chunks:
139
+ return batch_chunks
140
+ elif self.active_samples:
141
+ # If no chunks were ready but we still have active samples,
142
+ # wait a bit and try again
143
+ import time
144
+ time.sleep(0.01)
145
+ return self.__next__()
146
+ else:
147
+ raise StopIteration()
148
+
149
+
150
+ class AsyncAudioStreamer(AudioStreamer):
151
+ """
152
+ Async version of AudioStreamer for use in async contexts.
153
+ """
154
+
155
+ def __init__(
156
+ self,
157
+ batch_size: int,
158
+ stop_signal: Optional[any] = None,
159
+ timeout: Optional[float] = None,
160
+ ):
161
+ super().__init__(batch_size, stop_signal, timeout)
162
+ # Replace regular queues with async queues
163
+ self.audio_queues = [asyncio.Queue() for _ in range(batch_size)]
164
+ self.loop = asyncio.get_running_loop()
165
+
166
+ def put(self, audio_chunks: torch.Tensor, sample_indices: torch.Tensor):
167
+ """Put audio chunks in the appropriate async queues."""
168
+ for i, sample_idx in enumerate(sample_indices):
169
+ idx = sample_idx.item()
170
+ if idx < self.batch_size and not self.finished_flags[idx]:
171
+ audio_chunk = audio_chunks[i].detach().cpu()
172
+ self.loop.call_soon_threadsafe(
173
+ self.audio_queues[idx].put_nowait, audio_chunk
174
+ )
175
+
176
+ def end(self, sample_indices: Optional[torch.Tensor] = None):
177
+ """Signal the end of generation for specified samples."""
178
+ if sample_indices is None:
179
+ indices_to_end = range(self.batch_size)
180
+ else:
181
+ indices_to_end = [s.item() if torch.is_tensor(s) else s for s in sample_indices]
182
+
183
+ for idx in indices_to_end:
184
+ if idx < self.batch_size and not self.finished_flags[idx]:
185
+ self.loop.call_soon_threadsafe(
186
+ self.audio_queues[idx].put_nowait, self.stop_signal
187
+ )
188
+ self.finished_flags[idx] = True
189
+
190
+ async def get_stream(self, sample_idx: int):
191
+ """Get async iterator for a specific sample's audio stream."""
192
+ if sample_idx >= self.batch_size:
193
+ raise ValueError(f"Sample index {sample_idx} exceeds batch size {self.batch_size}")
194
+
195
+ while True:
196
+ value = await self.audio_queues[sample_idx].get()
197
+ if value == self.stop_signal:
198
+ break
199
+ yield value
200
+
201
+ def __aiter__(self):
202
+ """Returns an async iterator over all audio streams."""
203
+ return AsyncAudioBatchIterator(self)
204
+
205
+
206
+ class AsyncAudioBatchIterator:
207
+ """Async iterator for batch audio streaming."""
208
+
209
+ def __init__(self, streamer: AsyncAudioStreamer):
210
+ self.streamer = streamer
211
+ self.active_samples = set(range(streamer.batch_size))
212
+
213
+ def __aiter__(self):
214
+ return self
215
+
216
+ async def __anext__(self):
217
+ if not self.active_samples:
218
+ raise StopAsyncIteration()
219
+
220
+ batch_chunks = {}
221
+ samples_to_remove = set()
222
+
223
+ # Create tasks for all active samples
224
+ tasks = {
225
+ idx: asyncio.create_task(self._get_chunk(idx))
226
+ for idx in self.active_samples
227
+ }
228
+
229
+ # Wait for at least one chunk to be ready
230
+ done, pending = await asyncio.wait(
231
+ tasks.values(),
232
+ return_when=asyncio.FIRST_COMPLETED,
233
+ timeout=self.streamer.timeout
234
+ )
235
+
236
+ # Cancel pending tasks
237
+ for task in pending:
238
+ task.cancel()
239
+
240
+ # Process completed tasks
241
+ for idx, task in tasks.items():
242
+ if task in done:
243
+ try:
244
+ value = await task
245
+ if value == self.streamer.stop_signal:
246
+ samples_to_remove.add(idx)
247
+ else:
248
+ batch_chunks[idx] = value
249
+ except asyncio.CancelledError:
250
+ pass
251
+
252
+ self.active_samples -= samples_to_remove
253
+
254
+ if batch_chunks:
255
+ return batch_chunks
256
+ elif self.active_samples:
257
+ # Try again if we still have active samples
258
+ return await self.__anext__()
259
+ else:
260
+ raise StopAsyncIteration()
261
+
262
+ async def _get_chunk(self, idx):
263
+ """Helper to get a chunk from a specific queue."""
264
+ return await self.streamer.audio_queues[idx].get()
processor/__init__.py ADDED
File without changes
processor/vibevoice_processor.py ADDED
@@ -0,0 +1,677 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import warnings
3
+ from typing import List, Optional, Union, Dict, Any, Tuple
4
+ import os
5
+ import re
6
+
7
+ import numpy as np
8
+ import torch
9
+
10
+ from transformers.tokenization_utils_base import BatchEncoding, PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
11
+ from transformers.utils import TensorType, logging
12
+ from .vibevoice_tokenizer_processor import AudioNormalizer
13
+
14
+ logger = logging.get_logger(__name__)
15
+
16
+
17
+ class VibeVoiceProcessor:
18
+ r"""
19
+ Constructs a VibeVoice processor which wraps a VibeVoice tokenizer and audio processor into a single processor.
20
+
21
+ [`VibeVoiceProcessor`] offers all the functionalities of [`VibeVoiceTokenizer`] and [`VibeVoiceTokenizerProcessor`].
22
+ See the [`~VibeVoiceProcessor.__call__`] and [`~VibeVoiceProcessor.decode`] for more information.
23
+
24
+ Args:
25
+ tokenizer (`VibeVoiceTextTokenizer` or `VibeVoiceTextTokenizerFast`):
26
+ The tokenizer for text processing.
27
+ audio_processor (`VibeVoiceTokenizerProcessor`):
28
+ The audio processor for speech processing.
29
+ speech_tok_compress_ratio (`int`, *optional*, defaults to 3200):
30
+ The compression ratio for speech tokenization.
31
+ db_normalize (`bool`, *optional*, defaults to True):
32
+ Whether to apply decibel normalization to audio inputs.
33
+ """
34
+
35
+ def __init__(self, tokenizer=None, audio_processor=None, speech_tok_compress_ratio=3200, db_normalize=True, **kwargs):
36
+ self.tokenizer = tokenizer
37
+ self.audio_processor = audio_processor
38
+ self.speech_tok_compress_ratio = speech_tok_compress_ratio
39
+ self.db_normalize = db_normalize
40
+ self.audio_normalizer = AudioNormalizer() if db_normalize else None
41
+ self.system_prompt = " Transform the text provided by various speakers into speech output, utilizing the distinct voice of each respective speaker.\n"
42
+
43
+ @classmethod
44
+ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
45
+ """
46
+ Instantiate a VibeVoiceProcessor from a pretrained VibeVoice processor.
47
+
48
+ Args:
49
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
50
+ This can be either:
51
+ - a string, the *model id* of a pretrained model
52
+ - a path to a *directory* containing processor config
53
+
54
+ Returns:
55
+ [`VibeVoiceProcessor`]: The processor object instantiated from pretrained model.
56
+ """
57
+ import os
58
+ import json
59
+ from .vibevoice_tokenizer_processor import VibeVoiceTokenizerProcessor
60
+ from vibevoice.modular.modular_vibevoice_text_tokenizer import (
61
+ VibeVoiceTextTokenizer,
62
+ VibeVoiceTextTokenizerFast
63
+ )
64
+
65
+ # Load processor configuration
66
+ config_path = os.path.join(pretrained_model_name_or_path, "preprocessor_config.json")
67
+ if os.path.exists(config_path):
68
+ with open(config_path, 'r') as f:
69
+ config = json.load(f)
70
+ else:
71
+ logger.warning(f"No preprocessor_config.json found at {pretrained_model_name_or_path}, using defaults")
72
+ config = {
73
+ "speech_tok_compress_ratio": 3200,
74
+ "db_normalize": True,
75
+ }
76
+
77
+ # Extract main processor parameters
78
+ speech_tok_compress_ratio = config.get("speech_tok_compress_ratio", 3200)
79
+ db_normalize = config.get("db_normalize", True)
80
+
81
+ # Load tokenizer - try from model path first, then fallback to Qwen
82
+ language_model_pretrained_name = config.get("language_model_pretrained_name", None) or kwargs.pop("language_model_pretrained_name", "Qwen/Qwen2.5-1.5B")
83
+ logger.info(f"Loading tokenizer from {language_model_pretrained_name}")
84
+ if 'qwen' in language_model_pretrained_name.lower():
85
+ tokenizer = VibeVoiceTextTokenizerFast.from_pretrained(
86
+ language_model_pretrained_name,
87
+ **kwargs
88
+ )
89
+ else:
90
+ raise ValueError(f"Unsupported tokenizer type for {language_model_pretrained_name}. Supported types: Qwen, Llama, Gemma.")
91
+
92
+ # Load audio processor
93
+ if "audio_processor" in config:
94
+ # Create audio processor from config
95
+ audio_config = config["audio_processor"]
96
+ audio_processor = VibeVoiceTokenizerProcessor(
97
+ sampling_rate=audio_config.get("sampling_rate", 24000),
98
+ normalize_audio=audio_config.get("normalize_audio", True),
99
+ target_dB_FS=audio_config.get("target_dB_FS", -25),
100
+ eps=audio_config.get("eps", 1e-6),
101
+ )
102
+ else:
103
+ # Create default audio processor
104
+ audio_processor = VibeVoiceTokenizerProcessor()
105
+
106
+ # Create and return the processor
107
+ return cls(
108
+ tokenizer=tokenizer,
109
+ audio_processor=audio_processor,
110
+ speech_tok_compress_ratio=speech_tok_compress_ratio,
111
+ db_normalize=db_normalize,
112
+ )
113
+
114
+ def save_pretrained(self, save_directory: Union[str, os.PathLike], **kwargs):
115
+ """
116
+ Save a processor to a directory, so that it can be re-loaded using the
117
+ [`~VibeVoiceProcessor.from_pretrained`] class method.
118
+
119
+ Args:
120
+ save_directory (`str` or `os.PathLike`):
121
+ Directory where the processor will be saved.
122
+ """
123
+ import os
124
+ import json
125
+
126
+ os.makedirs(save_directory, exist_ok=True)
127
+
128
+ # Save processor configuration
129
+ processor_config = {
130
+ "processor_class": "VibeVoiceProcessor",
131
+ "speech_tok_compress_ratio": self.speech_tok_compress_ratio,
132
+ "db_normalize": self.db_normalize,
133
+ "audio_processor": {
134
+ "feature_extractor_type": "VibeVoiceTokenizerProcessor",
135
+ "sampling_rate": getattr(self.audio_processor, 'sampling_rate', 24000),
136
+ "normalize_audio": getattr(self.audio_processor, 'normalize_audio', True),
137
+ "target_dB_FS": getattr(self.audio_processor, 'target_dB_FS', -25),
138
+ "eps": getattr(self.audio_processor, 'eps', 1e-6),
139
+ }
140
+ }
141
+
142
+ config_path = os.path.join(save_directory, "preprocessor_config.json")
143
+ with open(config_path, 'w') as f:
144
+ json.dump(processor_config, f, indent=2)
145
+
146
+ logger.info(f"Processor configuration saved in {config_path}")
147
+
148
+ def __call__(
149
+ self,
150
+ text: Optional[Union[str, List[str], TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None,
151
+ voice_samples: Optional[Union[List[Union[str, np.ndarray]], List[List[Union[str, np.ndarray]]]]] = None,
152
+ padding: Union[bool, str, PaddingStrategy] = True,
153
+ truncation: Union[bool, str, TruncationStrategy] = False,
154
+ max_length: Optional[int] = None,
155
+ return_tensors: Optional[Union[str, TensorType]] = None,
156
+ return_attention_mask: bool = True,
157
+ **kwargs,
158
+ ) -> BatchEncoding:
159
+ """
160
+ Main method to process one or more podcast scripts with optional voice samples.
161
+
162
+ Args:
163
+ text (`str`, `List[str]`):
164
+ The input text(s) to process. Can be:
165
+ - A single script string
166
+ - A list of script strings for batch processing
167
+ - A path to a .json or .txt file
168
+ - A list of paths
169
+ voice_samples (`List[Union[str, np.ndarray]]`, `List[List[Union[str, np.ndarray]]]`, *optional*):
170
+ Voice samples for each script. Can be:
171
+ - A list of samples for a single script
172
+ - A list of lists for batch processing
173
+ padding (`bool`, `str` or `PaddingStrategy`, defaults to `True`):
174
+ Whether to pad sequences to the same length
175
+ truncation (`bool`, `str` or `TruncationStrategy`, defaults to `False`):
176
+ Whether to truncate sequences
177
+ max_length (`int`, *optional*):
178
+ Maximum length of the returned sequences
179
+ return_tensors (`str` or `TensorType`, *optional*):
180
+ If set, will return tensors of a particular framework
181
+ return_attention_mask (`bool`, defaults to `True`):
182
+ Whether to return the attention mask
183
+
184
+ Returns:
185
+ `BatchEncoding`: A BatchEncoding with the following fields:
186
+ - **input_ids** -- List of token id sequences or tensor
187
+ - **attention_mask** -- List of attention masks or tensor
188
+ - **speech_tensors** -- Padded speech inputs (if voice_samples provided)
189
+ - **speech_masks** -- Speech masks (if voice_samples provided)
190
+ - **speech_input_mask** -- Boolean masks indicating speech token positions
191
+ """
192
+ # Handle single vs batch input
193
+ if isinstance(text, str) or (isinstance(text, list) and len(text) > 0 and not isinstance(text[0], str)):
194
+ # Single input
195
+ texts = [text]
196
+ is_batched = False
197
+ else:
198
+ # Batch input
199
+ texts = text
200
+ is_batched = True
201
+
202
+ # Handle voice samples
203
+ if voice_samples is not None:
204
+ if not is_batched or (isinstance(voice_samples[0], (str, np.ndarray))):
205
+ # Single set of voice samples
206
+ voice_samples_list = [voice_samples]
207
+ else:
208
+ # Batch of voice samples
209
+ voice_samples_list = voice_samples
210
+ else:
211
+ voice_samples_list = [None] * len(texts)
212
+
213
+ # Process each input
214
+ all_encodings = []
215
+ for text_input, voice_input in zip(texts, voice_samples_list):
216
+ encoding = self._process_single(text_input, voice_input)
217
+ all_encodings.append(encoding)
218
+
219
+ # Combine batch
220
+ batch_encoding = self._batch_encode(
221
+ all_encodings,
222
+ padding=padding,
223
+ truncation=truncation,
224
+ max_length=max_length,
225
+ return_tensors=return_tensors,
226
+ return_attention_mask=return_attention_mask,
227
+ )
228
+
229
+ return batch_encoding
230
+
231
+ def _process_single(
232
+ self,
233
+ text: Union[str, TextInput],
234
+ voice_samples: Optional[List[Union[str, np.ndarray]]] = None,
235
+ ) -> Dict[str, Any]:
236
+ """Process a single podcast script."""
237
+ # Determine if text is a file path or direct script
238
+ script = None
239
+ if isinstance(text, str):
240
+ # Check if it's a file path
241
+ if text.endswith('.json') and os.path.exists(text):
242
+ script = self._convert_json_to_script(text)
243
+ elif text.endswith('.txt') and os.path.exists(text):
244
+ script = self._convert_text_to_script(text)
245
+ else:
246
+ # Assume it's the script content directly
247
+ script = text
248
+
249
+ if script is None:
250
+ raise ValueError(f"Could not process input text: {text}")
251
+
252
+ # Parse the script
253
+ parsed_lines = self._parse_script(script)
254
+ all_speakers = list(set(speaker_id for speaker_id, _ in parsed_lines))
255
+
256
+ # Create system prompt
257
+ # system_tokens = self.tokenizer.encode(self.system_prompt, add_special_tokens=False)
258
+ system_tokens = self.tokenizer.encode(self.system_prompt)
259
+
260
+ # Process voice samples if provided
261
+ if voice_samples:
262
+ voice_tokens, voice_speech_inputs, voice_speech_masks = self._create_voice_prompt(voice_samples[:len(all_speakers)])
263
+ else:
264
+ voice_tokens, voice_speech_inputs, voice_speech_masks = [], [], []
265
+
266
+ # Build full token sequence
267
+ full_tokens = system_tokens + voice_tokens
268
+ speech_input_mask = [False] * len(system_tokens) + voice_speech_masks
269
+
270
+ # Add text input section
271
+ full_tokens += self.tokenizer.encode(' Text input:\n', add_special_tokens=False)
272
+ speech_input_mask += [False] * len(self.tokenizer.encode(' Text input:\n', add_special_tokens=False))
273
+
274
+ for speaker_id, speaker_text in parsed_lines:
275
+ speaker_text_tokens = self.tokenizer.encode(f" Speaker {speaker_id}:{speaker_text}\n", add_special_tokens=False)
276
+ full_tokens += speaker_text_tokens
277
+ speech_input_mask += [False] * len(speaker_text_tokens)
278
+
279
+ # Add speech output section
280
+ full_tokens += self.tokenizer.encode(' Speech output:\n', add_special_tokens=False) + [self.tokenizer.speech_start_id]
281
+ speech_input_mask += [False] * (len(self.tokenizer.encode(' Speech output:\n', add_special_tokens=False)) + 1)
282
+
283
+ return {
284
+ "input_ids": full_tokens,
285
+ "speech_inputs": voice_speech_inputs if voice_speech_inputs else None,
286
+ "speech_input_mask": speech_input_mask,
287
+ "parsed_script": parsed_lines,
288
+ "all_speakers": all_speakers,
289
+ }
290
+
291
+ def _batch_encode(
292
+ self,
293
+ encodings: List[Dict[str, Any]],
294
+ padding: Union[bool, str, PaddingStrategy] = True,
295
+ truncation: Union[bool, str, TruncationStrategy] = False,
296
+ max_length: Optional[int] = None,
297
+ return_tensors: Optional[Union[str, TensorType]] = None,
298
+ return_attention_mask: bool = True,
299
+ ) -> BatchEncoding:
300
+ """Combine multiple encodings into a batch with padding."""
301
+ # Extract input_ids and create attention_mask
302
+ input_ids_list = [enc["input_ids"] for enc in encodings]
303
+ speech_input_masks_list = [enc["speech_input_mask"] for enc in encodings]
304
+
305
+ # Determine padding strategy
306
+ if isinstance(padding, bool):
307
+ padding_strategy = PaddingStrategy.LONGEST if padding else PaddingStrategy.DO_NOT_PAD
308
+ elif isinstance(padding, str):
309
+ padding_strategy = PaddingStrategy(padding)
310
+ else:
311
+ padding_strategy = padding
312
+
313
+ # Apply padding to input_ids
314
+ if padding_strategy != PaddingStrategy.DO_NOT_PAD:
315
+ if padding_strategy == PaddingStrategy.LONGEST:
316
+ max_len = max(len(ids) for ids in input_ids_list)
317
+ elif padding_strategy == PaddingStrategy.MAX_LENGTH and max_length is not None:
318
+ max_len = max_length
319
+ else:
320
+ max_len = max(len(ids) for ids in input_ids_list)
321
+
322
+ # Pad sequences
323
+ padded_input_ids = []
324
+ attention_masks = []
325
+ padded_speech_input_masks = []
326
+
327
+ for input_ids, speech_mask in zip(input_ids_list, speech_input_masks_list):
328
+ # Truncate if needed
329
+ if truncation and len(input_ids) > max_len:
330
+ input_ids = input_ids[:max_len]
331
+ speech_mask = speech_mask[:max_len]
332
+
333
+ # Pad
334
+ padding_length = max_len - len(input_ids)
335
+ # padded_ids = [self.tokenizer.pad_token_id] * padding_length + input_ids
336
+ padded_ids = [self.tokenizer.pad_id] * padding_length + input_ids
337
+ attention_mask = [0] * padding_length + [1] * len(input_ids)
338
+ padded_speech_mask = [False] * padding_length + speech_mask
339
+
340
+ padded_input_ids.append(padded_ids)
341
+ attention_masks.append(attention_mask)
342
+ padded_speech_input_masks.append(padded_speech_mask)
343
+
344
+ input_ids_list = padded_input_ids
345
+ speech_input_masks_list = padded_speech_input_masks
346
+ else:
347
+ # No padding, just create attention masks
348
+ attention_masks = [[1] * len(ids) for ids in input_ids_list] if return_attention_mask else None
349
+
350
+ # Process speech inputs
351
+ all_speech_inputs = []
352
+ has_speech = False
353
+ for enc in encodings:
354
+ if enc["speech_inputs"] is not None:
355
+ all_speech_inputs.extend(enc["speech_inputs"])
356
+ has_speech = True
357
+
358
+ # Prepare batch encoding
359
+ batch_encoding = BatchEncoding()
360
+
361
+ # Handle tensor conversion
362
+ if return_tensors is not None:
363
+ batch_encoding["input_ids"] = torch.tensor(input_ids_list, dtype=torch.long)
364
+ if return_attention_mask and attention_masks is not None:
365
+ batch_encoding["attention_mask"] = torch.tensor(attention_masks, dtype=torch.long)
366
+ batch_encoding["speech_input_mask"] = torch.tensor(speech_input_masks_list, dtype=torch.bool)
367
+ else:
368
+ batch_encoding["input_ids"] = input_ids_list
369
+ if return_attention_mask and attention_masks is not None:
370
+ batch_encoding["attention_mask"] = attention_masks
371
+ batch_encoding["speech_input_mask"] = speech_input_masks_list
372
+
373
+ # Process speech tensors if present
374
+ if has_speech:
375
+ speech_dict = self.prepare_speech_inputs(
376
+ all_speech_inputs,
377
+ return_tensors=return_tensors,
378
+ )
379
+ batch_encoding["speech_tensors"] = speech_dict["padded_speeches"]
380
+ batch_encoding["speech_masks"] = speech_dict["speech_masks"]
381
+ else:
382
+ batch_encoding["speech_tensors"] = None
383
+ batch_encoding["speech_masks"] = None
384
+
385
+ # Add metadata
386
+ batch_encoding["parsed_scripts"] = [enc["parsed_script"] for enc in encodings]
387
+ batch_encoding["all_speakers_list"] = [enc["all_speakers"] for enc in encodings]
388
+
389
+ return batch_encoding
390
+
391
+ def _create_voice_prompt(
392
+ self,
393
+ speaker_samples: List[Union[str, np.ndarray]]
394
+ ) -> Tuple[List[int], List[np.ndarray], List[bool]]:
395
+ """
396
+ Create voice prompt tokens and process audio samples.
397
+
398
+ Returns:
399
+ tuple: (voice_tokens, voice_speech_inputs, voice_speech_masks)
400
+ """
401
+ vae_token_id = self.tokenizer.speech_diffusion_id
402
+
403
+ voice_full_tokens = self.tokenizer.encode(' Voice input:\n', add_special_tokens=False)
404
+ voice_speech_inputs = []
405
+ voice_speech_masks = [False] * len(voice_full_tokens)
406
+
407
+ for speaker_id, speaker_audio in enumerate(speaker_samples):
408
+ prefix_tokens = self.tokenizer.encode(f" Speaker {speaker_id}:", add_special_tokens=False)
409
+
410
+ # Process audio
411
+ if isinstance(speaker_audio, str):
412
+ # Load audio from file
413
+ wav = self.audio_processor._load_audio_from_path(speaker_audio)
414
+ else:
415
+ wav = np.array(speaker_audio, dtype=np.float32)
416
+
417
+ # Apply normalization if needed
418
+ if self.db_normalize and self.audio_normalizer:
419
+ wav = self.audio_normalizer(wav)
420
+
421
+ # Calculate token length based on compression ratio
422
+ # if speaker_audio.endswith('.pt') or speaker_audio.endswith('.npy'):
423
+ # vae_tok_len = wav.shape[0]
424
+ # else:
425
+ vae_tok_len = math.ceil(wav.shape[0] / self.speech_tok_compress_ratio)
426
+
427
+ # Build tokens and masks
428
+ speaker_tokens = (prefix_tokens +
429
+ [self.tokenizer.speech_start_id] +
430
+ [vae_token_id] * vae_tok_len +
431
+ [self.tokenizer.speech_end_id] +
432
+ self.tokenizer.encode('\n', add_special_tokens=False))
433
+
434
+ vae_input_mask = ([False] * len(prefix_tokens) +
435
+ [False] +
436
+ [True] * vae_tok_len +
437
+ [False] +
438
+ [False])
439
+
440
+ voice_full_tokens.extend(speaker_tokens)
441
+ voice_speech_masks.extend(vae_input_mask)
442
+ voice_speech_inputs.append(wav)
443
+
444
+ return voice_full_tokens, voice_speech_inputs, voice_speech_masks
445
+
446
+ def prepare_speech_inputs(
447
+ self,
448
+ speech_inputs: List[np.ndarray],
449
+ return_tensors: Optional[Union[str, TensorType]] = None,
450
+ device: Optional[Union[str, torch.device]] = None,
451
+ dtype: Optional[torch.dtype] = None,
452
+ ) -> Dict[str, Any]:
453
+ """
454
+ Prepare speech inputs for model consumption.
455
+
456
+ Args:
457
+ speech_inputs: List of speech arrays
458
+ return_tensors: Output tensor type
459
+ device: Device to place tensors on
460
+ dtype: Data type for tensors
461
+
462
+ Returns:
463
+ Dictionary with padded_speeches and speech_masks
464
+ """
465
+ if not speech_inputs:
466
+ return {"padded_speeches": None, "speech_masks": None}
467
+
468
+ # Calculate sequence lengths
469
+ vae_tok_seqlens = [math.ceil(s.shape[0] / self.speech_tok_compress_ratio) for s in speech_inputs]
470
+ # vae_tok_seqlens = [math.ceil(s.shape[0] / self.speech_tok_compress_ratio) if s.ndim == 1 else s.shape[0] for s in speech_inputs]
471
+ max_speech_length = max(s.shape[0] for s in speech_inputs)
472
+
473
+ # Pad speeches
474
+ if speech_inputs[0].ndim == 1:
475
+ padded_speeches = np.full((len(speech_inputs), max_speech_length), fill_value=0, dtype=np.float32)
476
+ else:
477
+ padded_speeches = np.full((len(speech_inputs), max_speech_length, speech_inputs[0].shape[-1]), fill_value=0, dtype=np.float32)
478
+ speech_masks = np.zeros((len(speech_inputs), max(vae_tok_seqlens)), dtype=np.bool_)
479
+
480
+ for i, (speech, vae_tok_length) in enumerate(zip(speech_inputs, vae_tok_seqlens)):
481
+ padded_speeches[i, :len(speech)] = speech
482
+ speech_masks[i, :vae_tok_length] = True
483
+
484
+ result = {
485
+ "padded_speeches": padded_speeches,
486
+ "speech_masks": speech_masks,
487
+ }
488
+
489
+ # Convert to tensors if requested
490
+ if return_tensors == "pt":
491
+ result["padded_speeches"] = torch.tensor(padded_speeches, device=device, dtype=dtype or torch.float32)
492
+ result["speech_masks"] = torch.tensor(speech_masks, device=device, dtype=torch.bool)
493
+
494
+ return result
495
+
496
+ def _convert_json_to_script(self, json_file: str) -> str:
497
+ """
498
+ Convert JSON format to script format.
499
+ Expected JSON format:
500
+ [
501
+ {"speaker": "1", "text": "Hello everyone..."},
502
+ {"speaker": "2", "text": "Great to be here..."}
503
+ ]
504
+ """
505
+ import json
506
+
507
+ with open(json_file, 'r', encoding='utf-8') as f:
508
+ data = json.load(f)
509
+
510
+ if not isinstance(data, list):
511
+ raise ValueError("JSON file must contain a list of speaker entries")
512
+
513
+ script_lines = []
514
+ for item in data:
515
+ if not isinstance(item, dict):
516
+ logger.warning(f"Skipping non-dict entry: {item}")
517
+ continue
518
+
519
+ speaker = item.get('speaker')
520
+ text = item.get('text')
521
+
522
+ if speaker is None or text is None:
523
+ logger.warning(f"Skipping entry missing speaker or text: {item}")
524
+ continue
525
+
526
+ # Ensure speaker ID is valid
527
+ try:
528
+ speaker_id = int(speaker)
529
+ except (ValueError, TypeError):
530
+ logger.warning(f"Invalid speaker ID: {speaker}, skipping entry")
531
+ continue
532
+
533
+ # Clean up text
534
+ text = text.strip()
535
+ if text:
536
+ script_lines.append(f"Speaker {speaker_id}: {text}")
537
+
538
+ if not script_lines:
539
+ raise ValueError("No valid entries found in JSON file")
540
+
541
+ return "\n".join(script_lines)
542
+
543
+ def _convert_text_to_script(self, text_file: str) -> str:
544
+ """
545
+ Convert text file to script format.
546
+ Handles multiple formats:
547
+ 1. Already formatted as "Speaker X: text"
548
+ 2. Plain text (assigns to Speaker 1)
549
+
550
+ Handles edge cases like multiple colons in a line.
551
+ """
552
+ with open(text_file, 'r', encoding='utf-8') as f:
553
+ lines = f.readlines()
554
+
555
+ script_lines = []
556
+ current_speaker = 1
557
+
558
+ for line in lines:
559
+ line = line.strip()
560
+ if not line:
561
+ continue
562
+
563
+ # Try to parse as "Speaker X: text" format
564
+ # Use regex to be more robust
565
+ speaker_match = re.match(r'^Speaker\s+(\d+)\s*:\s*(.*)$', line, re.IGNORECASE)
566
+
567
+ if speaker_match:
568
+ speaker_id = int(speaker_match.group(1))
569
+ text = speaker_match.group(2).strip()
570
+ if text:
571
+ script_lines.append(f"Speaker {speaker_id}: {text}")
572
+ else:
573
+ # Treat as plain text - assign to current speaker
574
+ script_lines.append(f"Speaker {current_speaker}: {line}")
575
+
576
+ if not script_lines:
577
+ raise ValueError("No valid content found in text file")
578
+
579
+ return "\n".join(script_lines)
580
+
581
+ def _parse_script(self, script: str) -> List[Tuple[int, str]]:
582
+ """Parse script into list of (speaker_id, text) tuples."""
583
+ lines = script.strip().split("\n")
584
+ parsed_lines = []
585
+ speaker_ids = []
586
+
587
+ # First pass: parse all lines and collect speaker IDs
588
+ for line in lines:
589
+ if not line.strip():
590
+ continue
591
+
592
+ # Use regex to handle edge cases like multiple colons
593
+ match = re.match(r'^Speaker\s+(\d+)\s*:\s*(.*)$', line.strip(), re.IGNORECASE)
594
+
595
+ if match:
596
+ speaker_id = int(match.group(1))
597
+ text = ' ' + match.group(2).strip()
598
+ parsed_lines.append((speaker_id, text))
599
+ speaker_ids.append(speaker_id)
600
+ else:
601
+ logger.warning(f"Could not parse line: '{line}'")
602
+
603
+ if not parsed_lines:
604
+ raise ValueError("No valid speaker lines found in script")
605
+
606
+ # Check if we need to normalize speaker IDs (only if all are > 0)
607
+ min_speaker_id = min(speaker_ids)
608
+ if min_speaker_id > 0:
609
+ # Normalize to start from 0
610
+ normalized_lines = []
611
+ for speaker_id, text in parsed_lines:
612
+ normalized_lines.append((speaker_id - 1, text))
613
+ return normalized_lines
614
+ else:
615
+ # Keep original IDs
616
+ return parsed_lines
617
+
618
+ def _merge_inputs(self, text_inputs: BatchEncoding, audio_inputs: Dict) -> BatchEncoding:
619
+ """Merge text and audio inputs into a single BatchEncoding."""
620
+ # Start with text inputs
621
+ merged = BatchEncoding(text_inputs)
622
+
623
+ # Add audio-specific fields
624
+ if "audio" in audio_inputs:
625
+ merged["speech_inputs"] = audio_inputs["audio"]
626
+ if "streaming" in audio_inputs:
627
+ merged["streaming"] = audio_inputs["streaming"]
628
+
629
+ return merged
630
+
631
+ def batch_decode(self, *args, **kwargs):
632
+ """
633
+ This method forwards all its arguments to VibeVoiceTextTokenizer's [`~PreTrainedTokenizer.batch_decode`].
634
+ Please refer to the docstring of this method for more information.
635
+ """
636
+ return self.tokenizer.batch_decode(*args, **kwargs)
637
+
638
+ def decode(self, *args, **kwargs):
639
+ """
640
+ This method forwards all its arguments to VibeVoiceTextTokenizer's [`~PreTrainedTokenizer.decode`].
641
+ Please refer to the docstring of this method for more information.
642
+ """
643
+ return self.tokenizer.decode(*args, **kwargs)
644
+
645
+ @property
646
+ def model_input_names(self):
647
+ """
648
+ Return the list of inputs accepted by the model.
649
+ """
650
+ tokenizer_input_names = self.tokenizer.model_input_names
651
+ audio_processor_input_names = self.audio_processor.model_input_names
652
+ return list(dict.fromkeys(tokenizer_input_names + audio_processor_input_names + ["speech_inputs", "speech_input_mask"]))
653
+
654
+ def save_audio(self,
655
+ audio: Union[torch.Tensor, np.ndarray, List[Union[torch.Tensor, np.ndarray]]],
656
+ output_path: str = "output.wav",
657
+ sampling_rate: Optional[int] = None,
658
+ normalize: bool = False,
659
+ batch_prefix: str = "audio_",
660
+ ) -> str:
661
+ """
662
+ Save audio data to a file.
663
+ Args:
664
+ audio (Union[torch.Tensor, np.ndarray, List[Union[torch.Tensor, np.ndarray]]]):
665
+ The audio data to save. Can be a single tensor/array or a list of them.
666
+ output_path (str, optional): Path to save the audio file. Defaults to "output.wav".
667
+ sampling_rate (int, optional): Sampling rate for the audio. If None, uses the processor's default.
668
+ normalize (bool, optional): Whether to normalize the audio before saving. Defaults to False.
669
+ batch_prefix (str, optional): Prefix for batch audio files. Defaults to "audio_".
670
+ Returns:
671
+ str: The path to the saved audio file.
672
+ """
673
+ return self.audio_processor.save_audio(audio, output_path=output_path, sampling_rate=sampling_rate, normalize=normalize, batch_prefix=batch_prefix)
674
+
675
+ __all__ = [
676
+ "VibeVoiceProcessor",
677
+ ]
processor/vibevoice_tokenizer_processor.py ADDED
@@ -0,0 +1,483 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Processor class for VibeVoice models.
3
+ """
4
+
5
+ import os
6
+ import json
7
+ import warnings
8
+ from typing import List, Optional, Union, Dict, Any
9
+
10
+ import numpy as np
11
+ import torch
12
+
13
+ from transformers.feature_extraction_utils import FeatureExtractionMixin
14
+ from transformers.utils import logging
15
+
16
+ logger = logging.get_logger(__name__)
17
+
18
+
19
+ class AudioNormalizer:
20
+ """
21
+ Audio normalization class for VibeVoice tokenizer.
22
+
23
+ This class provides audio normalization to ensure consistent input levels
24
+ for the VibeVoice tokenizer while maintaining audio quality.
25
+ """
26
+
27
+ def __init__(self, target_dB_FS: float = -25, eps: float = 1e-6):
28
+ """
29
+ Initialize the audio normalizer.
30
+
31
+ Args:
32
+ target_dB_FS (float): Target dB FS level for the audio. Default: -25
33
+ eps (float): Small value to avoid division by zero. Default: 1e-6
34
+ """
35
+ self.target_dB_FS = target_dB_FS
36
+ self.eps = eps
37
+
38
+ def tailor_dB_FS(self, audio: np.ndarray) -> tuple:
39
+ """
40
+ Adjust the audio to the target dB FS level.
41
+
42
+ Args:
43
+ audio (np.ndarray): Input audio signal
44
+
45
+ Returns:
46
+ tuple: (normalized_audio, rms, scalar)
47
+ """
48
+ rms = np.sqrt(np.mean(audio**2))
49
+ scalar = 10 ** (self.target_dB_FS / 20) / (rms + self.eps)
50
+ normalized_audio = audio * scalar
51
+ return normalized_audio, rms, scalar
52
+
53
+ def avoid_clipping(self, audio: np.ndarray, scalar: Optional[float] = None) -> tuple:
54
+ """
55
+ Avoid clipping by scaling down if necessary.
56
+
57
+ Args:
58
+ audio (np.ndarray): Input audio signal
59
+ scalar (float, optional): Explicit scaling factor
60
+
61
+ Returns:
62
+ tuple: (normalized_audio, scalar)
63
+ """
64
+ if scalar is None:
65
+ max_val = np.max(np.abs(audio))
66
+ if max_val > 1.0:
67
+ scalar = max_val + self.eps
68
+ else:
69
+ scalar = 1.0
70
+
71
+ return audio / scalar, scalar
72
+
73
+ def __call__(self, audio: np.ndarray) -> np.ndarray:
74
+ """
75
+ Normalize the audio by adjusting to target dB FS and avoiding clipping.
76
+
77
+ Args:
78
+ audio (np.ndarray): Input audio signal
79
+
80
+ Returns:
81
+ np.ndarray: Normalized audio signal
82
+ """
83
+ # First adjust to target dB FS
84
+ audio, _, _ = self.tailor_dB_FS(audio)
85
+ # Then avoid clipping
86
+ audio, _ = self.avoid_clipping(audio)
87
+ return audio
88
+
89
+
90
+ # Change from ProcessorMixin to FeatureExtractionMixin which is designed for single components
91
+ class VibeVoiceTokenizerProcessor(FeatureExtractionMixin):
92
+ """
93
+ Processor for VibeVoice acoustic tokenizer models.
94
+
95
+ This processor handles audio preprocessing for VibeVoice models, including:
96
+ - Audio format conversion (stereo to mono)
97
+ - Optional audio normalization
98
+ - Streaming support for infinite-length audio
99
+
100
+ Args:
101
+ sampling_rate (int, optional): Expected sampling rate. Defaults to 24000.
102
+ normalize_audio (bool, optional): Whether to normalize audio. Defaults to True.
103
+ target_dB_FS (float, optional): Target dB FS for normalization. Defaults to -25.
104
+ eps (float, optional): Small value for numerical stability. Defaults to 1e-6.
105
+ """
106
+ model_input_names = ["input_features"]
107
+
108
+ def __init__(
109
+ self,
110
+ sampling_rate: int = 24000,
111
+ normalize_audio: bool = True,
112
+ target_dB_FS: float = -25,
113
+ eps: float = 1e-6,
114
+ **kwargs,
115
+ ):
116
+ super().__init__(**kwargs)
117
+
118
+ self.sampling_rate = sampling_rate
119
+ self.normalize_audio = normalize_audio
120
+
121
+ # Initialize audio normalizer if needed
122
+ if self.normalize_audio:
123
+ self.normalizer = AudioNormalizer(target_dB_FS=target_dB_FS, eps=eps)
124
+ else:
125
+ self.normalizer = None
126
+
127
+ # Save config
128
+ self.feature_extractor_dict = {
129
+ "sampling_rate": sampling_rate,
130
+ "normalize_audio": normalize_audio,
131
+ "target_dB_FS": target_dB_FS,
132
+ "eps": eps,
133
+ }
134
+
135
+ def _ensure_mono(self, audio: np.ndarray) -> np.ndarray:
136
+ """
137
+ Convert stereo audio to mono if needed.
138
+
139
+ Args:
140
+ audio (np.ndarray): Input audio array
141
+
142
+ Returns:
143
+ np.ndarray: Mono audio array
144
+ """
145
+ if len(audio.shape) == 1:
146
+ return audio
147
+ elif len(audio.shape) == 2:
148
+ if audio.shape[0] == 2: # (2, time)
149
+ return np.mean(audio, axis=0)
150
+ elif audio.shape[1] == 2: # (time, 2)
151
+ return np.mean(audio, axis=1)
152
+ else:
153
+ # If one dimension is 1, squeeze it
154
+ if audio.shape[0] == 1:
155
+ return audio.squeeze(0)
156
+ elif audio.shape[1] == 1:
157
+ return audio.squeeze(1)
158
+ else:
159
+ raise ValueError(f"Unexpected audio shape: {audio.shape}")
160
+ else:
161
+ raise ValueError(f"Audio should be 1D or 2D, got shape: {audio.shape}")
162
+
163
+ def _process_single_audio(self, audio: Union[np.ndarray, List[float]]) -> np.ndarray:
164
+ """
165
+ Process a single audio array.
166
+
167
+ Args:
168
+ audio: Single audio input
169
+
170
+ Returns:
171
+ np.ndarray: Processed audio
172
+ """
173
+ # Convert to numpy array
174
+ if not isinstance(audio, np.ndarray):
175
+ audio = np.array(audio, dtype=np.float32)
176
+ else:
177
+ audio = audio.astype(np.float32)
178
+
179
+ # Ensure mono
180
+ audio = self._ensure_mono(audio)
181
+
182
+ # Normalize if requested
183
+ if self.normalize_audio and self.normalizer is not None:
184
+ audio = self.normalizer(audio)
185
+
186
+ return audio
187
+
188
+ def __call__(
189
+ self,
190
+ audio: Union[str, np.ndarray, List[float], List[np.ndarray], List[List[float]], List[str]] = None,
191
+ sampling_rate: Optional[int] = None,
192
+ return_tensors: Optional[str] = None,
193
+ **kwargs,
194
+ ):
195
+ """
196
+ Process audio for VibeVoice models.
197
+
198
+ Args:
199
+ audio: Audio input(s) to process. Can be:
200
+ - str: Path to audio file
201
+ - np.ndarray: Audio array
202
+ - List[float]: Audio as list of floats
203
+ - List[np.ndarray]: Batch of audio arrays
204
+ - List[str]: Batch of audio file paths
205
+ sampling_rate (int, optional): Sampling rate of the input audio
206
+ return_tensors (str, optional): Return format ('pt' for PyTorch, 'np' for NumPy)
207
+
208
+ Returns:
209
+ dict: Processed audio inputs with keys:
210
+ - input_features: Audio tensor(s) ready for the model
211
+ """
212
+ if audio is None:
213
+ raise ValueError("Audio input is required")
214
+
215
+ # Validate sampling rate
216
+ if sampling_rate is not None and sampling_rate != self.sampling_rate:
217
+ logger.warning(
218
+ f"Input sampling rate ({sampling_rate}) differs from expected "
219
+ f"sampling rate ({self.sampling_rate}). Please resample your audio."
220
+ )
221
+
222
+ # Handle different input types
223
+ if isinstance(audio, str):
224
+ # Single audio file path
225
+ audio = self._load_audio_from_path(audio)
226
+ is_batched = False
227
+ elif isinstance(audio, list):
228
+ if len(audio) == 0:
229
+ raise ValueError("Empty audio list provided")
230
+
231
+ # Check if it's a list of file paths
232
+ if all(isinstance(item, str) for item in audio):
233
+ # Batch of audio file paths
234
+ audio = [self._load_audio_from_path(path) for path in audio]
235
+ is_batched = True
236
+ else:
237
+ # Check if it's batched audio arrays
238
+ is_batched = isinstance(audio[0], (np.ndarray, list))
239
+ else:
240
+ # Single audio array or list
241
+ is_batched = False
242
+
243
+ # Process audio
244
+ if is_batched:
245
+ processed_audio = [self._process_single_audio(a) for a in audio]
246
+ else:
247
+ processed_audio = [self._process_single_audio(audio)]
248
+
249
+ # Convert to tensors if requested
250
+ if return_tensors == "pt":
251
+ if len(processed_audio) == 1:
252
+ # Create a proper batch dimension (B, T)
253
+ input_features = torch.from_numpy(processed_audio[0]).unsqueeze(0).unsqueeze(1)
254
+ else:
255
+ # For batched input with different lengths, create a batch properly
256
+ input_features = torch.stack([torch.from_numpy(a) for a in processed_audio]).unsqueeze(1)
257
+ elif return_tensors == "np":
258
+ if len(processed_audio) == 1:
259
+ input_features = processed_audio[0][np.newaxis, np.newaxis, :]
260
+ else:
261
+ input_features = np.stack(processed_audio)[:, np.newaxis, :]
262
+ else:
263
+ input_features = processed_audio[0] if len(processed_audio) == 1 else processed_audio
264
+
265
+ outputs = {
266
+ "audio": input_features, # Use "audio" instead of "input_features"
267
+ }
268
+
269
+ return outputs
270
+
271
+ def _load_audio_from_path(self, audio_path: str) -> np.ndarray:
272
+ """
273
+ Load audio from file path.
274
+
275
+ Args:
276
+ audio_path (str): Path to audio file
277
+
278
+ Returns:
279
+ np.ndarray: Loaded audio array
280
+ """
281
+ # Get file extension to determine loading method
282
+ file_ext = os.path.splitext(audio_path)[1].lower()
283
+
284
+ if file_ext in ['.wav', '.mp3', '.flac', '.m4a', '.ogg']:
285
+ # Audio file - use librosa
286
+ import librosa
287
+ audio_array, sr = librosa.load(
288
+ audio_path,
289
+ sr=self.sampling_rate,
290
+ mono=True
291
+ )
292
+ return audio_array
293
+ elif file_ext == '.pt':
294
+ # PyTorch tensor file
295
+ audio_tensor = torch.load(audio_path, map_location='cpu').squeeze()
296
+ if isinstance(audio_tensor, torch.Tensor):
297
+ audio_array = audio_tensor.numpy()
298
+ else:
299
+ audio_array = np.array(audio_tensor)
300
+ return audio_array.astype(np.float32)
301
+ elif file_ext == '.npy':
302
+ # NumPy file
303
+ audio_array = np.load(audio_path)
304
+ return audio_array.astype(np.float32)
305
+ else:
306
+ raise ValueError(
307
+ f"Unsupported file format: {file_ext}. "
308
+ f"Supported formats: .wav, .mp3, .flac, .m4a, .ogg, .pt, .npy, .npz"
309
+ )
310
+
311
+ def preprocess_audio(
312
+ self,
313
+ audio_path_or_array: Union[str, np.ndarray],
314
+ normalize: Optional[bool] = None,
315
+ ) -> np.ndarray:
316
+ """
317
+ Convenience method to preprocess audio from file path or array.
318
+ This method is kept for backward compatibility but __call__ is recommended.
319
+
320
+ Args:
321
+ audio_path_or_array: Path to audio file or numpy array
322
+ normalize: Whether to normalize (overrides default setting)
323
+
324
+ Returns:
325
+ np.ndarray: Preprocessed audio array
326
+ """
327
+ if isinstance(audio_path_or_array, str):
328
+ audio_array = self._load_audio_from_path(audio_path_or_array)
329
+ else:
330
+ audio_array = np.array(audio_path_or_array, dtype=np.float32)
331
+
332
+ # Override normalization setting if specified
333
+ original_normalize = self.normalize_audio
334
+ if normalize is not None:
335
+ self.normalize_audio = normalize
336
+
337
+ try:
338
+ processed = self._process_single_audio(audio_array)
339
+ finally:
340
+ # Restore original setting
341
+ self.normalize_audio = original_normalize
342
+
343
+ return processed
344
+
345
+ # Override to_dict method for configuration saving
346
+ def to_dict(self) -> Dict[str, Any]:
347
+ """
348
+ Convert the object to a dict containing all attributes needed for serialization.
349
+ """
350
+ return self.feature_extractor_dict
351
+
352
+ def save_audio(
353
+ self,
354
+ audio: Union[torch.Tensor, np.ndarray, List[Union[torch.Tensor, np.ndarray]]],
355
+ output_path: str = "output.wav",
356
+ sampling_rate: Optional[int] = None,
357
+ normalize: bool = False,
358
+ batch_prefix: str = "audio_",
359
+ ):
360
+ """
361
+ Save audio data to WAV file(s).
362
+
363
+ Args:
364
+ audio: Audio data to save. Can be:
365
+ - torch.Tensor: PyTorch tensor with shape (B, C, T) or (B, T) or (T)
366
+ - np.ndarray: NumPy array with shape (B, C, T) or (B, T) or (T)
367
+ - List of tensors or arrays
368
+ output_path: Path where to save the audio. If saving multiple files,
369
+ this is treated as a directory and individual files will be saved inside.
370
+ sampling_rate: Sampling rate for the saved audio. Defaults to the processor's rate.
371
+ normalize: Whether to normalize audio before saving.
372
+ batch_prefix: Prefix for batch files when saving multiple audios.
373
+
374
+ Returns:
375
+ List[str]: Paths to the saved audio files.
376
+ """
377
+ if sampling_rate is None:
378
+ sampling_rate = self.sampling_rate
379
+
380
+ try:
381
+ import soundfile as sf
382
+ except ImportError:
383
+ raise ImportError(
384
+ "soundfile is required to save audio files. "
385
+ "Install it with: pip install soundfile"
386
+ )
387
+
388
+ # Ensure audio is in the right format
389
+ if isinstance(audio, torch.Tensor):
390
+ # Convert PyTorch tensor to numpy
391
+ audio_np = audio.float().detach().cpu().numpy()
392
+ elif isinstance(audio, np.ndarray):
393
+ audio_np = audio
394
+ elif isinstance(audio, list):
395
+ # Handle list of tensors or arrays
396
+ if all(isinstance(a, torch.Tensor) for a in audio):
397
+ audio_np = [a.float().detach().cpu().numpy() for a in audio]
398
+ else:
399
+ audio_np = audio
400
+ else:
401
+ raise ValueError(f"Unsupported audio type: {type(audio)}")
402
+
403
+ saved_paths = []
404
+
405
+ # Handle based on shape or type
406
+ if isinstance(audio_np, list):
407
+ # Multiple separate audios to save
408
+ output_dir = output_path
409
+
410
+ # Ensure output directory exists
411
+ os.makedirs(output_dir, exist_ok=True)
412
+
413
+ # Save each audio
414
+ for i, audio_item in enumerate(audio_np):
415
+ audio_item = self._prepare_audio_for_save(audio_item, normalize)
416
+ file_path = os.path.join(output_dir, f"{batch_prefix}{i}.wav")
417
+ sf.write(file_path, audio_item, sampling_rate)
418
+ saved_paths.append(file_path)
419
+
420
+ else:
421
+ # Handle different dimensions
422
+ if len(audio_np.shape) >= 3: # (B, C, T) or similar
423
+ # Get batch size
424
+ batch_size = audio_np.shape[0]
425
+
426
+ if batch_size > 1:
427
+ # Multiple audios in a batch
428
+ output_dir = output_path
429
+
430
+ # Ensure output directory exists
431
+ os.makedirs(output_dir, exist_ok=True)
432
+
433
+ # Save each audio in the batch
434
+ for i in range(batch_size):
435
+ # Extract single audio and remove channel dim if present
436
+ single_audio = audio_np[i]
437
+ if len(single_audio.shape) > 1:
438
+ if single_audio.shape[0] == 1: # (1, T)
439
+ single_audio = single_audio.squeeze(0)
440
+
441
+ single_audio = self._prepare_audio_for_save(single_audio, normalize)
442
+ file_path = os.path.join(output_dir, f"{batch_prefix}{i}.wav")
443
+ sf.write(file_path, single_audio, sampling_rate)
444
+ saved_paths.append(file_path)
445
+ else:
446
+ # Single audio with batch and channel dims
447
+ audio_item = audio_np.squeeze() # Remove batch and channel dimensions
448
+ audio_item = self._prepare_audio_for_save(audio_item, normalize)
449
+ sf.write(output_path, audio_item, sampling_rate)
450
+ saved_paths.append(output_path)
451
+ else:
452
+ # Single audio without batch dimension
453
+ audio_item = self._prepare_audio_for_save(audio_np, normalize)
454
+ sf.write(output_path, audio_item, sampling_rate)
455
+ saved_paths.append(output_path)
456
+
457
+ return saved_paths
458
+
459
+ def _prepare_audio_for_save(self, audio: np.ndarray, normalize: bool) -> np.ndarray:
460
+ """
461
+ Prepare audio for saving by ensuring it's the right shape and optionally normalizing.
462
+
463
+ Args:
464
+ audio: Audio data as numpy array
465
+ normalize: Whether to normalize audio
466
+
467
+ Returns:
468
+ np.ndarray: Processed audio ready for saving
469
+ """
470
+ # Ensure right dimensionality
471
+ if len(audio.shape) > 1 and audio.shape[0] == 1: # (1, T)
472
+ audio = audio.squeeze(0)
473
+
474
+ # Normalize if requested
475
+ if normalize:
476
+ max_val = np.abs(audio).max()
477
+ if max_val > 0:
478
+ audio = audio / max_val
479
+
480
+ return audio
481
+
482
+
483
+ __all__ = ["VibeVoiceTokenizerProcessor", "AudioNormalizer"]
schedule/__init__.py ADDED
File without changes
schedule/dpm_solver.py ADDED
@@ -0,0 +1,1065 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 TSAIL Team and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver
16
+
17
+ import math
18
+ from typing import List, Optional, Tuple, Union
19
+
20
+ import numpy as np
21
+ import torch
22
+
23
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
24
+ from diffusers.utils import deprecate
25
+ from diffusers.utils.torch_utils import randn_tensor
26
+ from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
27
+
28
+ def betas_for_alpha_bar(
29
+ num_diffusion_timesteps,
30
+ max_beta=0.999,
31
+ alpha_transform_type="cosine",
32
+ ):
33
+ """
34
+ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
35
+ (1-beta) over time from t = [0,1].
36
+
37
+ Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
38
+ to that part of the diffusion process.
39
+
40
+
41
+ Args:
42
+ num_diffusion_timesteps (`int`): the number of betas to produce.
43
+ max_beta (`float`): the maximum beta to use; use values lower than 1 to
44
+ prevent singularities.
45
+ alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
46
+ Choose from `cosine` or `exp`
47
+
48
+ Returns:
49
+ betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
50
+ """
51
+ if alpha_transform_type == "cosine":
52
+
53
+ def alpha_bar_fn(t):
54
+ return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
55
+ # return math.cos(t * math.pi / 2 * 0.95) ** 2
56
+
57
+ elif alpha_transform_type == "exp":
58
+
59
+ def alpha_bar_fn(t):
60
+ return math.exp(t * -12.0)
61
+
62
+ elif alpha_transform_type == "cauchy":
63
+ # µ + γ tan (π (0.5 - x)) γ = 1, µ = 3
64
+ # alpha^2 = 1-1/(exp(λ)+1)
65
+ def alpha_bar_fn(t, gamma=1, mu=3):
66
+ snr = mu + gamma * math.tan(math.pi * (0.5 - t) * 0.9)
67
+ return 1 - 1 / (math.exp(snr) + 1.1)
68
+
69
+ elif alpha_transform_type == "laplace":
70
+ # µ − bsgn(0.5 − t) log(1 − 2|t − 0.5|) µ = 0, b = 1
71
+ def alpha_bar_fn(t, mu=0, b=1):
72
+ snr = mu - b * math.copysign(1, 0.5 - t) * math.log(1 - 2 * abs(t - 0.5) * 0.98)
73
+ return 1 - 1 / (math.exp(snr) + 1.02)
74
+
75
+ else:
76
+ raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
77
+
78
+ betas = []
79
+ for i in range(num_diffusion_timesteps):
80
+ t1 = i / num_diffusion_timesteps
81
+ t2 = (i + 1) / num_diffusion_timesteps
82
+ betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
83
+ return torch.tensor(betas, dtype=torch.float32)
84
+
85
+
86
+ # Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
87
+ def rescale_zero_terminal_snr(betas):
88
+ """
89
+ Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
90
+
91
+
92
+ Args:
93
+ betas (`torch.Tensor`):
94
+ the betas that the scheduler is being initialized with.
95
+
96
+ Returns:
97
+ `torch.Tensor`: rescaled betas with zero terminal SNR
98
+ """
99
+ # Convert betas to alphas_bar_sqrt
100
+ alphas = 1.0 - betas
101
+ alphas_cumprod = torch.cumprod(alphas, dim=0)
102
+ alphas_bar_sqrt = alphas_cumprod.sqrt()
103
+
104
+ # Store old values.
105
+ alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
106
+ alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
107
+
108
+ # Shift so the last timestep is zero.
109
+ alphas_bar_sqrt -= alphas_bar_sqrt_T
110
+
111
+ # Scale so the first timestep is back to the old value.
112
+ alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
113
+
114
+ # Convert alphas_bar_sqrt to betas
115
+ alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
116
+ alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
117
+ alphas = torch.cat([alphas_bar[0:1], alphas])
118
+ betas = 1 - alphas
119
+
120
+ return betas
121
+
122
+ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
123
+ """
124
+ `DPMSolverMultistepScheduler` is a fast dedicated high-order solver for diffusion ODEs.
125
+
126
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
127
+ methods the library implements for all schedulers such as loading and saving.
128
+
129
+ Args:
130
+ num_train_timesteps (`int`, defaults to 1000):
131
+ The number of diffusion steps to train the model.
132
+ beta_start (`float`, defaults to 0.0001):
133
+ The starting `beta` value of inference.
134
+ beta_end (`float`, defaults to 0.02):
135
+ The final `beta` value.
136
+ beta_schedule (`str`, defaults to `"linear"`):
137
+ The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
138
+ `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
139
+ trained_betas (`np.ndarray`, *optional*):
140
+ Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
141
+ solver_order (`int`, defaults to 2):
142
+ The DPMSolver order which can be `1` or `2` or `3`. It is recommended to use `solver_order=2` for guided
143
+ sampling, and `solver_order=3` for unconditional sampling.
144
+ prediction_type (`str`, defaults to `epsilon`, *optional*):
145
+ Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
146
+ `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
147
+ Video](https://imagen.research.google/video/paper.pdf) paper).
148
+ thresholding (`bool`, defaults to `False`):
149
+ Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
150
+ as Stable Diffusion.
151
+ dynamic_thresholding_ratio (`float`, defaults to 0.995):
152
+ The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
153
+ sample_max_value (`float`, defaults to 1.0):
154
+ The threshold value for dynamic thresholding. Valid only when `thresholding=True` and
155
+ `algorithm_type="dpmsolver++"`.
156
+ algorithm_type (`str`, defaults to `dpmsolver++`):
157
+ Algorithm type for the solver; can be `dpmsolver`, `dpmsolver++`, `sde-dpmsolver` or `sde-dpmsolver++`. The
158
+ `dpmsolver` type implements the algorithms in the [DPMSolver](https://huggingface.co/papers/2206.00927)
159
+ paper, and the `dpmsolver++` type implements the algorithms in the
160
+ [DPMSolver++](https://huggingface.co/papers/2211.01095) paper. It is recommended to use `dpmsolver++` or
161
+ `sde-dpmsolver++` with `solver_order=2` for guided sampling like in Stable Diffusion.
162
+ solver_type (`str`, defaults to `midpoint`):
163
+ Solver type for the second-order solver; can be `midpoint` or `heun`. The solver type slightly affects the
164
+ sample quality, especially for a small number of steps. It is recommended to use `midpoint` solvers.
165
+ lower_order_final (`bool`, defaults to `True`):
166
+ Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can
167
+ stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10.
168
+ euler_at_final (`bool`, defaults to `False`):
169
+ Whether to use Euler's method in the final step. It is a trade-off between numerical stability and detail
170
+ richness. This can stabilize the sampling of the SDE variant of DPMSolver for small number of inference
171
+ steps, but sometimes may result in blurring.
172
+ use_karras_sigmas (`bool`, *optional*, defaults to `False`):
173
+ Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
174
+ the sigmas are determined according to a sequence of noise levels {σi}.
175
+ use_lu_lambdas (`bool`, *optional*, defaults to `False`):
176
+ Whether to use the uniform-logSNR for step sizes proposed by Lu's DPM-Solver in the noise schedule during
177
+ the sampling process. If `True`, the sigmas and time steps are determined according to a sequence of
178
+ `lambda(t)`.
179
+ final_sigmas_type (`str`, defaults to `"zero"`):
180
+ The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
181
+ sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
182
+ lambda_min_clipped (`float`, defaults to `-inf`):
183
+ Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the
184
+ cosine (`squaredcos_cap_v2`) noise schedule.
185
+ variance_type (`str`, *optional*):
186
+ Set to "learned" or "learned_range" for diffusion models that predict variance. If set, the model's output
187
+ contains the predicted Gaussian variance.
188
+ timestep_spacing (`str`, defaults to `"linspace"`):
189
+ The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
190
+ Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
191
+ steps_offset (`int`, defaults to 0):
192
+ An offset added to the inference steps, as required by some model families.
193
+ rescale_betas_zero_snr (`bool`, defaults to `False`):
194
+ Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
195
+ dark samples instead of limiting it to samples with medium brightness. Loosely related to
196
+ [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
197
+ """
198
+
199
+ _compatibles = [e.name for e in KarrasDiffusionSchedulers]
200
+ order = 1
201
+
202
+ @register_to_config
203
+ def __init__(
204
+ self,
205
+ num_train_timesteps: int = 1000,
206
+ beta_start: float = 0.0001,
207
+ beta_end: float = 0.02,
208
+ beta_schedule: str = "linear",
209
+ trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
210
+ solver_order: int = 2,
211
+ prediction_type: str = "epsilon",
212
+ thresholding: bool = False,
213
+ dynamic_thresholding_ratio: float = 0.995,
214
+ sample_max_value: float = 1.0,
215
+ algorithm_type: str = "dpmsolver++",
216
+ solver_type: str = "midpoint",
217
+ lower_order_final: bool = True,
218
+ euler_at_final: bool = False,
219
+ use_karras_sigmas: Optional[bool] = False,
220
+ use_lu_lambdas: Optional[bool] = False,
221
+ final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
222
+ lambda_min_clipped: float = -float("inf"),
223
+ variance_type: Optional[str] = None,
224
+ timestep_spacing: str = "linspace",
225
+ steps_offset: int = 0,
226
+ rescale_betas_zero_snr: bool = False,
227
+ ):
228
+ if algorithm_type in ["dpmsolver", "sde-dpmsolver"]:
229
+ deprecation_message = f"algorithm_type {algorithm_type} is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead"
230
+ deprecate("algorithm_types dpmsolver and sde-dpmsolver", "1.0.0", deprecation_message)
231
+
232
+ if trained_betas is not None:
233
+ self.betas = torch.tensor(trained_betas, dtype=torch.float32)
234
+ elif beta_schedule == "linear":
235
+ self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
236
+ elif beta_schedule == "scaled_linear":
237
+ # this schedule is very specific to the latent diffusion model.
238
+ self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
239
+ elif beta_schedule == "squaredcos_cap_v2" or beta_schedule == "cosine":
240
+ # Glide cosine schedule
241
+ self.betas = betas_for_alpha_bar(num_train_timesteps, alpha_transform_type="cosine")
242
+ elif beta_schedule == "cauchy":
243
+ self.betas = betas_for_alpha_bar(num_train_timesteps, alpha_transform_type="cauchy")
244
+ elif beta_schedule == "laplace":
245
+ self.betas = betas_for_alpha_bar(num_train_timesteps, alpha_transform_type="laplace")
246
+ else:
247
+ raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
248
+
249
+ if rescale_betas_zero_snr:
250
+ self.betas = rescale_zero_terminal_snr(self.betas)
251
+
252
+ self.alphas = 1.0 - self.betas
253
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
254
+
255
+ if rescale_betas_zero_snr:
256
+ # Close to 0 without being 0 so first sigma is not inf
257
+ # FP16 smallest positive subnormal works well here
258
+ self.alphas_cumprod[-1] = 2**-24
259
+
260
+ # Currently we only support VP-type noise schedule
261
+ self.alpha_t = torch.sqrt(self.alphas_cumprod)
262
+ self.sigma_t = torch.sqrt(1 - self.alphas_cumprod)
263
+ self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t)
264
+ self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5
265
+
266
+ # standard deviation of the initial noise distribution
267
+ self.init_noise_sigma = 1.0
268
+
269
+ # settings for DPM-Solver
270
+ if algorithm_type not in ["dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++"]:
271
+ if algorithm_type == "deis":
272
+ self.register_to_config(algorithm_type="dpmsolver++")
273
+ else:
274
+ raise NotImplementedError(f"{algorithm_type} is not implemented for {self.__class__}")
275
+
276
+ if solver_type not in ["midpoint", "heun"]:
277
+ if solver_type in ["logrho", "bh1", "bh2"]:
278
+ self.register_to_config(solver_type="midpoint")
279
+ else:
280
+ raise NotImplementedError(f"{solver_type} is not implemented for {self.__class__}")
281
+
282
+ if algorithm_type not in ["dpmsolver++", "sde-dpmsolver++"] and final_sigmas_type == "zero":
283
+ raise ValueError(
284
+ f"`final_sigmas_type` {final_sigmas_type} is not supported for `algorithm_type` {algorithm_type}. Please choose `sigma_min` instead."
285
+ )
286
+
287
+ # setable values
288
+ self.num_inference_steps = None
289
+ timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy()
290
+ self.timesteps = torch.from_numpy(timesteps)
291
+ self.model_outputs = [None] * solver_order
292
+ self.lower_order_nums = 0
293
+ self._step_index = None
294
+ self._begin_index = None
295
+ self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
296
+
297
+ @property
298
+ def step_index(self):
299
+ """
300
+ The index counter for current timestep. It will increase 1 after each scheduler step.
301
+ """
302
+ return self._step_index
303
+
304
+ @property
305
+ def begin_index(self):
306
+ """
307
+ The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
308
+ """
309
+ return self._begin_index
310
+
311
+ def set_begin_index(self, begin_index: int = 0):
312
+ """
313
+ Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
314
+
315
+ Args:
316
+ begin_index (`int`):
317
+ The begin index for the scheduler.
318
+ """
319
+ self._begin_index = begin_index
320
+
321
+ def set_timesteps(
322
+ self,
323
+ num_inference_steps: int = None,
324
+ device: Union[str, torch.device] = None,
325
+ timesteps: Optional[List[int]] = None,
326
+ ):
327
+ """
328
+ Sets the discrete timesteps used for the diffusion chain (to be run before inference).
329
+
330
+ Args:
331
+ num_inference_steps (`int`):
332
+ The number of diffusion steps used when generating samples with a pre-trained model.
333
+ device (`str` or `torch.device`, *optional*):
334
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
335
+ timesteps (`List[int]`, *optional*):
336
+ Custom timesteps used to support arbitrary timesteps schedule. If `None`, timesteps will be generated
337
+ based on the `timestep_spacing` attribute. If `timesteps` is passed, `num_inference_steps` and `sigmas`
338
+ must be `None`, and `timestep_spacing` attribute will be ignored.
339
+ """
340
+ if num_inference_steps is None and timesteps is None:
341
+ raise ValueError("Must pass exactly one of `num_inference_steps` or `timesteps`.")
342
+ if num_inference_steps is not None and timesteps is not None:
343
+ raise ValueError("Can only pass one of `num_inference_steps` or `custom_timesteps`.")
344
+ if timesteps is not None and self.config.use_karras_sigmas:
345
+ raise ValueError("Cannot use `timesteps` with `config.use_karras_sigmas = True`")
346
+ if timesteps is not None and self.config.use_lu_lambdas:
347
+ raise ValueError("Cannot use `timesteps` with `config.use_lu_lambdas = True`")
348
+
349
+ if timesteps is not None:
350
+ timesteps = np.array(timesteps).astype(np.int64)
351
+ else:
352
+ # Clipping the minimum of all lambda(t) for numerical stability.
353
+ # This is critical for cosine (squaredcos_cap_v2) noise schedule.
354
+ clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.config.lambda_min_clipped)
355
+ last_timestep = ((self.config.num_train_timesteps - clipped_idx).numpy()).item()
356
+
357
+ # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
358
+ if self.config.timestep_spacing == "linspace":
359
+ timesteps = (
360
+ np.linspace(0, last_timestep - 1, num_inference_steps + 1)
361
+ .round()[::-1][:-1]
362
+ .copy()
363
+ .astype(np.int64)
364
+ )
365
+ elif self.config.timestep_spacing == "leading":
366
+ step_ratio = last_timestep // (num_inference_steps + 1)
367
+ # creates integer timesteps by multiplying by ratio
368
+ # casting to int to avoid issues when num_inference_step is power of 3
369
+ timesteps = (
370
+ (np.arange(0, num_inference_steps + 1) * step_ratio).round()[::-1][:-1].copy().astype(np.int64)
371
+ )
372
+ timesteps += self.config.steps_offset
373
+ elif self.config.timestep_spacing == "trailing":
374
+ step_ratio = self.config.num_train_timesteps / num_inference_steps
375
+ # creates integer timesteps by multiplying by ratio
376
+ # casting to int to avoid issues when num_inference_step is power of 3
377
+ timesteps = np.arange(last_timestep, 0, -step_ratio).round().copy().astype(np.int64)
378
+ timesteps -= 1
379
+ else:
380
+ raise ValueError(
381
+ f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
382
+ )
383
+
384
+ sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
385
+ log_sigmas = np.log(sigmas)
386
+
387
+ if self.config.use_karras_sigmas:
388
+ sigmas = np.flip(sigmas).copy()
389
+ sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
390
+ timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
391
+ elif self.config.use_lu_lambdas:
392
+ lambdas = np.flip(log_sigmas.copy())
393
+ lambdas = self._convert_to_lu(in_lambdas=lambdas, num_inference_steps=num_inference_steps)
394
+ sigmas = np.exp(lambdas)
395
+ timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
396
+ else:
397
+ sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
398
+
399
+ if self.config.final_sigmas_type == "sigma_min":
400
+ sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
401
+ elif self.config.final_sigmas_type == "zero":
402
+ sigma_last = 0
403
+ else:
404
+ raise ValueError(
405
+ f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
406
+ )
407
+
408
+ sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
409
+
410
+ self.sigmas = torch.from_numpy(sigmas)
411
+ self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64)
412
+
413
+ self.num_inference_steps = len(timesteps)
414
+
415
+ self.model_outputs = [
416
+ None,
417
+ ] * self.config.solver_order
418
+ self.lower_order_nums = 0
419
+
420
+ # add an index counter for schedulers that allow duplicated timesteps
421
+ self._step_index = None
422
+ self._begin_index = None
423
+ self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
424
+
425
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
426
+ def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
427
+ """
428
+ "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
429
+ prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
430
+ s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
431
+ pixels from saturation at each step. We find that dynamic thresholding results in significantly better
432
+ photorealism as well as better image-text alignment, especially when using very large guidance weights."
433
+
434
+ https://arxiv.org/abs/2205.11487
435
+ """
436
+ dtype = sample.dtype
437
+ batch_size, channels, *remaining_dims = sample.shape
438
+
439
+ if dtype not in (torch.float32, torch.float64):
440
+ sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
441
+
442
+ # Flatten sample for doing quantile calculation along each image
443
+ sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
444
+
445
+ abs_sample = sample.abs() # "a certain percentile absolute pixel value"
446
+
447
+ s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
448
+ s = torch.clamp(
449
+ s, min=1, max=self.config.sample_max_value
450
+ ) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
451
+ s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
452
+ sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
453
+
454
+ sample = sample.reshape(batch_size, channels, *remaining_dims)
455
+ sample = sample.to(dtype)
456
+
457
+ return sample
458
+
459
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
460
+ def _sigma_to_t(self, sigma, log_sigmas):
461
+ # get log sigma
462
+ log_sigma = np.log(np.maximum(sigma, 1e-10))
463
+
464
+ # get distribution
465
+ dists = log_sigma - log_sigmas[:, np.newaxis]
466
+
467
+ # get sigmas range
468
+ low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2)
469
+ high_idx = low_idx + 1
470
+
471
+ low = log_sigmas[low_idx]
472
+ high = log_sigmas[high_idx]
473
+
474
+ # interpolate sigmas
475
+ w = (low - log_sigma) / (low - high)
476
+ w = np.clip(w, 0, 1)
477
+
478
+ # transform interpolation to time range
479
+ t = (1 - w) * low_idx + w * high_idx
480
+ t = t.reshape(sigma.shape)
481
+ return t
482
+
483
+ def _sigma_to_alpha_sigma_t(self, sigma):
484
+ alpha_t = 1 / ((sigma**2 + 1) ** 0.5)
485
+ sigma_t = sigma * alpha_t
486
+
487
+ return alpha_t, sigma_t
488
+
489
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
490
+ def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
491
+ """Constructs the noise schedule of Karras et al. (2022)."""
492
+
493
+ # Hack to make sure that other schedulers which copy this function don't break
494
+ # TODO: Add this logic to the other schedulers
495
+ if hasattr(self.config, "sigma_min"):
496
+ sigma_min = self.config.sigma_min
497
+ else:
498
+ sigma_min = None
499
+
500
+ if hasattr(self.config, "sigma_max"):
501
+ sigma_max = self.config.sigma_max
502
+ else:
503
+ sigma_max = None
504
+
505
+ sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
506
+ sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
507
+
508
+ rho = 7.0 # 7.0 is the value used in the paper
509
+ ramp = np.linspace(0, 1, num_inference_steps)
510
+ min_inv_rho = sigma_min ** (1 / rho)
511
+ max_inv_rho = sigma_max ** (1 / rho)
512
+ sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
513
+ return sigmas
514
+
515
+ def _convert_to_lu(self, in_lambdas: torch.Tensor, num_inference_steps) -> torch.Tensor:
516
+ """Constructs the noise schedule of Lu et al. (2022)."""
517
+
518
+ lambda_min: float = in_lambdas[-1].item()
519
+ lambda_max: float = in_lambdas[0].item()
520
+
521
+ rho = 1.0 # 1.0 is the value used in the paper
522
+ ramp = np.linspace(0, 1, num_inference_steps)
523
+ min_inv_rho = lambda_min ** (1 / rho)
524
+ max_inv_rho = lambda_max ** (1 / rho)
525
+ lambdas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
526
+ return lambdas
527
+
528
+ def convert_model_output(
529
+ self,
530
+ model_output: torch.Tensor,
531
+ *args,
532
+ sample: torch.Tensor = None,
533
+ **kwargs,
534
+ ) -> torch.Tensor:
535
+ """
536
+ Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is
537
+ designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an
538
+ integral of the data prediction model.
539
+
540
+ <Tip>
541
+
542
+ The algorithm and model type are decoupled. You can use either DPMSolver or DPMSolver++ for both noise
543
+ prediction and data prediction models.
544
+
545
+ </Tip>
546
+
547
+ Args:
548
+ model_output (`torch.Tensor`):
549
+ The direct output from the learned diffusion model.
550
+ sample (`torch.Tensor`):
551
+ A current instance of a sample created by the diffusion process.
552
+
553
+ Returns:
554
+ `torch.Tensor`:
555
+ The converted model output.
556
+ """
557
+ timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
558
+ if sample is None:
559
+ if len(args) > 1:
560
+ sample = args[1]
561
+ else:
562
+ raise ValueError("missing `sample` as a required keyward argument")
563
+ if timestep is not None:
564
+ deprecate(
565
+ "timesteps",
566
+ "1.0.0",
567
+ "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
568
+ )
569
+
570
+ # DPM-Solver++ needs to solve an integral of the data prediction model.
571
+ if self.config.algorithm_type in ["dpmsolver++", "sde-dpmsolver++"]:
572
+ if self.config.prediction_type == "epsilon":
573
+ # DPM-Solver and DPM-Solver++ only need the "mean" output.
574
+ if self.config.variance_type in ["learned", "learned_range"]:
575
+ model_output = model_output[:, :3]
576
+ sigma = self.sigmas[self.step_index]
577
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
578
+ x0_pred = (sample - sigma_t * model_output) / alpha_t
579
+ elif self.config.prediction_type == "sample":
580
+ x0_pred = model_output
581
+ elif self.config.prediction_type == "v_prediction":
582
+ sigma = self.sigmas[self.step_index]
583
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
584
+ x0_pred = alpha_t * sample - sigma_t * model_output
585
+ else:
586
+ raise ValueError(
587
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
588
+ " `v_prediction` for the DPMSolverMultistepScheduler."
589
+ )
590
+
591
+ if self.config.thresholding:
592
+ x0_pred = self._threshold_sample(x0_pred)
593
+
594
+ return x0_pred
595
+
596
+ # DPM-Solver needs to solve an integral of the noise prediction model.
597
+ elif self.config.algorithm_type in ["dpmsolver", "sde-dpmsolver"]:
598
+ if self.config.prediction_type == "epsilon":
599
+ # DPM-Solver and DPM-Solver++ only need the "mean" output.
600
+ if self.config.variance_type in ["learned", "learned_range"]:
601
+ epsilon = model_output[:, :3]
602
+ else:
603
+ epsilon = model_output
604
+ elif self.config.prediction_type == "sample":
605
+ sigma = self.sigmas[self.step_index]
606
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
607
+ epsilon = (sample - alpha_t * model_output) / sigma_t
608
+ elif self.config.prediction_type == "v_prediction":
609
+ sigma = self.sigmas[self.step_index]
610
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
611
+ epsilon = alpha_t * model_output + sigma_t * sample
612
+ else:
613
+ raise ValueError(
614
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
615
+ " `v_prediction` for the DPMSolverMultistepScheduler."
616
+ )
617
+
618
+ if self.config.thresholding:
619
+ sigma = self.sigmas[self.step_index]
620
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
621
+ x0_pred = (sample - sigma_t * epsilon) / alpha_t
622
+ x0_pred = self._threshold_sample(x0_pred)
623
+ epsilon = (sample - alpha_t * x0_pred) / sigma_t
624
+
625
+ return epsilon
626
+
627
+ def dpm_solver_first_order_update(
628
+ self,
629
+ model_output: torch.Tensor,
630
+ *args,
631
+ sample: torch.Tensor = None,
632
+ noise: Optional[torch.Tensor] = None,
633
+ **kwargs,
634
+ ) -> torch.Tensor:
635
+ """
636
+ One step for the first-order DPMSolver (equivalent to DDIM).
637
+
638
+ Args:
639
+ model_output (`torch.Tensor`):
640
+ The direct output from the learned diffusion model.
641
+ sample (`torch.Tensor`):
642
+ A current instance of a sample created by the diffusion process.
643
+
644
+ Returns:
645
+ `torch.Tensor`:
646
+ The sample tensor at the previous timestep.
647
+ """
648
+ timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
649
+ prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
650
+ if sample is None:
651
+ if len(args) > 2:
652
+ sample = args[2]
653
+ else:
654
+ raise ValueError(" missing `sample` as a required keyward argument")
655
+ if timestep is not None:
656
+ deprecate(
657
+ "timesteps",
658
+ "1.0.0",
659
+ "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
660
+ )
661
+
662
+ if prev_timestep is not None:
663
+ deprecate(
664
+ "prev_timestep",
665
+ "1.0.0",
666
+ "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
667
+ )
668
+
669
+ sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index]
670
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
671
+ alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s)
672
+ lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
673
+ lambda_s = torch.log(alpha_s) - torch.log(sigma_s)
674
+
675
+ h = lambda_t - lambda_s
676
+ if self.config.algorithm_type == "dpmsolver++":
677
+ x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output
678
+ elif self.config.algorithm_type == "dpmsolver":
679
+ x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output
680
+ elif self.config.algorithm_type == "sde-dpmsolver++":
681
+ assert noise is not None
682
+ x_t = (
683
+ (sigma_t / sigma_s * torch.exp(-h)) * sample
684
+ + (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output
685
+ + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
686
+ )
687
+ elif self.config.algorithm_type == "sde-dpmsolver":
688
+ assert noise is not None
689
+ x_t = (
690
+ (alpha_t / alpha_s) * sample
691
+ - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * model_output
692
+ + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise
693
+ )
694
+ return x_t
695
+
696
+ def multistep_dpm_solver_second_order_update(
697
+ self,
698
+ model_output_list: List[torch.Tensor],
699
+ *args,
700
+ sample: torch.Tensor = None,
701
+ noise: Optional[torch.Tensor] = None,
702
+ **kwargs,
703
+ ) -> torch.Tensor:
704
+ """
705
+ One step for the second-order multistep DPMSolver.
706
+
707
+ Args:
708
+ model_output_list (`List[torch.Tensor]`):
709
+ The direct outputs from learned diffusion model at current and latter timesteps.
710
+ sample (`torch.Tensor`):
711
+ A current instance of a sample created by the diffusion process.
712
+
713
+ Returns:
714
+ `torch.Tensor`:
715
+ The sample tensor at the previous timestep.
716
+ """
717
+ timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None)
718
+ prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
719
+ if sample is None:
720
+ if len(args) > 2:
721
+ sample = args[2]
722
+ else:
723
+ raise ValueError(" missing `sample` as a required keyward argument")
724
+ if timestep_list is not None:
725
+ deprecate(
726
+ "timestep_list",
727
+ "1.0.0",
728
+ "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
729
+ )
730
+
731
+ if prev_timestep is not None:
732
+ deprecate(
733
+ "prev_timestep",
734
+ "1.0.0",
735
+ "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
736
+ )
737
+
738
+ sigma_t, sigma_s0, sigma_s1 = (
739
+ self.sigmas[self.step_index + 1],
740
+ self.sigmas[self.step_index],
741
+ self.sigmas[self.step_index - 1],
742
+ )
743
+
744
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
745
+ alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
746
+ alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)
747
+
748
+ lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
749
+ lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
750
+ lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1)
751
+
752
+ m0, m1 = model_output_list[-1], model_output_list[-2]
753
+
754
+ h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1
755
+ r0 = h_0 / h
756
+ D0, D1 = m0, (1.0 / r0) * (m0 - m1)
757
+ if self.config.algorithm_type == "dpmsolver++":
758
+ # See https://arxiv.org/abs/2211.01095 for detailed derivations
759
+ if self.config.solver_type == "midpoint":
760
+ x_t = (
761
+ (sigma_t / sigma_s0) * sample
762
+ - (alpha_t * (torch.exp(-h) - 1.0)) * D0
763
+ - 0.5 * (alpha_t * (torch.exp(-h) - 1.0)) * D1
764
+ )
765
+ elif self.config.solver_type == "heun":
766
+ x_t = (
767
+ (sigma_t / sigma_s0) * sample
768
+ - (alpha_t * (torch.exp(-h) - 1.0)) * D0
769
+ + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1
770
+ )
771
+ elif self.config.algorithm_type == "dpmsolver":
772
+ # See https://arxiv.org/abs/2206.00927 for detailed derivations
773
+ if self.config.solver_type == "midpoint":
774
+ x_t = (
775
+ (alpha_t / alpha_s0) * sample
776
+ - (sigma_t * (torch.exp(h) - 1.0)) * D0
777
+ - 0.5 * (sigma_t * (torch.exp(h) - 1.0)) * D1
778
+ )
779
+ elif self.config.solver_type == "heun":
780
+ x_t = (
781
+ (alpha_t / alpha_s0) * sample
782
+ - (sigma_t * (torch.exp(h) - 1.0)) * D0
783
+ - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
784
+ )
785
+ elif self.config.algorithm_type == "sde-dpmsolver++":
786
+ assert noise is not None
787
+ if self.config.solver_type == "midpoint":
788
+ x_t = (
789
+ (sigma_t / sigma_s0 * torch.exp(-h)) * sample
790
+ + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0
791
+ + 0.5 * (alpha_t * (1 - torch.exp(-2.0 * h))) * D1
792
+ + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
793
+ )
794
+ elif self.config.solver_type == "heun":
795
+ x_t = (
796
+ (sigma_t / sigma_s0 * torch.exp(-h)) * sample
797
+ + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0
798
+ + (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1
799
+ + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
800
+ )
801
+ elif self.config.algorithm_type == "sde-dpmsolver":
802
+ assert noise is not None
803
+ if self.config.solver_type == "midpoint":
804
+ x_t = (
805
+ (alpha_t / alpha_s0) * sample
806
+ - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0
807
+ - (sigma_t * (torch.exp(h) - 1.0)) * D1
808
+ + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise
809
+ )
810
+ elif self.config.solver_type == "heun":
811
+ x_t = (
812
+ (alpha_t / alpha_s0) * sample
813
+ - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0
814
+ - 2.0 * (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
815
+ + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise
816
+ )
817
+ return x_t
818
+
819
+ def multistep_dpm_solver_third_order_update(
820
+ self,
821
+ model_output_list: List[torch.Tensor],
822
+ *args,
823
+ sample: torch.Tensor = None,
824
+ **kwargs,
825
+ ) -> torch.Tensor:
826
+ """
827
+ One step for the third-order multistep DPMSolver.
828
+
829
+ Args:
830
+ model_output_list (`List[torch.Tensor]`):
831
+ The direct outputs from learned diffusion model at current and latter timesteps.
832
+ sample (`torch.Tensor`):
833
+ A current instance of a sample created by diffusion process.
834
+
835
+ Returns:
836
+ `torch.Tensor`:
837
+ The sample tensor at the previous timestep.
838
+ """
839
+
840
+ timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None)
841
+ prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
842
+ if sample is None:
843
+ if len(args) > 2:
844
+ sample = args[2]
845
+ else:
846
+ raise ValueError(" missing`sample` as a required keyward argument")
847
+ if timestep_list is not None:
848
+ deprecate(
849
+ "timestep_list",
850
+ "1.0.0",
851
+ "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
852
+ )
853
+
854
+ if prev_timestep is not None:
855
+ deprecate(
856
+ "prev_timestep",
857
+ "1.0.0",
858
+ "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
859
+ )
860
+
861
+ sigma_t, sigma_s0, sigma_s1, sigma_s2 = (
862
+ self.sigmas[self.step_index + 1],
863
+ self.sigmas[self.step_index],
864
+ self.sigmas[self.step_index - 1],
865
+ self.sigmas[self.step_index - 2],
866
+ )
867
+
868
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
869
+ alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
870
+ alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)
871
+ alpha_s2, sigma_s2 = self._sigma_to_alpha_sigma_t(sigma_s2)
872
+
873
+ lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
874
+ lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
875
+ lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1)
876
+ lambda_s2 = torch.log(alpha_s2) - torch.log(sigma_s2)
877
+
878
+ m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3]
879
+
880
+ h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2
881
+ r0, r1 = h_0 / h, h_1 / h
882
+ D0 = m0
883
+ D1_0, D1_1 = (1.0 / r0) * (m0 - m1), (1.0 / r1) * (m1 - m2)
884
+ D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1)
885
+ D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1)
886
+ if self.config.algorithm_type == "dpmsolver++":
887
+ # See https://arxiv.org/abs/2206.00927 for detailed derivations
888
+ x_t = (
889
+ (sigma_t / sigma_s0) * sample
890
+ - (alpha_t * (torch.exp(-h) - 1.0)) * D0
891
+ + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1
892
+ - (alpha_t * ((torch.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2
893
+ )
894
+ elif self.config.algorithm_type == "dpmsolver":
895
+ # See https://arxiv.org/abs/2206.00927 for detailed derivations
896
+ x_t = (
897
+ (alpha_t / alpha_s0) * sample
898
+ - (sigma_t * (torch.exp(h) - 1.0)) * D0
899
+ - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
900
+ - (sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2
901
+ )
902
+ return x_t
903
+
904
+ def index_for_timestep(self, timestep, schedule_timesteps=None):
905
+ if schedule_timesteps is None:
906
+ schedule_timesteps = self.timesteps
907
+
908
+ index_candidates = (schedule_timesteps == timestep).nonzero()
909
+
910
+ if len(index_candidates) == 0:
911
+ step_index = len(self.timesteps) - 1
912
+ # The sigma index that is taken for the **very** first `step`
913
+ # is always the second index (or the last index if there is only 1)
914
+ # This way we can ensure we don't accidentally skip a sigma in
915
+ # case we start in the middle of the denoising schedule (e.g. for image-to-image)
916
+ elif len(index_candidates) > 1:
917
+ step_index = index_candidates[1].item()
918
+ else:
919
+ step_index = index_candidates[0].item()
920
+
921
+ return step_index
922
+
923
+ def _init_step_index(self, timestep):
924
+ """
925
+ Initialize the step_index counter for the scheduler.
926
+ """
927
+
928
+ if self.begin_index is None:
929
+ if isinstance(timestep, torch.Tensor):
930
+ timestep = timestep.to(self.timesteps.device)
931
+ self._step_index = self.index_for_timestep(timestep)
932
+ else:
933
+ self._step_index = self._begin_index
934
+
935
+ def step(
936
+ self,
937
+ model_output: torch.Tensor,
938
+ timestep: int,
939
+ sample: torch.Tensor,
940
+ generator=None,
941
+ variance_noise: Optional[torch.Tensor] = None,
942
+ return_dict: bool = True,
943
+ ) -> Union[SchedulerOutput, Tuple]:
944
+ """
945
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with
946
+ the multistep DPMSolver.
947
+
948
+ Args:
949
+ model_output (`torch.Tensor`):
950
+ The direct output from learned diffusion model.
951
+ timestep (`int`):
952
+ The current discrete timestep in the diffusion chain.
953
+ sample (`torch.Tensor`):
954
+ A current instance of a sample created by the diffusion process.
955
+ generator (`torch.Generator`, *optional*):
956
+ A random number generator.
957
+ variance_noise (`torch.Tensor`):
958
+ Alternative to generating noise with `generator` by directly providing the noise for the variance
959
+ itself. Useful for methods such as [`LEdits++`].
960
+ return_dict (`bool`):
961
+ Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
962
+
963
+ Returns:
964
+ [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
965
+ If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
966
+ tuple is returned where the first element is the sample tensor.
967
+
968
+ """
969
+ if self.num_inference_steps is None:
970
+ raise ValueError(
971
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
972
+ )
973
+
974
+ if self.step_index is None:
975
+ self._init_step_index(timestep)
976
+
977
+ # Improve numerical stability for small number of steps
978
+ lower_order_final = (self.step_index == len(self.timesteps) - 1) and (
979
+ self.config.euler_at_final
980
+ or (self.config.lower_order_final and len(self.timesteps) < 15)
981
+ or self.config.final_sigmas_type == "zero"
982
+ )
983
+ lower_order_second = (
984
+ (self.step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15
985
+ )
986
+
987
+ model_output = self.convert_model_output(model_output, sample=sample)
988
+ for i in range(self.config.solver_order - 1):
989
+ self.model_outputs[i] = self.model_outputs[i + 1]
990
+ self.model_outputs[-1] = model_output
991
+
992
+ # Upcast to avoid precision issues when computing prev_sample
993
+ sample = sample.to(torch.float32)
994
+ if self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"] and variance_noise is None:
995
+ noise = randn_tensor(
996
+ model_output.shape, generator=generator, device=model_output.device, dtype=torch.float32
997
+ )
998
+ elif self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]:
999
+ noise = variance_noise.to(device=model_output.device, dtype=torch.float32)
1000
+ else:
1001
+ noise = None
1002
+
1003
+ if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final:
1004
+ prev_sample = self.dpm_solver_first_order_update(model_output, sample=sample, noise=noise)
1005
+ elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second:
1006
+ prev_sample = self.multistep_dpm_solver_second_order_update(self.model_outputs, sample=sample, noise=noise)
1007
+ else:
1008
+ prev_sample = self.multistep_dpm_solver_third_order_update(self.model_outputs, sample=sample)
1009
+
1010
+ if self.lower_order_nums < self.config.solver_order:
1011
+ self.lower_order_nums += 1
1012
+
1013
+ # Cast sample back to expected dtype
1014
+ prev_sample = prev_sample.to(model_output.dtype)
1015
+
1016
+ # upon completion increase step index by one
1017
+ self._step_index += 1
1018
+
1019
+ if not return_dict:
1020
+ return (prev_sample,)
1021
+
1022
+ return SchedulerOutput(prev_sample=prev_sample)
1023
+
1024
+ def add_noise(
1025
+ self,
1026
+ original_samples: torch.Tensor,
1027
+ noise: torch.Tensor,
1028
+ timesteps: torch.IntTensor,
1029
+ ) -> torch.Tensor:
1030
+ # Make sure sigmas and timesteps have the same device and dtype as original_samples
1031
+ # alpha_t = self.alpha_t.to(device=original_samples.device, dtype=original_samples.dtype)
1032
+ # sigma_t = self.sigma_t.to(device=original_samples.device, dtype=original_samples.dtype)
1033
+ alpha_t = self.alpha_t.to(original_samples.device).to(original_samples.dtype)
1034
+ sigma_t = self.sigma_t.to(original_samples.device).to(original_samples.dtype)
1035
+ timesteps = timesteps.to(original_samples.device)
1036
+ alpha_t = alpha_t[timesteps].flatten()
1037
+ while len(alpha_t.shape) < len(original_samples.shape):
1038
+ alpha_t = alpha_t.unsqueeze(-1)
1039
+
1040
+ sigma_t = sigma_t[timesteps].flatten()
1041
+ while len(sigma_t.shape) < len(original_samples.shape):
1042
+ sigma_t = sigma_t.unsqueeze(-1)
1043
+ noisy_samples = alpha_t * original_samples + sigma_t * noise
1044
+ return noisy_samples
1045
+
1046
+ def get_velocity(self, original_samples: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor) -> torch.Tensor:
1047
+ # alpha_t = self.alpha_t.to(device=original_samples.device, dtype=original_samples.dtype)
1048
+ # sigma_t = self.sigma_t.to(device=original_samples.device, dtype=original_samples.dtype)
1049
+ alpha_t = self.alpha_t.to(original_samples.device).to(original_samples.dtype)
1050
+ sigma_t = self.sigma_t.to(original_samples.device).to(original_samples.dtype)
1051
+
1052
+ timesteps = timesteps.to(original_samples.device)
1053
+ alpha_t = alpha_t[timesteps].flatten()
1054
+ while len(alpha_t.shape) < len(original_samples.shape):
1055
+ alpha_t = alpha_t.unsqueeze(-1)
1056
+
1057
+ sigma_t = sigma_t[timesteps].flatten()
1058
+ while len(sigma_t.shape) < len(original_samples.shape):
1059
+ sigma_t = sigma_t.unsqueeze(-1)
1060
+
1061
+ velocity = alpha_t * noise - sigma_t * original_samples
1062
+ return velocity
1063
+
1064
+ def __len__(self):
1065
+ return self.config.num_train_timesteps
schedule/timestep_sampler.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+
4
+
5
+ class UniformSampler:
6
+ def __init__(self, timesteps = 1000):
7
+ self.timesteps = timesteps
8
+ def sample(self, batch_size, device):
9
+ return torch.randint(0, self.timesteps, (batch_size,), device=device)
10
+
11
+ class LogitNormalSampler:
12
+ def __init__(self, timesteps = 1000, m = 0, s = 1):
13
+ self.timesteps = timesteps
14
+ timesteps = torch.linspace(0, 1, timesteps)
15
+ logit = torch.log(timesteps / (1 - timesteps))
16
+ self.prob = torch.exp(-0.5 * (logit - m) ** 2 / s ** 2) / (s * math.sqrt(2 * math.pi))
17
+ def sample(self, batch_size, device):
18
+ return torch.multinomial(self.prob, batch_size, replacement=True).to(device)
19
+
scripts/__init__.py ADDED
File without changes
scripts/convert_nnscaler_checkpoint_to_transformers.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+
4
+ import argparse
5
+ import json
6
+ import os
7
+ from pathlib import Path
8
+ import re
9
+ import torch
10
+ from typing import Dict, List, Tuple
11
+
12
+ from vibevoice.modular.configuration_vibevoice import (
13
+ VibeVoiceConfig
14
+ )
15
+ from vibevoice.modular.modeling_vibevoice import VibeVoiceForConditionalGeneration
16
+ from transformers.utils import logging
17
+
18
+ logger = logging.get_logger(__name__)
19
+
20
+ def convert_vibevoice_nnscaler_checkpoint_to_hf(
21
+ checkpoint_path: str,
22
+ pytorch_dump_folder_path: str,
23
+ config_path: str = None,
24
+ ):
25
+ """
26
+ Convert a nnscaler VibeVoice checkpoint to HuggingFace format.
27
+ Supports both regular checkpoints and tensor parallel checkpoints.
28
+ """
29
+
30
+ # Load regular checkpoint
31
+ logger.info(f"Loading regular checkpoint from {checkpoint_path}")
32
+ checkpoint = torch.load(checkpoint_path, map_location="cpu") # ['model', 'optimizer', 'lr_scheduler', 'train_status', 'train_args', 'rng_states', 'nnscaler', 'dataloader']
33
+
34
+ # config = checkpoint['train_args']
35
+ init_config_name = checkpoint['train_args']['vars']['model_args']['config_path']['relative_path']
36
+ pretrained_name = checkpoint['train_args']['vars']['data_args']['tokenizer_path']
37
+
38
+ init_config_path = Path(__file__).parent.parent / 'configs' / init_config_name.split('/')[-1]
39
+ if init_config_path.exists():
40
+ logger.info(f"Loading initial config from {init_config_path}")
41
+ with open(init_config_path, 'r') as f:
42
+ init_config = json.load(f)
43
+ else:
44
+ raise FileNotFoundError(f"Initial config file {init_config_path} not found. Please provide a valid path.")
45
+
46
+ tie_word_embeddings = init_config['decoder_config'].get('tie_word_embeddings', True)
47
+ logger.info(f"Tie word embeddings: {tie_word_embeddings}")
48
+
49
+ init_config['decoder_config']['use_cache'] = True
50
+ config = VibeVoiceConfig(**init_config, tie_word_embeddings=tie_word_embeddings)
51
+
52
+ # # Extract the model state dict
53
+ model_state_dict = {k.replace('model.model.', 'model.'): v for k, v in checkpoint["model"].items() if k.startswith('model.model.')}
54
+ if not tie_word_embeddings and 'model.lm_head.weight' in checkpoint["model"].keys():
55
+ # If not tying weights, we need to add the lm_head weight separately
56
+ model_state_dict['lm_head.weight'] = checkpoint["model"]['model.lm_head.weight']
57
+
58
+ # Override with provided config if available
59
+ if config_path:
60
+ logger.info(f"Loading config from {config_path}")
61
+ with open(config_path, 'r') as f:
62
+ config_dict = json.load(f)
63
+ config = VibeVoiceConfig.from_dict(config_dict)
64
+
65
+ # Set the default dtype to bfloat16 before creating the model
66
+ original_dtype = torch.get_default_dtype()
67
+ torch.set_default_dtype(torch.bfloat16)
68
+
69
+ # Create the HuggingFace model
70
+ logger.info("Creating HuggingFace VibeVoiceForConditionalGeneration model")
71
+ model = VibeVoiceForConditionalGeneration(config)
72
+
73
+ # Restore original dtype
74
+ torch.set_default_dtype(original_dtype)
75
+
76
+ # Load the state dict
77
+ logger.info("Loading weights into model")
78
+ missing_keys, unexpected_keys = model.load_state_dict(model_state_dict, strict=False)
79
+
80
+ if missing_keys:
81
+ logger.warning(f"Missing keys: {missing_keys}")
82
+ if unexpected_keys:
83
+ logger.warning(f"Unexpected keys: {unexpected_keys}")
84
+
85
+ # Create output directory
86
+ os.makedirs(pytorch_dump_folder_path, exist_ok=True)
87
+
88
+ # Save the model and config
89
+ logger.info(f"Saving model to {pytorch_dump_folder_path}")
90
+
91
+ # Save config
92
+ config.save_pretrained(pytorch_dump_folder_path)
93
+
94
+ # Save VibeVoiceProcessor configuration
95
+ logger.info("Saving VibeVoiceProcessor configuration")
96
+ processor_config = {
97
+ "processor_class": "VibeVoiceProcessor",
98
+ "speech_tok_compress_ratio": 3200,
99
+ "db_normalize": True,
100
+ # Audio processor configuration
101
+ "audio_processor": {
102
+ "feature_extractor_type": "VibeVoiceTokenizerProcessor",
103
+ "sampling_rate": 24000,
104
+ "normalize_audio": True,
105
+ "target_dB_FS": -25,
106
+ "eps": 1e-6,
107
+ },
108
+ "language_model_pretrained_name": pretrained_name,
109
+ }
110
+
111
+ processor_config_path = os.path.join(pytorch_dump_folder_path, "preprocessor_config.json")
112
+ with open(processor_config_path, 'w') as f:
113
+ json.dump(processor_config, f, indent=2)
114
+ logger.info(f"Saved processor config to {processor_config_path}")
115
+
116
+ # Save model with sharding
117
+ # save_pretrained handles tied weights automatically
118
+ logger.info("Saving model weights with sharding...")
119
+ model.save_pretrained(
120
+ pytorch_dump_folder_path,
121
+ max_shard_size="2GB", # Set maximum size for each shard
122
+ safe_serialization=True # Ensure saving in .safetensors format
123
+ )
124
+ logger.info(f"Model weights saved to {pytorch_dump_folder_path}")
125
+
126
+ logger.info("Conversion complete!")
127
+
128
+ # Verify the saved model can be loaded
129
+ logger.info("Verifying saved model...")
130
+ loaded_model = VibeVoiceForConditionalGeneration.from_pretrained(pytorch_dump_folder_path)
131
+ logger.info("Model successfully loaded from saved checkpoint!")
132
+
133
+ def main():
134
+ parser = argparse.ArgumentParser()
135
+ parser.add_argument(
136
+ "--nnscaler_checkpoint_path",
137
+ type=str,
138
+ required=True,
139
+ help="Path to the fairseq checkpoint (.pt file). For tensor parallel checkpoints, "
140
+ "provide any one of the part files (e.g., checkpoint_1_5000-model_part-0.pt), "
141
+ "and the script will automatically detect and merge all parts.",
142
+ )
143
+ parser.add_argument(
144
+ "--pytorch_dump_folder_path",
145
+ type=str,
146
+ required=True,
147
+ help="Path to the output PyTorch model directory",
148
+ )
149
+ parser.add_argument(
150
+ "--config_path",
151
+ type=str,
152
+ default=None,
153
+ help="Optional path to a config JSON file to override extracted config",
154
+ )
155
+
156
+ args = parser.parse_args()
157
+
158
+ convert_vibevoice_nnscaler_checkpoint_to_hf(
159
+ args.nnscaler_checkpoint_path,
160
+ args.pytorch_dump_folder_path,
161
+ args.config_path,
162
+ )
163
+
164
+
165
+ if __name__ == "__main__":
166
+ main()