Spaces:
Running
on
Zero
Running
on
Zero
Upload folder using huggingface_hub
Browse files- __init__.py +0 -0
- configs/qwen2.5_1.5b_64k.json +112 -0
- configs/qwen2.5_7b_32k.json +113 -0
- modular/__init__.py +0 -0
- modular/configuration_vibevoice.py +248 -0
- modular/modeling_vibevoice.py +488 -0
- modular/modeling_vibevoice_inference.py +715 -0
- modular/modular_vibevoice_diffusion_head.py +287 -0
- modular/modular_vibevoice_text_tokenizer.py +214 -0
- modular/modular_vibevoice_tokenizer.py +1195 -0
- modular/streamer.py +264 -0
- processor/__init__.py +0 -0
- processor/vibevoice_processor.py +677 -0
- processor/vibevoice_tokenizer_processor.py +483 -0
- schedule/__init__.py +0 -0
- schedule/dpm_solver.py +1065 -0
- schedule/timestep_sampler.py +19 -0
- scripts/__init__.py +0 -0
- scripts/convert_nnscaler_checkpoint_to_transformers.py +166 -0
__init__.py
ADDED
File without changes
|
configs/qwen2.5_1.5b_64k.json
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_attn_implementation_autoset": true,
|
3 |
+
"acoustic_vae_dim": 64,
|
4 |
+
"acoustic_tokenizer_config": {
|
5 |
+
"causal": true,
|
6 |
+
"channels": 1,
|
7 |
+
"conv_bias": true,
|
8 |
+
"conv_norm": "none",
|
9 |
+
"corpus_normalize": 0.0,
|
10 |
+
"decoder_depths": null,
|
11 |
+
"decoder_n_filters": 32,
|
12 |
+
"decoder_ratios": [
|
13 |
+
8,
|
14 |
+
5,
|
15 |
+
5,
|
16 |
+
4,
|
17 |
+
2,
|
18 |
+
2
|
19 |
+
],
|
20 |
+
"disable_last_norm": true,
|
21 |
+
"encoder_depths": "3-3-3-3-3-3-8",
|
22 |
+
"encoder_n_filters": 32,
|
23 |
+
"encoder_ratios": [
|
24 |
+
8,
|
25 |
+
5,
|
26 |
+
5,
|
27 |
+
4,
|
28 |
+
2,
|
29 |
+
2
|
30 |
+
],
|
31 |
+
"fix_std": 0.5,
|
32 |
+
"layer_scale_init_value": 1e-06,
|
33 |
+
"layernorm": "RMSNorm",
|
34 |
+
"layernorm_elementwise_affine": true,
|
35 |
+
"layernorm_eps": 1e-05,
|
36 |
+
"mixer_layer": "depthwise_conv",
|
37 |
+
"model_type": "vibepod_acoustic_tokenizer",
|
38 |
+
"pad_mode": "constant",
|
39 |
+
"std_dist_type": "gaussian",
|
40 |
+
"vae_dim": 64,
|
41 |
+
"weight_init_value": 0.01
|
42 |
+
},
|
43 |
+
"decoder_config": {
|
44 |
+
"attention_dropout": 0.0,
|
45 |
+
"hidden_act": "silu",
|
46 |
+
"hidden_size": 1536,
|
47 |
+
"initializer_range": 0.02,
|
48 |
+
"intermediate_size": 8960,
|
49 |
+
"max_position_embeddings": 65536,
|
50 |
+
"max_window_layers": 28,
|
51 |
+
"model_type": "qwen2",
|
52 |
+
"num_attention_heads": 12,
|
53 |
+
"num_hidden_layers": 28,
|
54 |
+
"num_key_value_heads": 2,
|
55 |
+
"rms_norm_eps": 1e-06,
|
56 |
+
"rope_scaling": null,
|
57 |
+
"rope_theta": 1000000.0,
|
58 |
+
"sliding_window": null,
|
59 |
+
"tie_word_embeddings": true,
|
60 |
+
"torch_dtype": "bfloat16",
|
61 |
+
"use_cache": true,
|
62 |
+
"use_sliding_window": false,
|
63 |
+
"vocab_size": 151936
|
64 |
+
},
|
65 |
+
"diffusion_head_config": {
|
66 |
+
"ddpm_batch_mul": 4,
|
67 |
+
"ddpm_beta_schedule": "cosine",
|
68 |
+
"ddpm_num_inference_steps": 20,
|
69 |
+
"ddpm_num_steps": 1000,
|
70 |
+
"diffusion_type": "ddpm",
|
71 |
+
"head_ffn_ratio": 3.0,
|
72 |
+
"head_layers": 4,
|
73 |
+
"hidden_size": 1536,
|
74 |
+
"latent_size": 64,
|
75 |
+
"model_type": "vibepod_diffusion_head",
|
76 |
+
"prediction_type": "v_prediction",
|
77 |
+
"rms_norm_eps": 1e-05,
|
78 |
+
"speech_vae_dim": 64
|
79 |
+
},
|
80 |
+
"model_type": "vibepod",
|
81 |
+
"semantic_tokenizer_config": {
|
82 |
+
"causal": true,
|
83 |
+
"channels": 1,
|
84 |
+
"conv_bias": true,
|
85 |
+
"conv_norm": "none",
|
86 |
+
"corpus_normalize": 0.0,
|
87 |
+
"disable_last_norm": true,
|
88 |
+
"encoder_depths": "3-3-3-3-3-3-8",
|
89 |
+
"encoder_n_filters": 32,
|
90 |
+
"encoder_ratios": [
|
91 |
+
8,
|
92 |
+
5,
|
93 |
+
5,
|
94 |
+
4,
|
95 |
+
2,
|
96 |
+
2
|
97 |
+
],
|
98 |
+
"fix_std": 0,
|
99 |
+
"layer_scale_init_value": 1e-06,
|
100 |
+
"layernorm": "RMSNorm",
|
101 |
+
"layernorm_elementwise_affine": true,
|
102 |
+
"layernorm_eps": 1e-05,
|
103 |
+
"mixer_layer": "depthwise_conv",
|
104 |
+
"model_type": "vibepod_semantic_tokenizer",
|
105 |
+
"pad_mode": "constant",
|
106 |
+
"std_dist_type": "none",
|
107 |
+
"vae_dim": 128,
|
108 |
+
"weight_init_value": 0.01
|
109 |
+
},
|
110 |
+
"semantic_vae_dim": 128,
|
111 |
+
"torch_dtype": "bfloat16"
|
112 |
+
}
|
configs/qwen2.5_7b_32k.json
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_attn_implementation_autoset": true,
|
3 |
+
"acoustic_vae_dim": 64,
|
4 |
+
"acoustic_tokenizer_config": {
|
5 |
+
"causal": true,
|
6 |
+
"channels": 1,
|
7 |
+
"conv_bias": true,
|
8 |
+
"conv_norm": "none",
|
9 |
+
"corpus_normalize": 0.0,
|
10 |
+
"decoder_depths": null,
|
11 |
+
"decoder_n_filters": 32,
|
12 |
+
"decoder_ratios": [
|
13 |
+
8,
|
14 |
+
5,
|
15 |
+
5,
|
16 |
+
4,
|
17 |
+
2,
|
18 |
+
2
|
19 |
+
],
|
20 |
+
"disable_last_norm": true,
|
21 |
+
"encoder_depths": "3-3-3-3-3-3-8",
|
22 |
+
"encoder_n_filters": 32,
|
23 |
+
"encoder_ratios": [
|
24 |
+
8,
|
25 |
+
5,
|
26 |
+
5,
|
27 |
+
4,
|
28 |
+
2,
|
29 |
+
2
|
30 |
+
],
|
31 |
+
"fix_std": 0.5,
|
32 |
+
"layer_scale_init_value": 1e-06,
|
33 |
+
"layernorm": "RMSNorm",
|
34 |
+
"layernorm_elementwise_affine": true,
|
35 |
+
"layernorm_eps": 1e-05,
|
36 |
+
"mixer_layer": "depthwise_conv",
|
37 |
+
"model_type": "vibepod_acoustic_tokenizer",
|
38 |
+
"pad_mode": "constant",
|
39 |
+
"std_dist_type": "gaussian",
|
40 |
+
"vae_dim": 64,
|
41 |
+
"weight_init_value": 0.01
|
42 |
+
},
|
43 |
+
"decoder_config": {
|
44 |
+
"attention_dropout": 0.0,
|
45 |
+
"hidden_act": "silu",
|
46 |
+
"hidden_size": 3584,
|
47 |
+
"initializer_range": 0.02,
|
48 |
+
"intermediate_size": 18944,
|
49 |
+
"max_position_embeddings": 32768,
|
50 |
+
"max_window_layers": 28,
|
51 |
+
"model_type": "qwen2",
|
52 |
+
"num_attention_heads": 28,
|
53 |
+
"num_hidden_layers": 28,
|
54 |
+
"num_key_value_heads": 4,
|
55 |
+
"rms_norm_eps": 1e-06,
|
56 |
+
"rope_theta": 1000000.0,
|
57 |
+
"sliding_window": null,
|
58 |
+
"tie_word_embeddings": false,
|
59 |
+
"torch_dtype": "bfloat16",
|
60 |
+
"transformers_version": "4.40.1",
|
61 |
+
"use_cache": true,
|
62 |
+
"use_mrope": false,
|
63 |
+
"use_sliding_window": false,
|
64 |
+
"vocab_size": 152064
|
65 |
+
},
|
66 |
+
"diffusion_head_config": {
|
67 |
+
"ddpm_batch_mul": 4,
|
68 |
+
"ddpm_beta_schedule": "cosine",
|
69 |
+
"ddpm_num_inference_steps": 20,
|
70 |
+
"ddpm_num_steps": 1000,
|
71 |
+
"diffusion_type": "ddpm",
|
72 |
+
"head_ffn_ratio": 3.0,
|
73 |
+
"head_layers": 4,
|
74 |
+
"hidden_size": 3584,
|
75 |
+
"latent_size": 64,
|
76 |
+
"model_type": "vibepod_diffusion_head",
|
77 |
+
"prediction_type": "v_prediction",
|
78 |
+
"rms_norm_eps": 1e-05,
|
79 |
+
"speech_vae_dim": 64
|
80 |
+
},
|
81 |
+
"model_type": "vibepod",
|
82 |
+
"semantic_tokenizer_config": {
|
83 |
+
"causal": true,
|
84 |
+
"channels": 1,
|
85 |
+
"conv_bias": true,
|
86 |
+
"conv_norm": "none",
|
87 |
+
"corpus_normalize": 0.0,
|
88 |
+
"disable_last_norm": true,
|
89 |
+
"encoder_depths": "3-3-3-3-3-3-8",
|
90 |
+
"encoder_n_filters": 32,
|
91 |
+
"encoder_ratios": [
|
92 |
+
8,
|
93 |
+
5,
|
94 |
+
5,
|
95 |
+
4,
|
96 |
+
2,
|
97 |
+
2
|
98 |
+
],
|
99 |
+
"fix_std": 0,
|
100 |
+
"layer_scale_init_value": 1e-06,
|
101 |
+
"layernorm": "RMSNorm",
|
102 |
+
"layernorm_elementwise_affine": true,
|
103 |
+
"layernorm_eps": 1e-05,
|
104 |
+
"mixer_layer": "depthwise_conv",
|
105 |
+
"model_type": "vibepod_semantic_tokenizer",
|
106 |
+
"pad_mode": "constant",
|
107 |
+
"std_dist_type": "none",
|
108 |
+
"vae_dim": 128,
|
109 |
+
"weight_init_value": 0.01
|
110 |
+
},
|
111 |
+
"semantic_vae_dim": 128,
|
112 |
+
"torch_dtype": "bfloat16"
|
113 |
+
}
|
modular/__init__.py
ADDED
File without changes
|
modular/configuration_vibevoice.py
ADDED
@@ -0,0 +1,248 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" VibeVoice_AcousticTokenizer model configuration"""
|
2 |
+
|
3 |
+
from typing import Dict, List, Optional, Tuple
|
4 |
+
|
5 |
+
from transformers.configuration_utils import PretrainedConfig
|
6 |
+
from transformers.utils import logging
|
7 |
+
|
8 |
+
from transformers.models.qwen2.configuration_qwen2 import Qwen2Config
|
9 |
+
|
10 |
+
logger = logging.get_logger(__name__)
|
11 |
+
|
12 |
+
|
13 |
+
class VibeVoiceAcousticTokenizerConfig(PretrainedConfig):
|
14 |
+
model_type = "vibevoice_acoustic_tokenizer"
|
15 |
+
|
16 |
+
def __init__(
|
17 |
+
self,
|
18 |
+
channels: int = 1,
|
19 |
+
corpus_normalize: float = 0.0,
|
20 |
+
causal: bool = True,
|
21 |
+
vae_dim: int = 64,
|
22 |
+
fix_std: float = 0.5,
|
23 |
+
std_dist_type: str = 'gaussian',
|
24 |
+
# common
|
25 |
+
mixer_layer: str = 'depthwise_conv',
|
26 |
+
conv_norm: str = 'none',
|
27 |
+
pad_mode: str = 'constant',
|
28 |
+
disable_last_norm: bool = True,
|
29 |
+
layernorm: str = 'RMSNorm',
|
30 |
+
layernorm_eps: float = 1e-5,
|
31 |
+
layernorm_elementwise_affine: bool = True,
|
32 |
+
conv_bias: bool = True,
|
33 |
+
layer_scale_init_value: float = 1e-6,
|
34 |
+
weight_init_value: float = 1e-2,
|
35 |
+
# encoder specific
|
36 |
+
encoder_n_filters: int = 32,
|
37 |
+
encoder_ratios: Optional[List[int]] = [8,5,5,4,2,2],
|
38 |
+
encoder_depths: str = "3-3-3-3-3-3-8",
|
39 |
+
# decoder specific
|
40 |
+
decoder_n_filters: int = 32,
|
41 |
+
decoder_ratios: Optional[List[int]] = None, # if None, same as encoder
|
42 |
+
decoder_depths: Optional[str] = None,
|
43 |
+
**kwargs
|
44 |
+
):
|
45 |
+
super().__init__(**kwargs)
|
46 |
+
self.channels = channels
|
47 |
+
self.corpus_normalize = corpus_normalize
|
48 |
+
self.causal = causal
|
49 |
+
self.vae_dim = vae_dim
|
50 |
+
self.fix_std = fix_std
|
51 |
+
self.std_dist_type = std_dist_type
|
52 |
+
|
53 |
+
# common parameters
|
54 |
+
self.conv_norm = conv_norm
|
55 |
+
self.pad_mode = pad_mode
|
56 |
+
self.layernorm_eps = layernorm_eps
|
57 |
+
self.disable_last_norm = disable_last_norm
|
58 |
+
self.layernorm = layernorm
|
59 |
+
self.layernorm_elementwise_affine = layernorm_elementwise_affine
|
60 |
+
self.conv_bias = conv_bias
|
61 |
+
self.layer_scale_init_value = layer_scale_init_value
|
62 |
+
self.weight_init_value = weight_init_value
|
63 |
+
self.mixer_layer = mixer_layer
|
64 |
+
|
65 |
+
# encoder specific parameters
|
66 |
+
self.encoder_n_filters = encoder_n_filters
|
67 |
+
self.encoder_ratios = encoder_ratios
|
68 |
+
self.encoder_depths = encoder_depths
|
69 |
+
|
70 |
+
# decoder specific parameters
|
71 |
+
self.decoder_ratios = decoder_ratios if decoder_ratios is not None else encoder_ratios
|
72 |
+
self.decoder_n_filters = decoder_n_filters
|
73 |
+
self.decoder_depths = decoder_depths
|
74 |
+
|
75 |
+
|
76 |
+
class VibeVoiceSemanticTokenizerConfig(PretrainedConfig):
|
77 |
+
model_type = "vibevoice_semantic_tokenizer"
|
78 |
+
|
79 |
+
def __init__(
|
80 |
+
self,
|
81 |
+
channels: int = 1,
|
82 |
+
corpus_normalize: float = 0.0,
|
83 |
+
causal: bool = True,
|
84 |
+
vae_dim: int = 64,
|
85 |
+
fix_std: float = 0,
|
86 |
+
std_dist_type: str = 'none',
|
87 |
+
# common
|
88 |
+
mixer_layer: str = 'depthwise_conv',
|
89 |
+
conv_norm: str = 'none',
|
90 |
+
pad_mode: str = 'constant',
|
91 |
+
disable_last_norm: bool = True,
|
92 |
+
layernorm: str = 'RMSNorm',
|
93 |
+
layernorm_eps: float = 1e-5,
|
94 |
+
layernorm_elementwise_affine: bool = True,
|
95 |
+
conv_bias: bool = True,
|
96 |
+
layer_scale_init_value: float = 1e-6,
|
97 |
+
weight_init_value: float = 1e-2,
|
98 |
+
# encoder specific
|
99 |
+
encoder_n_filters: int = 32,
|
100 |
+
encoder_ratios: Optional[List[int]] = [8,5,5,4,2,2],
|
101 |
+
encoder_depths: str = "3-3-3-3-3-3-8",
|
102 |
+
**kwargs
|
103 |
+
):
|
104 |
+
super().__init__(**kwargs)
|
105 |
+
self.channels = channels
|
106 |
+
self.corpus_normalize = corpus_normalize
|
107 |
+
self.causal = causal
|
108 |
+
self.vae_dim = vae_dim
|
109 |
+
self.fix_std = fix_std
|
110 |
+
self.std_dist_type = std_dist_type
|
111 |
+
|
112 |
+
# common parameters
|
113 |
+
self.conv_norm = conv_norm
|
114 |
+
self.pad_mode = pad_mode
|
115 |
+
self.layernorm_eps = layernorm_eps
|
116 |
+
self.disable_last_norm = disable_last_norm
|
117 |
+
self.layernorm = layernorm
|
118 |
+
self.layernorm_elementwise_affine = layernorm_elementwise_affine
|
119 |
+
self.conv_bias = conv_bias
|
120 |
+
self.layer_scale_init_value = layer_scale_init_value
|
121 |
+
self.weight_init_value = weight_init_value
|
122 |
+
self.mixer_layer = mixer_layer
|
123 |
+
|
124 |
+
# encoder specific parameters
|
125 |
+
self.encoder_n_filters = encoder_n_filters
|
126 |
+
self.encoder_ratios = encoder_ratios
|
127 |
+
self.encoder_depths = encoder_depths
|
128 |
+
|
129 |
+
|
130 |
+
class VibeVoiceDiffusionHeadConfig(PretrainedConfig):
|
131 |
+
model_type = "vibevoice_diffusion_head"
|
132 |
+
|
133 |
+
def __init__(
|
134 |
+
self,
|
135 |
+
hidden_size=768,
|
136 |
+
head_layers=4,
|
137 |
+
head_ffn_ratio=3.0,
|
138 |
+
rms_norm_eps=1e-5,
|
139 |
+
latent_size=64,
|
140 |
+
speech_vae_dim=None,
|
141 |
+
prediction_type="v_prediction",
|
142 |
+
diffusion_type="ddpm",
|
143 |
+
ddpm_num_steps=1000,
|
144 |
+
ddpm_num_inference_steps=20,
|
145 |
+
ddpm_beta_schedule="cosine",
|
146 |
+
ddpm_batch_mul=4,
|
147 |
+
**kwargs
|
148 |
+
):
|
149 |
+
self.hidden_size = hidden_size
|
150 |
+
self.head_layers = head_layers
|
151 |
+
self.head_ffn_ratio = head_ffn_ratio
|
152 |
+
self.rms_norm_eps = rms_norm_eps
|
153 |
+
self.latent_size = latent_size
|
154 |
+
self.speech_vae_dim = speech_vae_dim
|
155 |
+
self.prediction_type = prediction_type
|
156 |
+
self.diffusion_type = diffusion_type
|
157 |
+
self.ddpm_num_steps = ddpm_num_steps
|
158 |
+
self.ddpm_num_inference_steps = ddpm_num_inference_steps
|
159 |
+
self.ddpm_beta_schedule = ddpm_beta_schedule
|
160 |
+
self.ddpm_batch_mul = ddpm_batch_mul
|
161 |
+
|
162 |
+
super().__init__(**kwargs)
|
163 |
+
|
164 |
+
class VibeVoiceConfig(PretrainedConfig):
|
165 |
+
model_type = "vibevoice"
|
166 |
+
is_composition = True
|
167 |
+
sub_configs = {
|
168 |
+
"acoustic_tokenizer_config": VibeVoiceAcousticTokenizerConfig,
|
169 |
+
"semantic_tokenizer_config": VibeVoiceSemanticTokenizerConfig,
|
170 |
+
"decoder_config": Qwen2Config,
|
171 |
+
"diffusion_head_config": VibeVoiceDiffusionHeadConfig,
|
172 |
+
}
|
173 |
+
# keys_to_ignore_at_inference = ["past_key_values"]
|
174 |
+
# Default tensor parallel plan for base model `Qwen2`
|
175 |
+
base_model_tp_plan = {
|
176 |
+
"layers.*.self_attn.q_proj": "colwise",
|
177 |
+
"layers.*.self_attn.k_proj": "colwise",
|
178 |
+
"layers.*.self_attn.v_proj": "colwise",
|
179 |
+
"layers.*.self_attn.o_proj": "rowwise",
|
180 |
+
"layers.*.mlp.gate_proj": "colwise",
|
181 |
+
"layers.*.mlp.up_proj": "colwise",
|
182 |
+
"layers.*.mlp.down_proj": "rowwise",
|
183 |
+
}
|
184 |
+
|
185 |
+
def __init__(
|
186 |
+
self,
|
187 |
+
acoustic_tokenizer_config=None,
|
188 |
+
semantic_tokenizer_config=None,
|
189 |
+
decoder_config=None,
|
190 |
+
diffusion_head_config=None,
|
191 |
+
**kwargs
|
192 |
+
):
|
193 |
+
|
194 |
+
# kwargs["_attn_implementation"] = "flash_attention_2"
|
195 |
+
kwargs["_attn_implementation_autoset"] = False
|
196 |
+
|
197 |
+
if acoustic_tokenizer_config is None:
|
198 |
+
self.acoustic_tokenizer_config = self.sub_configs["acoustic_tokenizer_config"]()
|
199 |
+
elif isinstance(acoustic_tokenizer_config, dict):
|
200 |
+
acoustic_tokenizer_config["model_type"] = "vibevoice_acoustic_tokenizer"
|
201 |
+
self.acoustic_tokenizer_config = self.sub_configs["acoustic_tokenizer_config"](**acoustic_tokenizer_config)
|
202 |
+
elif isinstance(acoustic_tokenizer_config, VibeVoiceAcousticTokenizerConfig):
|
203 |
+
# If an instance of the config class is provided
|
204 |
+
self.acoustic_tokenizer_config = acoustic_tokenizer_config
|
205 |
+
|
206 |
+
if semantic_tokenizer_config is None:
|
207 |
+
self.semantic_tokenizer_config = self.sub_configs["semantic_tokenizer_config"]()
|
208 |
+
elif isinstance(semantic_tokenizer_config, dict):
|
209 |
+
semantic_tokenizer_config["model_type"] = "vibevoice_semantic_tokenizer"
|
210 |
+
self.semantic_tokenizer_config = self.sub_configs["semantic_tokenizer_config"](**semantic_tokenizer_config)
|
211 |
+
elif isinstance(semantic_tokenizer_config, VibeVoiceSemanticTokenizerConfig):
|
212 |
+
# If an instance of the config class is provided
|
213 |
+
self.semantic_tokenizer_config = semantic_tokenizer_config
|
214 |
+
|
215 |
+
if decoder_config is None:
|
216 |
+
self.decoder_config = self.sub_configs["decoder_config"]()
|
217 |
+
elif isinstance(decoder_config, dict):
|
218 |
+
# If a dictionary is provided, instantiate the config class with it
|
219 |
+
# self.decoder_config = self.sub_configs["decoder_config"](**decoder_config)
|
220 |
+
if decoder_config.get("model_type", '') == "qwen2":
|
221 |
+
self.decoder_config = Qwen2Config(**decoder_config)
|
222 |
+
else:
|
223 |
+
raise ValueError(f"Unsupported decoder model type: {decoder_config.get('model_type', '')}")
|
224 |
+
elif isinstance(decoder_config, (Qwen2Config,)):
|
225 |
+
# If an instance of the config class is provided
|
226 |
+
self.decoder_config = decoder_config
|
227 |
+
|
228 |
+
if diffusion_head_config is None:
|
229 |
+
self.diffusion_head_config = self.sub_configs["diffusion_head_config"]()
|
230 |
+
elif isinstance(diffusion_head_config, dict):
|
231 |
+
diffusion_head_config["model_type"] = "vibevoice_diffusion_head"
|
232 |
+
self.diffusion_head_config = self.sub_configs["diffusion_head_config"](**diffusion_head_config)
|
233 |
+
elif isinstance(diffusion_head_config, VibeVoiceDiffusionHeadConfig):
|
234 |
+
# If an instance of the config class is provided
|
235 |
+
self.diffusion_head_config = diffusion_head_config
|
236 |
+
|
237 |
+
# other parameters
|
238 |
+
self.acoustic_vae_dim = getattr(self.acoustic_tokenizer_config, 'vae_dim', 64)
|
239 |
+
self.semantic_vae_dim = getattr(self.semantic_tokenizer_config, 'vae_dim', 128)
|
240 |
+
|
241 |
+
super().__init__(**kwargs)
|
242 |
+
|
243 |
+
__all__ = [
|
244 |
+
"VibeVoiceAcousticTokenizerConfig",
|
245 |
+
"VibeVoiceSemanticTokenizerConfig",
|
246 |
+
"VibeVoiceDiffusionHeadConfig",
|
247 |
+
"VibeVoiceConfig"
|
248 |
+
]
|
modular/modeling_vibevoice.py
ADDED
@@ -0,0 +1,488 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from typing import Dict, List, Optional, Tuple, Union, Callable
|
3 |
+
from tqdm import tqdm
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
import torch.distributed as dist
|
8 |
+
|
9 |
+
from transformers.models.auto import AutoModel, AutoModelForCausalLM
|
10 |
+
|
11 |
+
from transformers.activations import ACT2FN
|
12 |
+
from transformers.modeling_outputs import CausalLMOutput, BaseModelOutputWithPast, ModelOutput
|
13 |
+
from transformers.models.llama.modeling_llama import LlamaRMSNorm
|
14 |
+
from transformers import modeling_utils
|
15 |
+
from transformers.modeling_utils import PreTrainedModel
|
16 |
+
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
|
17 |
+
from transformers.utils import logging
|
18 |
+
|
19 |
+
|
20 |
+
from .modular_vibevoice_tokenizer import VibeVoiceTokenizerStreamingCache, VibeVoiceAcousticTokenizerModel, VibeVoiceSemanticTokenizerModel
|
21 |
+
from .modular_vibevoice_diffusion_head import VibeVoiceDiffusionHead
|
22 |
+
from vibevoice.schedule.dpm_solver import DPMSolverMultistepScheduler
|
23 |
+
|
24 |
+
from .configuration_vibevoice import VibeVoiceConfig
|
25 |
+
|
26 |
+
|
27 |
+
logger = logging.get_logger(__name__)
|
28 |
+
|
29 |
+
if not hasattr(modeling_utils, "ALL_PARALLEL_STYLES") or modeling_utils.ALL_PARALLEL_STYLES is None:
|
30 |
+
modeling_utils.ALL_PARALLEL_STYLES = ["tp", "none", "colwise", "rowwise"]
|
31 |
+
|
32 |
+
@dataclass
|
33 |
+
class VibeVoiceCausalLMOutputWithPast(ModelOutput):
|
34 |
+
loss: Optional[torch.FloatTensor] = None
|
35 |
+
diffusion_loss: Optional[torch.FloatTensor] = None
|
36 |
+
speech_token_num: Optional[int] = None
|
37 |
+
logits: torch.FloatTensor = None
|
38 |
+
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
39 |
+
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
40 |
+
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
|
41 |
+
|
42 |
+
|
43 |
+
@dataclass
|
44 |
+
class VibeVoiceGenerationOutput(ModelOutput):
|
45 |
+
"""
|
46 |
+
Output type for VibeVoice generation.
|
47 |
+
|
48 |
+
Args:
|
49 |
+
sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
50 |
+
The generated sequences.
|
51 |
+
speech_outputs (`List[torch.FloatTensor]`, *optional*):
|
52 |
+
List of generated speech waveforms or latents for each speech segment.
|
53 |
+
"""
|
54 |
+
sequences: torch.LongTensor = None
|
55 |
+
speech_outputs: Optional[List[torch.FloatTensor]] = None
|
56 |
+
|
57 |
+
|
58 |
+
class SpeechConnector(nn.Module):
|
59 |
+
def __init__(self, input_dim, output_dim):
|
60 |
+
super().__init__()
|
61 |
+
self.fc1 = nn.Linear(input_dim, output_dim)
|
62 |
+
self.norm = LlamaRMSNorm(output_dim, eps=1e-6)
|
63 |
+
self.fc2 = nn.Linear(output_dim, output_dim)
|
64 |
+
|
65 |
+
def forward(self, features, **kwargs):
|
66 |
+
x = self.fc1(features)
|
67 |
+
x = self.norm(x)
|
68 |
+
x = self.fc2(x)
|
69 |
+
return x
|
70 |
+
|
71 |
+
|
72 |
+
# @auto_docstring
|
73 |
+
class VibeVoicePreTrainedModel(PreTrainedModel):
|
74 |
+
config_class = VibeVoiceConfig
|
75 |
+
base_model_prefix = "model"
|
76 |
+
supports_gradient_checkpointing = True
|
77 |
+
_skip_keys_device_placement = "past_key_values"
|
78 |
+
_supports_cache_class = True
|
79 |
+
_supports_flash_attn_2 = True
|
80 |
+
_supports_sdpa = True
|
81 |
+
_supports_quantized_cache = True
|
82 |
+
_supports_static_cache = True
|
83 |
+
_supports_attention_backend = True
|
84 |
+
|
85 |
+
def _init_weights(self, module):
|
86 |
+
if isinstance(module, VibeVoiceDiffusionHead):
|
87 |
+
module.initialize_weights()
|
88 |
+
return
|
89 |
+
|
90 |
+
# Use the language model's initializer_range if available
|
91 |
+
if hasattr(self.config, 'language_model_config') and hasattr(self.config.language_model_config, 'initializer_range'):
|
92 |
+
std = self.config.language_model_config.initializer_range
|
93 |
+
elif hasattr(self.config, 'decoder_config') and hasattr(self.config.decoder_config, 'initializer_range'):
|
94 |
+
std = self.config.decoder_config.initializer_range
|
95 |
+
else:
|
96 |
+
std = 0.02 # Default value
|
97 |
+
|
98 |
+
if isinstance(module, nn.Linear):
|
99 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
100 |
+
if module.bias is not None:
|
101 |
+
module.bias.data.zero_()
|
102 |
+
elif isinstance(module, nn.LayerNorm):
|
103 |
+
module.weight.data.fill_(1.0)
|
104 |
+
module.bias.data.zero_()
|
105 |
+
|
106 |
+
# @auto_docstring
|
107 |
+
class VibeVoiceModel(VibeVoicePreTrainedModel):
|
108 |
+
def __init__(self, config):
|
109 |
+
super().__init__(config)
|
110 |
+
|
111 |
+
if hasattr(config, 'torch_dtype') and config.torch_dtype is not None:
|
112 |
+
if isinstance(config.torch_dtype, str):
|
113 |
+
dtype = getattr(torch, config.torch_dtype)
|
114 |
+
else:
|
115 |
+
dtype = config.torch_dtype
|
116 |
+
else:
|
117 |
+
dtype = torch.float32
|
118 |
+
|
119 |
+
# Initialize Qwen2 model for language modeling
|
120 |
+
lm_config = config.decoder_config
|
121 |
+
self.language_model = AutoModel.from_config(lm_config)
|
122 |
+
|
123 |
+
# Initialize speech components if needed
|
124 |
+
self.acoustic_tokenizer = AutoModel.from_config(config.acoustic_tokenizer_config).to(dtype)
|
125 |
+
self.semantic_tokenizer = AutoModel.from_config(config.semantic_tokenizer_config).to(dtype)
|
126 |
+
|
127 |
+
self.acoustic_connector = SpeechConnector(config.acoustic_vae_dim, lm_config.hidden_size).to(dtype)
|
128 |
+
self.semantic_connector = SpeechConnector(config.semantic_vae_dim, lm_config.hidden_size).to(dtype)
|
129 |
+
|
130 |
+
# Register scaling factors as buffers - use 1D tensors for FSDP compatibility
|
131 |
+
self.register_buffer('speech_scaling_factor', torch.tensor(float('nan')))
|
132 |
+
self.register_buffer('speech_bias_factor', torch.tensor(float('nan')))
|
133 |
+
|
134 |
+
# Initialize prediction head for speech generation
|
135 |
+
self.prediction_head = AutoModel.from_config(config.diffusion_head_config).to(dtype)
|
136 |
+
|
137 |
+
# Initialize noise scheduler
|
138 |
+
self.noise_scheduler = DPMSolverMultistepScheduler(
|
139 |
+
num_train_timesteps=config.diffusion_head_config.ddpm_num_steps,
|
140 |
+
beta_schedule=config.diffusion_head_config.ddpm_beta_schedule,
|
141 |
+
prediction_type=config.diffusion_head_config.prediction_type
|
142 |
+
)
|
143 |
+
|
144 |
+
def get_input_embeddings(self):
|
145 |
+
if hasattr(self.language_model, 'embed_tokens'):
|
146 |
+
# If the language model has an embed_tokens attribute, return it
|
147 |
+
return self.language_model.embed_tokens
|
148 |
+
|
149 |
+
for name, attr in self.language_model.fullmap.items(): # parallel by nnscaler, the name is changed
|
150 |
+
if attr.orig_name == 'embed_tokens.weight':
|
151 |
+
return getattr(self.language_model, name)
|
152 |
+
assert False, 'should not arrive here'
|
153 |
+
|
154 |
+
def set_input_embeddings(self, value):
|
155 |
+
self.language_model.embed_tokens = value
|
156 |
+
|
157 |
+
def set_speech_tokenizers(self, acoustic_tokenizer=None, semantic_tokenizer=None):
|
158 |
+
"""Set the speech tokenizers used for encoding and decoding speech."""
|
159 |
+
self.acoustic_tokenizer = acoustic_tokenizer
|
160 |
+
self.semantic_tokenizer = semantic_tokenizer
|
161 |
+
|
162 |
+
# Reset the encoder to evaluation mode
|
163 |
+
if self.acoustic_tokenizer is not None:
|
164 |
+
self.acoustic_tokenizer.eval()
|
165 |
+
|
166 |
+
if self.semantic_tokenizer is not None:
|
167 |
+
self.semantic_tokenizer.eval()
|
168 |
+
|
169 |
+
def forward(
|
170 |
+
self,
|
171 |
+
input_ids: torch.LongTensor = None,
|
172 |
+
attention_mask: Optional[torch.Tensor] = None,
|
173 |
+
position_ids: Optional[torch.LongTensor] = None,
|
174 |
+
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
175 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
176 |
+
use_cache: Optional[bool] = None,
|
177 |
+
output_attentions: Optional[bool] = None,
|
178 |
+
output_hidden_states: Optional[bool] = None,
|
179 |
+
return_dict: Optional[bool] = None,
|
180 |
+
cache_position: Optional[torch.LongTensor] = None,
|
181 |
+
**kwargs,
|
182 |
+
) -> Union[Tuple, BaseModelOutputWithPast]:
|
183 |
+
|
184 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
185 |
+
|
186 |
+
# Forward through language model
|
187 |
+
outputs = self.language_model(
|
188 |
+
input_ids=input_ids,
|
189 |
+
attention_mask=attention_mask,
|
190 |
+
position_ids=position_ids,
|
191 |
+
past_key_values=past_key_values,
|
192 |
+
inputs_embeds=inputs_embeds,
|
193 |
+
use_cache=use_cache,
|
194 |
+
output_attentions=output_attentions,
|
195 |
+
output_hidden_states=output_hidden_states,
|
196 |
+
return_dict=return_dict,
|
197 |
+
cache_position=cache_position,
|
198 |
+
**kwargs,
|
199 |
+
)
|
200 |
+
|
201 |
+
if not return_dict:
|
202 |
+
return outputs
|
203 |
+
|
204 |
+
return BaseModelOutputWithPast(
|
205 |
+
last_hidden_state=outputs.last_hidden_state,
|
206 |
+
past_key_values=outputs.past_key_values,
|
207 |
+
hidden_states=outputs.hidden_states,
|
208 |
+
attentions=outputs.attentions,
|
209 |
+
)
|
210 |
+
|
211 |
+
|
212 |
+
class VibeVoiceForConditionalGeneration(VibeVoicePreTrainedModel):
|
213 |
+
_tied_weights_keys = ["lm_head.weight"]
|
214 |
+
_tp_plan = {"lm_head": "colwise_rep"}
|
215 |
+
|
216 |
+
def __init__(self, config):
|
217 |
+
super().__init__(config)
|
218 |
+
self.model = VibeVoiceModel(config)
|
219 |
+
self.vocab_size = config.decoder_config.vocab_size
|
220 |
+
self.lm_head = nn.Linear(config.decoder_config.hidden_size, self.vocab_size, bias=False)
|
221 |
+
|
222 |
+
self.post_init()
|
223 |
+
|
224 |
+
def get_input_embeddings(self):
|
225 |
+
return self.model.get_input_embeddings()
|
226 |
+
|
227 |
+
def set_input_embeddings(self, value):
|
228 |
+
self.model.set_input_embeddings(value)
|
229 |
+
|
230 |
+
def get_output_embeddings(self):
|
231 |
+
return self.lm_head
|
232 |
+
|
233 |
+
def set_decoder(self, decoder):
|
234 |
+
self.model.language_model = decoder
|
235 |
+
|
236 |
+
def get_decoder(self):
|
237 |
+
return self.model.language_model
|
238 |
+
|
239 |
+
def tie_weights(self):
|
240 |
+
"""
|
241 |
+
Tie the weights between the input embeddings and the output embeddings.
|
242 |
+
"""
|
243 |
+
if getattr(self.config.decoder_config, 'tie_word_embeddings', False):
|
244 |
+
# The standard PreTrainedModel method will handle the tying.
|
245 |
+
# It typically does a simple parameter object assignment, which is
|
246 |
+
# CORRECT to do BEFORE FSDP wraps the model.
|
247 |
+
output_embeddings = self.get_output_embeddings()
|
248 |
+
input_embeddings = self.get_input_embeddings()
|
249 |
+
if hasattr(input_embeddings, 'weight'):
|
250 |
+
output_embeddings.weight = input_embeddings.weight
|
251 |
+
else:
|
252 |
+
# maybe returned input_embeddings a tensor directly
|
253 |
+
output_embeddings.weight = input_embeddings
|
254 |
+
|
255 |
+
if getattr(output_embeddings, "bias", None) is not None:
|
256 |
+
output_embeddings.bias.data = nn.functional.pad(
|
257 |
+
output_embeddings.bias.data,
|
258 |
+
(0, output_embeddings.weight.shape[0] - output_embeddings.bias.shape[0]),
|
259 |
+
"constant",
|
260 |
+
0,
|
261 |
+
)
|
262 |
+
print("✅ Tied input and output embeddings using standard assignment.")
|
263 |
+
else:
|
264 |
+
print("ℹ️ tie_word_embeddings is False, not tying weights.")
|
265 |
+
|
266 |
+
# Also, ensure set_output_embeddings is safe, though your implementation looks okay.
|
267 |
+
# The key is to avoid calling it after accelerator.prepare().
|
268 |
+
def set_output_embeddings(self, new_embeddings):
|
269 |
+
# Your current implementation using data.copy_ is good practice,
|
270 |
+
# but the best way is to not call this after prepare().
|
271 |
+
self.lm_head = new_embeddings
|
272 |
+
|
273 |
+
def forward_speech_features(
|
274 |
+
self,
|
275 |
+
speech_tensors=None,
|
276 |
+
speech_masks=None,
|
277 |
+
speech_type="audio",
|
278 |
+
return_unmask=False
|
279 |
+
):
|
280 |
+
if speech_tensors is None:
|
281 |
+
# Use config to get vae_dim instead of non-existent self.args
|
282 |
+
vae_dim = self.config.acoustic_tokenizer_config.vae_dim
|
283 |
+
audio_features = torch.zeros(1, 1, vae_dim).to(self.get_input_embeddings().weight)
|
284 |
+
connect_features = self.model.acoustic_connector(audio_features)
|
285 |
+
return audio_features, connect_features
|
286 |
+
else:
|
287 |
+
with torch.no_grad():
|
288 |
+
if speech_type == "audio":
|
289 |
+
with torch.no_grad():
|
290 |
+
frames = self.model.acoustic_tokenizer.encode(speech_tensors.unsqueeze(1))[0][0]
|
291 |
+
audio_tokens = frames.sample(self.model.acoustic_tokenizer.std_dist_type)[0]
|
292 |
+
|
293 |
+
elif speech_type == "vae":
|
294 |
+
# Use config to get vae_dim instead of non-existent self.args
|
295 |
+
vae_dim = self.config.acoustic_tokenizer_config.vae_dim
|
296 |
+
speech_mode = speech_tensors.reshape(speech_tensors.size(0), -1, vae_dim)
|
297 |
+
|
298 |
+
# gaussian sample from the speech_mode
|
299 |
+
batch_size = speech_mode.size(0)
|
300 |
+
value = self.model.acoustic_tokenizer.fix_std / 0.8
|
301 |
+
std = torch.randn(batch_size, dtype=speech_mode.dtype, device=speech_mode.device) * value
|
302 |
+
std = std.view(-1, *[1] * (speech_mode.dim() - 1))
|
303 |
+
audio_tokens = speech_mode + std * torch.randn(speech_mode.shape).to(speech_mode)
|
304 |
+
else:
|
305 |
+
raise NotImplementedError(f"Speech type {speech_type} not implemented")
|
306 |
+
|
307 |
+
if torch.isnan(self.model.speech_scaling_factor) or torch.isnan(self.model.speech_bias_factor):
|
308 |
+
scaling_factor = 1. / audio_tokens[speech_masks].flatten().std()
|
309 |
+
bias_factor = -audio_tokens[speech_masks].flatten().mean()
|
310 |
+
|
311 |
+
# Only use distributed operations if the process group is initialized
|
312 |
+
if dist.is_available() and dist.is_initialized():
|
313 |
+
dist.all_reduce(scaling_factor, op=dist.ReduceOp.SUM)
|
314 |
+
dist.all_reduce(bias_factor, op=dist.ReduceOp.SUM)
|
315 |
+
world_size = dist.get_world_size()
|
316 |
+
self.model.speech_scaling_factor.copy_(scaling_factor / world_size)
|
317 |
+
self.model.speech_bias_factor.copy_(bias_factor / world_size)
|
318 |
+
print(f"Speech scaling factor (distributed): {self.model.speech_scaling_factor}, bias factor: {self.model.speech_bias_factor}", flush=True)
|
319 |
+
else:
|
320 |
+
# Single process case
|
321 |
+
self.model.speech_scaling_factor.copy_(scaling_factor)
|
322 |
+
self.model.speech_bias_factor.copy_(bias_factor)
|
323 |
+
print(f"Speech scaling factor (single process): {self.model.speech_scaling_factor}, bias factor: {self.model.speech_bias_factor}", flush=True)
|
324 |
+
|
325 |
+
audio_features = (audio_tokens + self.model.speech_bias_factor) * self.model.speech_scaling_factor
|
326 |
+
|
327 |
+
connect_features = self.model.acoustic_connector(audio_features)
|
328 |
+
if return_unmask:
|
329 |
+
return audio_features, connect_features
|
330 |
+
return audio_features[speech_masks], connect_features[speech_masks]
|
331 |
+
|
332 |
+
def forward(
|
333 |
+
self,
|
334 |
+
input_ids: torch.LongTensor = None,
|
335 |
+
attention_mask: Optional[torch.Tensor] = None,
|
336 |
+
position_ids: Optional[torch.LongTensor] = None,
|
337 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
338 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
339 |
+
labels: Optional[torch.LongTensor] = None,
|
340 |
+
use_cache: Optional[bool] = False,
|
341 |
+
output_attentions: Optional[bool] = None,
|
342 |
+
output_hidden_states: Optional[bool] = None,
|
343 |
+
return_dict: Optional[bool] = None,
|
344 |
+
cache_position: Optional[torch.LongTensor] = None,
|
345 |
+
# New arguments for speech processing and loss calculation
|
346 |
+
speech_tensors: Optional[torch.FloatTensor] = None,
|
347 |
+
speech_masks: Optional[torch.BoolTensor] = None,
|
348 |
+
speeches_loss_input: Optional[torch.FloatTensor] = None,
|
349 |
+
speech_semantic_tensors: Optional[torch.FloatTensor] = None,
|
350 |
+
acoustic_input_mask: Optional[torch.BoolTensor] = None,
|
351 |
+
acoustic_loss_mask: Optional[torch.BoolTensor] = None,
|
352 |
+
ddpm_batch_mul: int = 1,
|
353 |
+
**kwargs: Optional[Dict[str, Union[torch.Tensor, str]]],
|
354 |
+
) -> Union[Tuple, VibeVoiceCausalLMOutputWithPast]:
|
355 |
+
|
356 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
357 |
+
|
358 |
+
x = self.get_input_embeddings()(input_ids)
|
359 |
+
|
360 |
+
semantic_speech_all_connect_features = self.model.semantic_connector(speech_semantic_tensors)
|
361 |
+
if speeches_loss_input is not None:
|
362 |
+
# only part audio need diffuse
|
363 |
+
speech_all_features, speech_all_connect_features = self.forward_speech_features(
|
364 |
+
speech_tensors=speech_tensors.type_as(x) if speech_tensors is not None else None,
|
365 |
+
speech_masks=speech_masks,
|
366 |
+
speech_type=kwargs.get("speech_type", "audio"),
|
367 |
+
return_unmask=True
|
368 |
+
)
|
369 |
+
if speech_tensors is not None:
|
370 |
+
if semantic_speech_all_connect_features is not None:
|
371 |
+
x[acoustic_input_mask] = speech_all_connect_features[speech_masks] + semantic_speech_all_connect_features[speech_masks]
|
372 |
+
else:
|
373 |
+
x[acoustic_input_mask] = speech_all_connect_features[speech_masks]
|
374 |
+
speech_features = speech_all_features[speeches_loss_input.unsqueeze(-1) & speech_masks] # only part audio need diffuse
|
375 |
+
speech_connect_features = speech_all_connect_features[speeches_loss_input.unsqueeze(-1) & speech_masks]
|
376 |
+
else:
|
377 |
+
speech_features, speech_connect_features = self.forward_speech_features(
|
378 |
+
speech_tensors=speech_tensors.type_as(x) if speech_tensors is not None else None,
|
379 |
+
speech_masks=speech_masks,
|
380 |
+
speech_type=kwargs.get("speech_type", "audio"),
|
381 |
+
)
|
382 |
+
if speech_tensors is not None:
|
383 |
+
x[acoustic_input_mask] = speech_connect_features
|
384 |
+
|
385 |
+
outputs = self.model(
|
386 |
+
input_ids=None,
|
387 |
+
attention_mask=attention_mask,
|
388 |
+
position_ids=position_ids,
|
389 |
+
past_key_values=past_key_values,
|
390 |
+
inputs_embeds=x,
|
391 |
+
use_cache=use_cache,
|
392 |
+
output_attentions=output_attentions,
|
393 |
+
output_hidden_states=False,
|
394 |
+
return_dict=return_dict,
|
395 |
+
cache_position=cache_position,
|
396 |
+
)
|
397 |
+
|
398 |
+
hidden_states = outputs.last_hidden_state
|
399 |
+
logits = self.lm_head(hidden_states)
|
400 |
+
# logits = logits.float()
|
401 |
+
|
402 |
+
loss = None
|
403 |
+
if labels is not None:
|
404 |
+
# The custom CE loss with masking is calculated in the training script.
|
405 |
+
# We leave the standard loss calculation here as None.
|
406 |
+
pass
|
407 |
+
|
408 |
+
# --- Diffusion Loss Calculation ---
|
409 |
+
diffusion_loss = None
|
410 |
+
# This block is executed only if we are in a context that involves speech.
|
411 |
+
if speech_tensors is not None and acoustic_loss_mask.sum().item() > 0:
|
412 |
+
condition_features = hidden_states[acoustic_loss_mask]
|
413 |
+
|
414 |
+
speech_len, latent_size = speech_features.shape
|
415 |
+
|
416 |
+
noise = torch.randn(
|
417 |
+
(speech_len * ddpm_batch_mul, latent_size),
|
418 |
+
device=hidden_states.device,
|
419 |
+
dtype=hidden_states.dtype
|
420 |
+
)
|
421 |
+
|
422 |
+
timesteps = torch.multinomial(
|
423 |
+
torch.ones(self.config.diffusion_head_config.ddpm_num_steps),
|
424 |
+
speech_len * ddpm_batch_mul,
|
425 |
+
replacement=True,
|
426 |
+
).to(hidden_states.device)
|
427 |
+
|
428 |
+
speech_features_repeated = speech_features.repeat_interleave(ddpm_batch_mul, dim=0)
|
429 |
+
condition_features_repeated = condition_features.repeat_interleave(ddpm_batch_mul, dim=0)
|
430 |
+
|
431 |
+
noisy_speech_features = self.model.noise_scheduler.add_noise(
|
432 |
+
speech_features_repeated, noise, timesteps
|
433 |
+
)
|
434 |
+
|
435 |
+
model_output = self.model.prediction_head(
|
436 |
+
noisy_speech_features,
|
437 |
+
timesteps.type_as(x),
|
438 |
+
condition_features_repeated
|
439 |
+
)
|
440 |
+
|
441 |
+
prediction_type = self.config.diffusion_head_config.prediction_type
|
442 |
+
if prediction_type == "epsilon":
|
443 |
+
target_for_loss = noise
|
444 |
+
elif prediction_type == "v_prediction":
|
445 |
+
target_for_loss = self.model.noise_scheduler.get_velocity(
|
446 |
+
speech_features_repeated, noise, timesteps
|
447 |
+
)
|
448 |
+
else:
|
449 |
+
raise NotImplementedError(f"Prediction type {prediction_type} not implemented")
|
450 |
+
|
451 |
+
diffusion_loss = F.mse_loss(model_output.float(), target_for_loss.float(), reduction='sum')
|
452 |
+
if latent_size > 0 and ddpm_batch_mul > 0:
|
453 |
+
diffusion_loss = diffusion_loss / latent_size / ddpm_batch_mul
|
454 |
+
else:
|
455 |
+
diffusion_loss = torch.tensor(0.0, device=diffusion_loss.device)
|
456 |
+
|
457 |
+
else:
|
458 |
+
# Dummy loss for DDP to work when there are no speech samples in a batch,
|
459 |
+
# but we are in a speech context.
|
460 |
+
diffusion_loss = sum(p.sum() for p in self.model.prediction_head.parameters()) * 0.0
|
461 |
+
diffusion_loss += sum(p.sum() for p in self.model.acoustic_connector.parameters()) * 0.0
|
462 |
+
diffusion_loss += sum(p.sum() for p in self.model.semantic_connector.parameters()) * 0.0
|
463 |
+
# --- End Diffusion Loss Calculation ---
|
464 |
+
|
465 |
+
if not return_dict:
|
466 |
+
output = (logits, speech_len) + outputs.to_tuple()[1:]
|
467 |
+
return (loss, diffusion_loss) + output
|
468 |
+
|
469 |
+
return VibeVoiceCausalLMOutputWithPast(
|
470 |
+
loss=loss,
|
471 |
+
diffusion_loss=diffusion_loss,
|
472 |
+
speech_token_num=speech_len if speech_tensors is not None else 0,
|
473 |
+
logits=logits,
|
474 |
+
past_key_values=outputs.past_key_values,
|
475 |
+
hidden_states=outputs.hidden_states,
|
476 |
+
attentions=outputs.attentions,
|
477 |
+
)
|
478 |
+
|
479 |
+
AutoModel.register(VibeVoiceConfig, VibeVoiceModel)
|
480 |
+
AutoModelForCausalLM.register(VibeVoiceConfig, VibeVoiceForConditionalGeneration)
|
481 |
+
|
482 |
+
__all__ = [
|
483 |
+
"VibeVoiceModel",
|
484 |
+
"VibeVoicePreTrainedModel",
|
485 |
+
"VibeVoiceForConditionalGeneration",
|
486 |
+
"VibeVoiceCausalLMOutputWithPast",
|
487 |
+
"VibeVoiceGenerationOutput",
|
488 |
+
]
|
modular/modeling_vibevoice_inference.py
ADDED
@@ -0,0 +1,715 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from typing import Dict, List, Optional, Tuple, Union, Callable
|
3 |
+
from tqdm import tqdm
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
|
7 |
+
from transformers.models.auto import AutoModel, AutoModelForCausalLM
|
8 |
+
|
9 |
+
from transformers.generation import GenerationMixin, GenerationConfig, LogitsProcessor, LogitsProcessorList, StoppingCriteriaList
|
10 |
+
from transformers.modeling_outputs import BaseModelOutputWithPast, ModelOutput
|
11 |
+
from transformers import modeling_utils
|
12 |
+
from transformers.modeling_utils import PreTrainedModel
|
13 |
+
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
|
14 |
+
from transformers.utils import logging
|
15 |
+
|
16 |
+
|
17 |
+
# from .modular_vibevoice_tokenizer import VibeVoiceTokenizerStreamingCache, VibeVoiceAcousticTokenizerModel, VibeVoiceSemanticTokenizerModel
|
18 |
+
from .modular_vibevoice_tokenizer import VibeVoiceTokenizerStreamingCache, VibeVoiceTokenizerEncoderOutput
|
19 |
+
from .modular_vibevoice_diffusion_head import VibeVoiceDiffusionHead
|
20 |
+
from vibevoice.schedule.dpm_solver import DPMSolverMultistepScheduler
|
21 |
+
|
22 |
+
from .configuration_vibevoice import VibeVoiceConfig
|
23 |
+
|
24 |
+
from .modular_vibevoice_text_tokenizer import VibeVoiceTextTokenizer, VibeVoiceTextTokenizerFast
|
25 |
+
|
26 |
+
from .modeling_vibevoice import VibeVoiceModel, VibeVoicePreTrainedModel
|
27 |
+
from .streamer import AudioStreamer, AsyncAudioStreamer
|
28 |
+
|
29 |
+
logger = logging.get_logger(__name__)
|
30 |
+
|
31 |
+
if not hasattr(modeling_utils, "ALL_PARALLEL_STYLES") or modeling_utils.ALL_PARALLEL_STYLES is None:
|
32 |
+
modeling_utils.ALL_PARALLEL_STYLES = ["tp", "none", "colwise", "rowwise"]
|
33 |
+
|
34 |
+
@dataclass
|
35 |
+
class VibeVoiceCausalLMOutputWithPast(BaseModelOutputWithPast):
|
36 |
+
logits: Optional[torch.FloatTensor] = None
|
37 |
+
|
38 |
+
@dataclass
|
39 |
+
class VibeVoiceGenerationOutput(ModelOutput):
|
40 |
+
"""
|
41 |
+
Output type for VibeVoice generation.
|
42 |
+
|
43 |
+
Args:
|
44 |
+
sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
45 |
+
The generated sequences.
|
46 |
+
speech_outputs (`List[torch.FloatTensor]`, *optional*):
|
47 |
+
List of generated speech waveforms or latents for each speech segment.
|
48 |
+
"""
|
49 |
+
sequences: torch.LongTensor = None
|
50 |
+
speech_outputs: Optional[List[torch.FloatTensor]] = None
|
51 |
+
reach_max_step_sample: Optional[torch.BoolTensor] = None
|
52 |
+
|
53 |
+
class VibeVoiceTokenConstraintProcessor(LogitsProcessor):
|
54 |
+
"""Constrains token generation to only valid tokens during speech generation."""
|
55 |
+
|
56 |
+
def __init__(self, valid_token_ids: List[int], device: torch.device = None):
|
57 |
+
self.valid_token_ids = torch.tensor(valid_token_ids, dtype=torch.long, device=device)
|
58 |
+
|
59 |
+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
60 |
+
# Create a mask for valid tokens
|
61 |
+
mask = torch.full_like(scores, float('-inf'))
|
62 |
+
mask[:, self.valid_token_ids] = 0
|
63 |
+
|
64 |
+
# Apply mask to scores
|
65 |
+
scores = scores + mask
|
66 |
+
return scores
|
67 |
+
|
68 |
+
class VibeVoiceForConditionalGenerationInference(VibeVoicePreTrainedModel, GenerationMixin):
|
69 |
+
_tied_weights_keys = ["lm_head.weight"]
|
70 |
+
_tp_plan = {"lm_head": "colwise_rep"}
|
71 |
+
|
72 |
+
def __init__(self, config):
|
73 |
+
super().__init__(config)
|
74 |
+
|
75 |
+
# Initialize the base model
|
76 |
+
self.model = VibeVoiceModel(config)
|
77 |
+
|
78 |
+
# LM head for text generation
|
79 |
+
self.lm_head = nn.Linear(config.decoder_config.hidden_size, config.decoder_config.vocab_size, bias=False)
|
80 |
+
|
81 |
+
# inference configuration
|
82 |
+
self.ddpm_inference_steps = config.diffusion_head_config.ddpm_num_inference_steps
|
83 |
+
|
84 |
+
# Initialize weights and apply final processing
|
85 |
+
self.post_init()
|
86 |
+
|
87 |
+
@property
|
88 |
+
def noise_scheduler(self):
|
89 |
+
return self.model.noise_scheduler
|
90 |
+
|
91 |
+
@property
|
92 |
+
def prediction_head(self):
|
93 |
+
return self.model.prediction_head
|
94 |
+
|
95 |
+
@property
|
96 |
+
def speech_scaling_factor(self):
|
97 |
+
return self.model.speech_scaling_factor
|
98 |
+
|
99 |
+
@property
|
100 |
+
def speech_bias_factor(self):
|
101 |
+
return self.model.speech_bias_factor
|
102 |
+
|
103 |
+
@property
|
104 |
+
def acoustic_tokenizer(self):
|
105 |
+
return self.model.acoustic_tokenizer
|
106 |
+
|
107 |
+
@property
|
108 |
+
def semantic_tokenizer(self):
|
109 |
+
return self.model.semantic_tokenizer
|
110 |
+
|
111 |
+
@property
|
112 |
+
def acoustic_connector(self):
|
113 |
+
return self.model.acoustic_connector
|
114 |
+
|
115 |
+
@property
|
116 |
+
def semantic_connector(self):
|
117 |
+
return self.model.semantic_connector
|
118 |
+
|
119 |
+
def tie_weights(self):
|
120 |
+
"""
|
121 |
+
Tie the weights between the input embeddings and the output embeddings.
|
122 |
+
"""
|
123 |
+
# Tie lm_head.weight to language_model.embed_tokens.weight
|
124 |
+
if not getattr(self.config, 'tie_word_embeddings', False):
|
125 |
+
return
|
126 |
+
|
127 |
+
if hasattr(self, 'lm_head') and hasattr(self.model.language_model, 'embed_tokens'):
|
128 |
+
self.lm_head.weight = self.model.language_model.embed_tokens.weight
|
129 |
+
|
130 |
+
def get_input_embeddings(self):
|
131 |
+
return self.model.get_input_embeddings()
|
132 |
+
|
133 |
+
def set_input_embeddings(self, value):
|
134 |
+
self.model.set_input_embeddings(value)
|
135 |
+
|
136 |
+
def get_output_embeddings(self):
|
137 |
+
return self.lm_head
|
138 |
+
|
139 |
+
def set_output_embeddings(self, new_embeddings):
|
140 |
+
self.lm_head = new_embeddings
|
141 |
+
|
142 |
+
def set_speech_tokenizers(self, acoustic_tokenizer=None, semantic_tokenizer=None):
|
143 |
+
"""Set the speech tokenizers used for encoding and decoding speech."""
|
144 |
+
self.model.set_speech_tokenizers(acoustic_tokenizer, semantic_tokenizer)
|
145 |
+
|
146 |
+
def set_ddpm_inference_steps(self, num_steps=None):
|
147 |
+
self.ddpm_inference_steps = num_steps or self.config.diffusion_head_config.ddpm_num_inference_steps
|
148 |
+
|
149 |
+
def _process_speech_inputs(self, speech_tensors, speech_masks, speech_type="audio"):
|
150 |
+
"""Process speech inputs through tokenizers and connectors."""
|
151 |
+
with torch.no_grad():
|
152 |
+
if speech_type == "audio":
|
153 |
+
# Encode audio to acoustic latents
|
154 |
+
encoder_output = self.model.acoustic_tokenizer.encode(speech_tensors.unsqueeze(1))
|
155 |
+
acoustic_latents = encoder_output.sample(dist_type=self.model.acoustic_tokenizer.std_dist_type)[0]
|
156 |
+
|
157 |
+
# Apply scaling and bias
|
158 |
+
acoustic_features = (acoustic_latents + self.model.speech_bias_factor.to(acoustic_latents.device)) * self.model.speech_scaling_factor.to(acoustic_latents.device)
|
159 |
+
|
160 |
+
# Connect to language model space
|
161 |
+
acoustic_connected = self.model.acoustic_connector(acoustic_features)[speech_masks.cpu()]
|
162 |
+
|
163 |
+
return acoustic_features, acoustic_connected
|
164 |
+
elif speech_type == "pt":
|
165 |
+
encoder_output = VibeVoiceTokenizerEncoderOutput(mean=speech_tensors, std=self.acoustic_tokenizer.config.fix_std)
|
166 |
+
acoustic_latents = encoder_output.sample(dist_type=self.model.acoustic_tokenizer.std_dist_type)[0]
|
167 |
+
|
168 |
+
# Apply scaling and bias
|
169 |
+
acoustic_features = (acoustic_latents + self.model.speech_bias_factor.to(acoustic_latents.device)) * self.model.speech_scaling_factor.to(acoustic_latents.device)
|
170 |
+
|
171 |
+
# Connect to language model space
|
172 |
+
acoustic_connected = self.model.acoustic_connector(acoustic_features)[speech_masks.cpu()]
|
173 |
+
|
174 |
+
return acoustic_features, acoustic_connected
|
175 |
+
else:
|
176 |
+
raise NotImplementedError(f"Speech type {speech_type} not implemented")
|
177 |
+
|
178 |
+
# @can_return_tuple
|
179 |
+
def forward(
|
180 |
+
self,
|
181 |
+
input_ids: torch.LongTensor = None,
|
182 |
+
attention_mask: Optional[torch.Tensor] = None,
|
183 |
+
position_ids: Optional[torch.LongTensor] = None,
|
184 |
+
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
185 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
186 |
+
labels: Optional[torch.LongTensor] = None,
|
187 |
+
use_cache: Optional[bool] = None,
|
188 |
+
output_attentions: Optional[bool] = None,
|
189 |
+
output_hidden_states: Optional[bool] = None,
|
190 |
+
return_dict: Optional[bool] = None,
|
191 |
+
cache_position: Optional[torch.LongTensor] = None,
|
192 |
+
speech_tensors: Optional[torch.FloatTensor] = None,
|
193 |
+
speech_masks: Optional[torch.BoolTensor] = None,
|
194 |
+
speech_input_mask: Optional[torch.BoolTensor] = None,
|
195 |
+
logits_to_keep: Union[int, slice] = 0,
|
196 |
+
**kwargs,
|
197 |
+
) -> Union[Tuple, VibeVoiceCausalLMOutputWithPast]:
|
198 |
+
"""
|
199 |
+
Args:
|
200 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
201 |
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
202 |
+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
203 |
+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
204 |
+
speech_tensors (`torch.FloatTensor`, *optional*):
|
205 |
+
Input speech waveforms for voice cloning or speech understanding.
|
206 |
+
speech_masks (`torch.BoolTensor`, *optional*):
|
207 |
+
Masks indicating valid speech frames.
|
208 |
+
speech_input_mask (`torch.BoolTensor`, *optional*):
|
209 |
+
Positions in the input sequence where speech embeddings should be inserted.
|
210 |
+
|
211 |
+
Returns:
|
212 |
+
`VibeVoiceCausalLMOutputWithPast` or tuple
|
213 |
+
"""
|
214 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
215 |
+
|
216 |
+
# Get embeddings
|
217 |
+
if inputs_embeds is None:
|
218 |
+
inputs_embeds = self.model.get_input_embeddings()(input_ids)
|
219 |
+
|
220 |
+
# Process speech inputs if provided
|
221 |
+
if speech_tensors is not None and speech_masks is not None:
|
222 |
+
acoustic_features, speech_embeds = self._process_speech_inputs(speech_tensors.to(self.dtype), speech_masks)
|
223 |
+
if speech_input_mask is not None:
|
224 |
+
inputs_embeds[speech_input_mask] = speech_embeds
|
225 |
+
|
226 |
+
outputs = self.model(
|
227 |
+
inputs_embeds=inputs_embeds,
|
228 |
+
attention_mask=attention_mask,
|
229 |
+
position_ids=position_ids,
|
230 |
+
past_key_values=past_key_values,
|
231 |
+
use_cache=use_cache,
|
232 |
+
output_attentions=output_attentions,
|
233 |
+
output_hidden_states=output_hidden_states,
|
234 |
+
return_dict=return_dict,
|
235 |
+
cache_position=cache_position,
|
236 |
+
**kwargs,
|
237 |
+
)
|
238 |
+
|
239 |
+
hidden_states = outputs[0] if not return_dict else outputs.last_hidden_state
|
240 |
+
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
241 |
+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
242 |
+
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
243 |
+
|
244 |
+
if labels is not None:
|
245 |
+
raise NotImplementedError("Loss computation is not implemented in this version.")
|
246 |
+
|
247 |
+
return VibeVoiceCausalLMOutputWithPast(
|
248 |
+
logits=logits,
|
249 |
+
past_key_values=outputs.past_key_values,
|
250 |
+
last_hidden_state=hidden_states,
|
251 |
+
attentions=outputs.attentions,
|
252 |
+
)
|
253 |
+
|
254 |
+
def _build_generate_config_model_kwargs(self, generation_config, inputs, tokenizer, return_processors=False, **kwargs):
|
255 |
+
if generation_config is None:
|
256 |
+
generation_config = GenerationConfig(
|
257 |
+
bos_token_id=tokenizer.bos_token_id,
|
258 |
+
eos_token_id=tokenizer.eos_token_id,
|
259 |
+
pad_token_id = tokenizer.pad_token_id
|
260 |
+
)
|
261 |
+
else:
|
262 |
+
generation_config = GenerationConfig(
|
263 |
+
**generation_config,
|
264 |
+
bos_token_id=tokenizer.bos_token_id,
|
265 |
+
eos_token_id=tokenizer.eos_token_id,
|
266 |
+
pad_token_id = tokenizer.pad_token_id
|
267 |
+
)
|
268 |
+
|
269 |
+
generation_config, model_kwargs = self._prepare_generation_config(
|
270 |
+
generation_config,
|
271 |
+
True,
|
272 |
+
speech_start_id=tokenizer.speech_start_id,
|
273 |
+
speech_end_id=tokenizer.speech_end_id,
|
274 |
+
speech_diffusion_id=tokenizer.speech_diffusion_id,
|
275 |
+
**kwargs
|
276 |
+
)
|
277 |
+
generation_config.speech_start_id = tokenizer.speech_start_id
|
278 |
+
generation_config.speech_end_id = tokenizer.speech_end_id
|
279 |
+
generation_config.speech_diffusion_id = tokenizer.speech_diffusion_id
|
280 |
+
|
281 |
+
inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(inputs, generation_config.bos_token_id, model_kwargs)
|
282 |
+
batch_size = inputs_tensor.shape[0]
|
283 |
+
device = self.device
|
284 |
+
|
285 |
+
self._prepare_special_tokens(generation_config, True, device=device)
|
286 |
+
generation_config.use_cache = True
|
287 |
+
model_kwargs["use_cache"] = generation_config.use_cache
|
288 |
+
input_ids = inputs_tensor.to(self.device)
|
289 |
+
|
290 |
+
input_ids_length = input_ids.shape[1]
|
291 |
+
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
|
292 |
+
has_default_min_length = kwargs.get("min_length") is None and generation_config.min_length is not None
|
293 |
+
generation_config = self._prepare_generated_length(
|
294 |
+
generation_config=generation_config,
|
295 |
+
has_default_max_length=has_default_max_length,
|
296 |
+
has_default_min_length=has_default_min_length,
|
297 |
+
model_input_name=model_input_name,
|
298 |
+
inputs_tensor=inputs_tensor,
|
299 |
+
input_ids_length=input_ids_length,
|
300 |
+
)
|
301 |
+
|
302 |
+
max_cache_length = generation_config.max_length - 1
|
303 |
+
self._prepare_cache_for_generation(generation_config, model_kwargs, None, batch_size, max_cache_length, device)
|
304 |
+
model_kwargs['cache_position'] = torch.arange(input_ids_length, device=device, dtype=torch.long)
|
305 |
+
for k, v in model_kwargs.items():
|
306 |
+
if isinstance(v, torch.Tensor):
|
307 |
+
model_kwargs[k] = v.to(device=device)
|
308 |
+
|
309 |
+
if return_processors:
|
310 |
+
logits_processor = self._get_logits_processor(
|
311 |
+
generation_config=generation_config,
|
312 |
+
input_ids_seq_length=input_ids_length,
|
313 |
+
encoder_input_ids=inputs_tensor,
|
314 |
+
prefix_allowed_tokens_fn=None,
|
315 |
+
logits_processor=LogitsProcessorList(),
|
316 |
+
device=inputs_tensor.device,
|
317 |
+
model_kwargs=model_kwargs,
|
318 |
+
)
|
319 |
+
|
320 |
+
stopping_criteria = self._get_stopping_criteria(generation_config=generation_config, stopping_criteria=StoppingCriteriaList())
|
321 |
+
|
322 |
+
return generation_config, model_kwargs, input_ids, logits_processor, stopping_criteria
|
323 |
+
else:
|
324 |
+
return generation_config, model_kwargs, input_ids
|
325 |
+
|
326 |
+
@torch.no_grad()
|
327 |
+
def generate(
|
328 |
+
self,
|
329 |
+
inputs: Optional[torch.Tensor] = None,
|
330 |
+
generation_config: Optional[GenerationConfig] = None,
|
331 |
+
logits_processor: Optional[LogitsProcessorList] = None,
|
332 |
+
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
333 |
+
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
|
334 |
+
synced_gpus: Optional[bool] = None,
|
335 |
+
assistant_model: Optional["PreTrainedModel"] = None,
|
336 |
+
audio_streamer: Optional[Union[AudioStreamer, AsyncAudioStreamer]] = None,
|
337 |
+
negative_prompt_ids: Optional[torch.Tensor] = None,
|
338 |
+
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
339 |
+
speech_tensors: Optional[torch.FloatTensor] = None,
|
340 |
+
speech_masks: Optional[torch.BoolTensor] = None,
|
341 |
+
speech_input_mask: Optional[torch.BoolTensor] = None,
|
342 |
+
return_speech: bool = True,
|
343 |
+
cfg_scale: float = 1.0,
|
344 |
+
stop_check_fn: Optional[Callable[[], bool]] = None,
|
345 |
+
**kwargs,
|
346 |
+
) -> Union[torch.LongTensor, VibeVoiceGenerationOutput]:
|
347 |
+
"""
|
348 |
+
Generates sequences of token ids and optionally speech outputs.
|
349 |
+
|
350 |
+
Args:
|
351 |
+
All standard generation arguments from GenerationMixin
|
352 |
+
negative_prompt_ids: Negative prompt for CFG in speech generation
|
353 |
+
negative_prompt_attention_mask: Attention mask for negative prompt
|
354 |
+
speech_tensors: Input speech for voice cloning
|
355 |
+
speech_masks: Masks for speech tensors
|
356 |
+
speech_input_mask: Positions to insert speech embeddings
|
357 |
+
return_speech: Whether to decode and return speech outputs
|
358 |
+
cfg_scale: CFG scale for speech generation
|
359 |
+
stop_check_fn: Optional callable that returns True if generation should stop
|
360 |
+
|
361 |
+
Returns:
|
362 |
+
Generated token sequences and optionally speech outputs
|
363 |
+
"""
|
364 |
+
# 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
|
365 |
+
tokenizer = kwargs.pop("tokenizer", None) # Pull this out first, we only use it for stopping criteria
|
366 |
+
parsed_scripts = kwargs.pop("parsed_scripts", None)
|
367 |
+
all_speakers_list = kwargs.pop("all_speakers_list", None)
|
368 |
+
max_length_times = kwargs.pop("max_length_times", 2)
|
369 |
+
|
370 |
+
if kwargs.get('max_new_tokens', None) is None:
|
371 |
+
kwargs['max_new_tokens'] = self.config.decoder_config.max_position_embeddings - kwargs['input_ids'].shape[-1]
|
372 |
+
|
373 |
+
generation_config, model_kwargs, input_ids, logits_processor, stopping_criteria = self._build_generate_config_model_kwargs(
|
374 |
+
generation_config, inputs, tokenizer, return_processors=True, **kwargs
|
375 |
+
)
|
376 |
+
|
377 |
+
negative_kwargs = {
|
378 |
+
'input_ids': torch.full((kwargs['input_ids'].shape[0], 1), tokenizer.speech_start_id, dtype=torch.long, device=kwargs['input_ids'].device),
|
379 |
+
'attention_mask': torch.ones((kwargs['input_ids'].shape[0], 1), dtype=torch.long, device=kwargs['input_ids'].device),
|
380 |
+
'max_new_tokens': kwargs.get('max_new_tokens', 100)
|
381 |
+
}
|
382 |
+
negative_generation_config, negative_model_kwargs, negative_input_ids = self._build_generate_config_model_kwargs(
|
383 |
+
None, None, tokenizer, return_processors=False, **negative_kwargs
|
384 |
+
)
|
385 |
+
|
386 |
+
acoustic_cache = VibeVoiceTokenizerStreamingCache()
|
387 |
+
semantic_cache = VibeVoiceTokenizerStreamingCache()
|
388 |
+
|
389 |
+
batch_size = input_ids.shape[0]
|
390 |
+
device = input_ids.device
|
391 |
+
finished_tags = torch.zeros(batch_size, dtype=torch.bool, device=device)
|
392 |
+
correct_cnt = torch.zeros(batch_size, dtype=torch.long, device=device)
|
393 |
+
is_prefill = True
|
394 |
+
inputs_embeds = None
|
395 |
+
verbose = kwargs.get("verbose", False)
|
396 |
+
|
397 |
+
# Initialize audio chunks storage for each sample
|
398 |
+
audio_chunks = [[] for _ in range(batch_size)]
|
399 |
+
|
400 |
+
initial_length = input_ids.shape[-1]
|
401 |
+
initial_length_per_sample = model_kwargs['attention_mask'].sum(dim=-1)
|
402 |
+
|
403 |
+
# Define all valid tokens that can be generated
|
404 |
+
valid_tokens = [
|
405 |
+
generation_config.speech_start_id,
|
406 |
+
generation_config.speech_end_id,
|
407 |
+
generation_config.speech_diffusion_id,
|
408 |
+
generation_config.eos_token_id
|
409 |
+
]
|
410 |
+
# Add bos_token_id if it exists
|
411 |
+
if hasattr(generation_config, 'bos_token_id') and generation_config.bos_token_id is not None:
|
412 |
+
valid_tokens.append(generation_config.bos_token_id)
|
413 |
+
|
414 |
+
# Add custom processor to constrain token generation
|
415 |
+
token_constraint_processor = VibeVoiceTokenConstraintProcessor(valid_tokens, device=device)
|
416 |
+
if logits_processor is None:
|
417 |
+
logits_processor = LogitsProcessorList()
|
418 |
+
logits_processor.append(token_constraint_processor)
|
419 |
+
|
420 |
+
max_steps = min(generation_config.max_length - initial_length, int(max_length_times * initial_length))
|
421 |
+
max_step_per_sample = torch.min(generation_config.max_length - initial_length_per_sample, (max_length_times * initial_length_per_sample).long())
|
422 |
+
reach_max_step_sample = torch.zeros(batch_size, dtype=torch.bool, device=device)
|
423 |
+
|
424 |
+
# Create progress iterator if verbose
|
425 |
+
if kwargs.get("show_progress_bar", True):
|
426 |
+
progress_bar = tqdm(range(max_steps), desc="Generating", leave=False)
|
427 |
+
else:
|
428 |
+
progress_bar = range(max_steps)
|
429 |
+
|
430 |
+
for step in progress_bar:
|
431 |
+
# Check for external stop signal
|
432 |
+
if stop_check_fn is not None and stop_check_fn():
|
433 |
+
if verbose:
|
434 |
+
print(f"Generation stopped externally at step {step + 1}")
|
435 |
+
# End the audio streamer if it exists
|
436 |
+
if audio_streamer is not None:
|
437 |
+
audio_streamer.end()
|
438 |
+
break
|
439 |
+
|
440 |
+
# Check if audio_streamer has been ended (stopped externally)
|
441 |
+
if audio_streamer is not None and hasattr(audio_streamer, 'finished_flags'):
|
442 |
+
if any(audio_streamer.finished_flags):
|
443 |
+
if verbose:
|
444 |
+
print(f"Audio generation stopped externally at step {step + 1}")
|
445 |
+
break
|
446 |
+
|
447 |
+
if finished_tags.all():
|
448 |
+
if hasattr(progress_bar, 'set_description'):
|
449 |
+
progress_bar.set_description("Generation complete")
|
450 |
+
break
|
451 |
+
|
452 |
+
if input_ids.shape[-1] >= generation_config.max_length:
|
453 |
+
print(f"Reached maximum generation length {generation_config.max_length}, stopped it.")
|
454 |
+
reached_samples = torch.arange(batch_size, device=device)[~finished_tags]
|
455 |
+
if reached_samples.numel() > 0:
|
456 |
+
reach_max_step_sample[reached_samples] = True
|
457 |
+
break
|
458 |
+
|
459 |
+
# Update progress bar description with active samples
|
460 |
+
if hasattr(progress_bar, 'set_description'):
|
461 |
+
active_samples = (~finished_tags).sum().item()
|
462 |
+
progress_bar.set_description(f"Generating (active: {active_samples}/{batch_size})")
|
463 |
+
|
464 |
+
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
465 |
+
if is_prefill:
|
466 |
+
# we process the speech inputs only during the first generation step
|
467 |
+
prefill_inputs = {
|
468 |
+
"speech_tensors": speech_tensors.to(device=device),
|
469 |
+
"speech_masks": speech_masks.to(device),
|
470 |
+
"speech_input_mask": speech_input_mask.to(device),
|
471 |
+
}
|
472 |
+
is_prefill = False
|
473 |
+
else:
|
474 |
+
_ = model_inputs.pop('inputs_embeds', None)
|
475 |
+
prefill_inputs = {'inputs_embeds': inputs_embeds}
|
476 |
+
|
477 |
+
# Forward pass through the model
|
478 |
+
outputs = self(
|
479 |
+
**model_inputs, **prefill_inputs, logits_to_keep=1, return_dict=True, output_attentions=False, output_hidden_states=False,
|
480 |
+
)
|
481 |
+
model_kwargs = self._update_model_kwargs_for_generation(
|
482 |
+
outputs, model_kwargs, is_encoder_decoder=False,
|
483 |
+
)
|
484 |
+
|
485 |
+
# Get logits and apply logits processor
|
486 |
+
next_token_logits = outputs.logits[:, -1, :].to(copy=True, dtype=torch.float32, device=input_ids.device)
|
487 |
+
# next_token_logits = outputs.logits[:, -1, :].to(copy=True, device=input_ids.device)
|
488 |
+
next_token_scores = logits_processor(input_ids, next_token_logits)
|
489 |
+
|
490 |
+
# token selection
|
491 |
+
if generation_config.do_sample:
|
492 |
+
probs = nn.functional.softmax(next_token_scores, dim=-1)
|
493 |
+
# TODO (joao): this OP throws "skipping cudagraphs due to ['incompatible ops']", find solution
|
494 |
+
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
|
495 |
+
else:
|
496 |
+
next_tokens = torch.argmax(next_token_scores, dim=-1)
|
497 |
+
|
498 |
+
next_tokens[finished_tags] = generation_config.eos_token_id
|
499 |
+
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
|
500 |
+
|
501 |
+
if not kwargs.get('refresh_negative', True):
|
502 |
+
negative_model_inputs = self.prepare_inputs_for_generation(negative_input_ids, **negative_model_kwargs)
|
503 |
+
# Forward negative pass through the model
|
504 |
+
if negative_model_inputs['inputs_embeds'] is None and inputs_embeds is not None:
|
505 |
+
negative_model_inputs['inputs_embeds'] = inputs_embeds
|
506 |
+
negative_model_inputs['input_ids'] = None
|
507 |
+
|
508 |
+
negative_outputs = self(
|
509 |
+
**negative_model_inputs, logits_to_keep=0, return_dict=True, output_attentions=False, output_hidden_states=False,
|
510 |
+
)
|
511 |
+
negative_model_kwargs = self._update_model_kwargs_for_generation(
|
512 |
+
negative_outputs, negative_model_kwargs, is_encoder_decoder=False,
|
513 |
+
)
|
514 |
+
negative_input_ids = torch.cat([negative_input_ids, next_tokens[:, None]], dim=-1)
|
515 |
+
|
516 |
+
# reached end of generation
|
517 |
+
if (next_tokens == generation_config.eos_token_id).any():
|
518 |
+
eos_indices = (next_tokens == generation_config.eos_token_id).nonzero(as_tuple=False).squeeze(1)
|
519 |
+
# Only print for samples that are newly finished (not already marked as finished)
|
520 |
+
new_eos_indices = eos_indices[~finished_tags[eos_indices]]
|
521 |
+
if new_eos_indices.numel() > 0:
|
522 |
+
finished_tags[new_eos_indices] = True
|
523 |
+
if verbose:
|
524 |
+
print(f"Samples {new_eos_indices.tolist()} reached EOS token at step {step + 1}.", flush=True)
|
525 |
+
if audio_streamer is not None:
|
526 |
+
audio_streamer.end(new_eos_indices)
|
527 |
+
|
528 |
+
# Check if any sample reached its maximum generation length
|
529 |
+
max_length_reached = step >= max_step_per_sample
|
530 |
+
new_max_length_indices = torch.nonzero(max_length_reached & ~finished_tags, as_tuple=False).squeeze(1)
|
531 |
+
if new_max_length_indices.numel() > 0:
|
532 |
+
finished_tags[new_max_length_indices] = True
|
533 |
+
reach_max_step_sample[new_max_length_indices] = True
|
534 |
+
if verbose:
|
535 |
+
print(f"Samples {new_max_length_indices.tolist()} reached max generation length at step {step + 1}.", flush=True)
|
536 |
+
if audio_streamer is not None:
|
537 |
+
audio_streamer.end(new_max_length_indices)
|
538 |
+
|
539 |
+
# speech_end
|
540 |
+
diffusion_end_indices = (next_tokens == generation_config.speech_end_id).nonzero(as_tuple=False).squeeze(1)
|
541 |
+
if diffusion_end_indices.numel() > 0:
|
542 |
+
# Clear tokenizer caches for samples that reached speech end
|
543 |
+
acoustic_cache.set_to_zero(diffusion_end_indices)
|
544 |
+
semantic_cache.set_to_zero(diffusion_end_indices)
|
545 |
+
|
546 |
+
# speech_begin
|
547 |
+
diffusion_start_indices = torch.arange(batch_size, device=device)[~finished_tags & (next_tokens == generation_config.speech_start_id)]
|
548 |
+
if diffusion_start_indices.numel() > 0 and kwargs.get('refresh_negative', True):
|
549 |
+
# update attention mask
|
550 |
+
for i, sample_idx in enumerate(diffusion_start_indices.tolist()):
|
551 |
+
negative_model_kwargs['attention_mask'][sample_idx, :] = 0
|
552 |
+
negative_model_kwargs['attention_mask'][sample_idx, -1] = 1
|
553 |
+
# update past key values
|
554 |
+
for layer_idx, (k_cache, v_cache) in enumerate(zip(negative_model_kwargs['past_key_values'].key_cache,
|
555 |
+
negative_model_kwargs['past_key_values'].value_cache)):
|
556 |
+
# Process each non-diffusion sample
|
557 |
+
for sample_idx in diffusion_start_indices.tolist():
|
558 |
+
# Shift cache for this sample
|
559 |
+
k_cache[sample_idx, :, -1, :] = k_cache[sample_idx, :, 0, :].clone()
|
560 |
+
v_cache[sample_idx, :, -1, :] = v_cache[sample_idx, :, 0, :].clone()
|
561 |
+
# update negative_input_ids
|
562 |
+
for sample_idx in diffusion_start_indices.tolist():
|
563 |
+
negative_input_ids[sample_idx, -1] = generation_config.speech_start_id
|
564 |
+
|
565 |
+
# Prepare inputs_embeds for next iteration
|
566 |
+
# Initialize with default embeddings for all tokens
|
567 |
+
next_inputs_embeds = self.model.get_input_embeddings()(next_tokens).unsqueeze(1) # [batch_size, 1, hidden_size]
|
568 |
+
|
569 |
+
# forward diffusion
|
570 |
+
# Diffusion indices are those that are not finished and not special tokens
|
571 |
+
diffusion_indices = torch.arange(batch_size, device=device)[~finished_tags & (next_tokens == generation_config.speech_diffusion_id)]
|
572 |
+
|
573 |
+
if diffusion_indices.numel() > 0:
|
574 |
+
if kwargs.get('refresh_negative', True):
|
575 |
+
negative_model_inputs = self.prepare_inputs_for_generation(negative_input_ids, **negative_model_kwargs)
|
576 |
+
# Forward negative pass through the model
|
577 |
+
if negative_model_inputs['inputs_embeds'] is None and inputs_embeds is not None:
|
578 |
+
negative_model_inputs['inputs_embeds'] = inputs_embeds
|
579 |
+
negative_model_inputs['input_ids'] = None
|
580 |
+
|
581 |
+
negative_outputs = self(
|
582 |
+
**negative_model_inputs, logits_to_keep=0, return_dict=True, output_attentions=False, output_hidden_states=False,
|
583 |
+
)
|
584 |
+
negative_model_kwargs = self._update_model_kwargs_for_generation(
|
585 |
+
negative_outputs, negative_model_kwargs, is_encoder_decoder=False,
|
586 |
+
)
|
587 |
+
negative_input_ids = torch.cat([negative_input_ids, next_tokens[:, None]], dim=-1)
|
588 |
+
# correct the non-diffusion indices
|
589 |
+
# we forward all samples' negative outputs even if
|
590 |
+
# they are not in diffusion mode to keep the cache consistent
|
591 |
+
# So we need to correct the kv cache of non-diffusion samples
|
592 |
+
non_diffusion_mask = ~finished_tags & (next_tokens != generation_config.speech_diffusion_id)
|
593 |
+
if non_diffusion_mask.any():
|
594 |
+
non_diffusion_indices = torch.arange(batch_size, device=device)[non_diffusion_mask]
|
595 |
+
start_indices = correct_cnt[non_diffusion_indices]
|
596 |
+
|
597 |
+
# 1. Update attention_mask - need to handle each sample separately
|
598 |
+
seq_len = negative_model_kwargs['attention_mask'].shape[1]
|
599 |
+
for i, (sample_idx, start_idx) in enumerate(zip(non_diffusion_indices.tolist(), start_indices.tolist())):
|
600 |
+
# Shift the attention mask for this sample
|
601 |
+
if start_idx + 1 < seq_len - 1:
|
602 |
+
negative_model_kwargs['attention_mask'][sample_idx, start_idx+1:] = \
|
603 |
+
negative_model_kwargs['attention_mask'][sample_idx, start_idx:-1].clone()
|
604 |
+
negative_model_kwargs['attention_mask'][sample_idx, start_idx] = 0
|
605 |
+
|
606 |
+
# 2. Update past_key_values
|
607 |
+
for layer_idx, (k_cache, v_cache) in enumerate(zip(negative_model_kwargs['past_key_values'].key_cache,
|
608 |
+
negative_model_kwargs['past_key_values'].value_cache)):
|
609 |
+
# Process each non-diffusion sample
|
610 |
+
for sample_idx, start_idx in zip(non_diffusion_indices.tolist(), start_indices.tolist()):
|
611 |
+
if start_idx + 1 < k_cache.shape[2] - 1:
|
612 |
+
# Shift cache for this sample
|
613 |
+
k_cache[sample_idx, :, start_idx+1:, :] = k_cache[sample_idx, :, start_idx:-1, :].clone()
|
614 |
+
v_cache[sample_idx, :, start_idx+1:, :] = v_cache[sample_idx, :, start_idx:-1, :].clone()
|
615 |
+
|
616 |
+
# 3. Update negative_input_ids
|
617 |
+
for sample_idx, start_idx in zip(non_diffusion_indices.tolist(), start_indices.tolist()):
|
618 |
+
if start_idx + 1 < negative_input_ids.shape[1] - 1:
|
619 |
+
negative_input_ids[sample_idx, start_idx+1:] = \
|
620 |
+
negative_input_ids[sample_idx, start_idx:-1].clone()
|
621 |
+
|
622 |
+
correct_cnt[non_diffusion_indices] += 1
|
623 |
+
|
624 |
+
positive_condition = outputs.last_hidden_state[diffusion_indices, -1, :]
|
625 |
+
negative_condition = negative_outputs.last_hidden_state[diffusion_indices, -1, :]
|
626 |
+
|
627 |
+
speech_latent = self.sample_speech_tokens(
|
628 |
+
positive_condition,
|
629 |
+
negative_condition,
|
630 |
+
cfg_scale=cfg_scale,
|
631 |
+
).unsqueeze(1)
|
632 |
+
|
633 |
+
# Decode acoustic latent to audio using acoustic streaming cache
|
634 |
+
scaled_latent = speech_latent / self.model.speech_scaling_factor.to(speech_latent.device) - self.model.speech_bias_factor.to(speech_latent.device)
|
635 |
+
audio_chunk = self.model.acoustic_tokenizer.decode(
|
636 |
+
scaled_latent.to(self.model.acoustic_tokenizer.device),
|
637 |
+
cache=acoustic_cache, # Use acoustic-specific cache
|
638 |
+
sample_indices=diffusion_indices.to(self.model.acoustic_tokenizer.device),
|
639 |
+
use_cache=True,
|
640 |
+
debug=False
|
641 |
+
)
|
642 |
+
|
643 |
+
# Store audio chunks for each sample
|
644 |
+
for i, sample_idx in enumerate(diffusion_indices):
|
645 |
+
idx = sample_idx.item()
|
646 |
+
# Only append audio chunk if the sample is not finished
|
647 |
+
if not finished_tags[idx]:
|
648 |
+
audio_chunks[idx].append(audio_chunk[i])
|
649 |
+
|
650 |
+
# Add streaming support here
|
651 |
+
if audio_streamer is not None:
|
652 |
+
# Stream the audio chunks immediately
|
653 |
+
audio_streamer.put(audio_chunk, diffusion_indices)
|
654 |
+
|
655 |
+
# Encode audio to semantic features using semantic streaming cache
|
656 |
+
semantic_features = self.model.semantic_tokenizer.encode(
|
657 |
+
audio_chunk,
|
658 |
+
cache=semantic_cache, # Use semantic-specific cache
|
659 |
+
sample_indices=diffusion_indices,
|
660 |
+
use_cache=True,
|
661 |
+
debug=False
|
662 |
+
).mean # semantic tokenizer has no VAE.
|
663 |
+
|
664 |
+
# Combine acoustic and semantic features for next input
|
665 |
+
acoustic_embed = self.model.acoustic_connector(speech_latent)
|
666 |
+
semantic_embed = self.model.semantic_connector(semantic_features)
|
667 |
+
diffusion_embeds = acoustic_embed + semantic_embed
|
668 |
+
|
669 |
+
# Update embeddings for diffusion indices
|
670 |
+
next_inputs_embeds[diffusion_indices] = diffusion_embeds
|
671 |
+
|
672 |
+
# Set inputs_embeds for next iteration
|
673 |
+
inputs_embeds = next_inputs_embeds
|
674 |
+
|
675 |
+
if audio_streamer is not None:
|
676 |
+
audio_streamer.end()
|
677 |
+
|
678 |
+
# Concatenate audio chunks for each sample
|
679 |
+
final_audio_outputs = []
|
680 |
+
for sample_chunks in audio_chunks:
|
681 |
+
if sample_chunks:
|
682 |
+
# Concatenate all chunks along the time dimension (assumed to be the last dimension)
|
683 |
+
concatenated_audio = torch.cat(sample_chunks, dim=-1)
|
684 |
+
final_audio_outputs.append(concatenated_audio)
|
685 |
+
else:
|
686 |
+
# If no audio was generated for this sample, append None
|
687 |
+
final_audio_outputs.append(None)
|
688 |
+
|
689 |
+
return VibeVoiceGenerationOutput(
|
690 |
+
sequences=input_ids,
|
691 |
+
speech_outputs=final_audio_outputs if return_speech else None,
|
692 |
+
reach_max_step_sample=reach_max_step_sample,
|
693 |
+
)
|
694 |
+
|
695 |
+
@torch.no_grad()
|
696 |
+
def sample_speech_tokens(self, condition, neg_condition, cfg_scale=3.0):
|
697 |
+
self.model.noise_scheduler.set_timesteps(self.ddpm_inference_steps)
|
698 |
+
condition = torch.cat([condition, neg_condition], dim=0).to(self.model.prediction_head.device)
|
699 |
+
speech = torch.randn(condition.shape[0], self.config.acoustic_vae_dim).to(condition)
|
700 |
+
for t in self.model.noise_scheduler.timesteps:
|
701 |
+
half = speech[: len(speech) // 2]
|
702 |
+
combined = torch.cat([half, half], dim=0)
|
703 |
+
eps = self.model.prediction_head(combined, t.repeat(combined.shape[0]).to(combined), condition=condition)
|
704 |
+
cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
|
705 |
+
half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
|
706 |
+
eps = torch.cat([half_eps, half_eps], dim=0)
|
707 |
+
speech = self.model.noise_scheduler.step(eps, t, speech).prev_sample
|
708 |
+
return speech[: len(speech) // 2]
|
709 |
+
|
710 |
+
|
711 |
+
AutoModelForCausalLM.register(VibeVoiceConfig, VibeVoiceForConditionalGenerationInference)
|
712 |
+
|
713 |
+
__all__ = [
|
714 |
+
"VibeVoiceForConditionalGenerationInference",
|
715 |
+
]
|
modular/modular_vibevoice_diffusion_head.py
ADDED
@@ -0,0 +1,287 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from typing import Optional, Tuple, Union
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
from transformers.models.auto import AutoModel
|
9 |
+
from transformers.modeling_utils import PreTrainedModel
|
10 |
+
# from transformers.modeling_layers import GradientCheckpointingLayer
|
11 |
+
from transformers.activations import ACT2FN
|
12 |
+
from transformers.utils import logging
|
13 |
+
|
14 |
+
from .configuration_vibevoice import VibeVoiceDiffusionHeadConfig
|
15 |
+
|
16 |
+
|
17 |
+
logger = logging.get_logger(__name__)
|
18 |
+
|
19 |
+
|
20 |
+
class RMSNorm(nn.Module):
|
21 |
+
def __init__(self, dim: int, eps: float = 1e-6, elementwise_affine=True, memory_efficient=False):
|
22 |
+
super().__init__()
|
23 |
+
self.dim = dim
|
24 |
+
self.eps = eps
|
25 |
+
self.elementwise_affine = elementwise_affine
|
26 |
+
if self.elementwise_affine:
|
27 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
28 |
+
else:
|
29 |
+
self.register_parameter('weight', None)
|
30 |
+
|
31 |
+
def _norm(self, x):
|
32 |
+
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
33 |
+
|
34 |
+
def forward(self, x):
|
35 |
+
output = self._norm(x.float()).type_as(x)
|
36 |
+
if self.weight is not None:
|
37 |
+
output = output * self.weight
|
38 |
+
return output
|
39 |
+
|
40 |
+
def extra_repr(self) -> str:
|
41 |
+
return f'dim={self.dim}, eps={self.eps}, elementwise_affine={self.elementwise_affine}'
|
42 |
+
|
43 |
+
def modulate(x, shift, scale):
|
44 |
+
"""Apply modulation to input tensor."""
|
45 |
+
return x * (1 + scale) + shift
|
46 |
+
|
47 |
+
|
48 |
+
class TimestepEmbedder(nn.Module):
|
49 |
+
"""
|
50 |
+
Embeds scalar timesteps into vector representations.
|
51 |
+
|
52 |
+
Args:
|
53 |
+
hidden_size (`int`): Size of the output embedding
|
54 |
+
frequency_embedding_size (`int`, optional): Size of the intermediate frequency embedding
|
55 |
+
"""
|
56 |
+
def __init__(self, hidden_size, frequency_embedding_size=256):
|
57 |
+
super().__init__()
|
58 |
+
self.mlp = nn.Sequential(
|
59 |
+
nn.Linear(frequency_embedding_size, hidden_size, bias=False),
|
60 |
+
# nn.SiLU(),
|
61 |
+
ACT2FN['silu'],
|
62 |
+
nn.Linear(hidden_size, hidden_size, bias=False),
|
63 |
+
)
|
64 |
+
self.frequency_embedding_size = frequency_embedding_size
|
65 |
+
|
66 |
+
@staticmethod
|
67 |
+
def timestep_embedding(t, dim, max_period=10000):
|
68 |
+
"""
|
69 |
+
Create sinusoidal timestep embeddings.
|
70 |
+
|
71 |
+
Args:
|
72 |
+
t (`torch.Tensor`): A 1-D Tensor of N indices, one per batch element.
|
73 |
+
These may be fractional.
|
74 |
+
dim (`int`): The dimension of the output.
|
75 |
+
max_period (`int`, optional): Controls the minimum frequency of the embeddings.
|
76 |
+
|
77 |
+
Returns:
|
78 |
+
`torch.Tensor`: An [N, D] Tensor of positional embeddings.
|
79 |
+
"""
|
80 |
+
half = dim // 2
|
81 |
+
freqs = torch.exp(
|
82 |
+
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
|
83 |
+
).to(t.device)
|
84 |
+
args = t[:, None].float() * freqs[None]
|
85 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
86 |
+
if dim % 2:
|
87 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
88 |
+
return embedding.to(t.dtype)
|
89 |
+
|
90 |
+
def forward(self, t):
|
91 |
+
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
|
92 |
+
t_emb = self.mlp(t_freq)
|
93 |
+
return t_emb
|
94 |
+
|
95 |
+
|
96 |
+
class FeedForwardNetwork(nn.Module):
|
97 |
+
"""
|
98 |
+
Standard feed-forward network with SwiGLU activation.
|
99 |
+
|
100 |
+
Args:
|
101 |
+
embed_dim (`int`): Input dimension
|
102 |
+
ffn_dim (`int`): Hidden dimension
|
103 |
+
"""
|
104 |
+
def __init__(
|
105 |
+
self,
|
106 |
+
embed_dim,
|
107 |
+
ffn_dim,
|
108 |
+
):
|
109 |
+
super().__init__()
|
110 |
+
self.embed_dim = embed_dim
|
111 |
+
self.gate_proj = nn.Linear(self.embed_dim, ffn_dim, bias=False)
|
112 |
+
self.up_proj = nn.Linear(self.embed_dim, ffn_dim, bias=False)
|
113 |
+
self.down_proj = nn.Linear(ffn_dim, self.embed_dim, bias=False)
|
114 |
+
self.act_fn = ACT2FN['silu'] # Using SiLU as the activation function
|
115 |
+
|
116 |
+
def forward(self, x):
|
117 |
+
gate = self.gate_proj(x)
|
118 |
+
up = self.up_proj(x)
|
119 |
+
|
120 |
+
# SwiGLU activation
|
121 |
+
# gate = F.silu(gate)
|
122 |
+
gate = self.act_fn(gate)
|
123 |
+
return self.down_proj(gate * up)
|
124 |
+
|
125 |
+
|
126 |
+
class HeadLayer(nn.Module):
|
127 |
+
"""
|
128 |
+
A layer in the diffusion head.
|
129 |
+
|
130 |
+
Args:
|
131 |
+
embed_dim (`int`): Input dimension
|
132 |
+
ffn_dim (`int`): Hidden dimension
|
133 |
+
cond_dim (`int`): Condition embedding dimension
|
134 |
+
norm_eps (`float`, optional): Epsilon for normalization
|
135 |
+
"""
|
136 |
+
def __init__(
|
137 |
+
self,
|
138 |
+
embed_dim,
|
139 |
+
ffn_dim,
|
140 |
+
cond_dim,
|
141 |
+
norm_eps=1e-5,
|
142 |
+
):
|
143 |
+
super().__init__()
|
144 |
+
self.embed_dim = embed_dim
|
145 |
+
self.cond_dim = cond_dim
|
146 |
+
self.ffn_dim = ffn_dim
|
147 |
+
self.ffn = FeedForwardNetwork(
|
148 |
+
self.embed_dim,
|
149 |
+
self.ffn_dim,
|
150 |
+
)
|
151 |
+
self.norm = RMSNorm(self.embed_dim, eps=norm_eps)
|
152 |
+
self.adaLN_modulation = nn.Sequential(
|
153 |
+
# nn.SiLU(),
|
154 |
+
ACT2FN['silu'],
|
155 |
+
nn.Linear(cond_dim, 3 * self.embed_dim, bias=False)
|
156 |
+
)
|
157 |
+
|
158 |
+
def forward(self, x, c):
|
159 |
+
shift_ffn, scale_ffn, gate_ffn = self.adaLN_modulation(c).chunk(3, dim=-1)
|
160 |
+
x = x + gate_ffn * self.ffn(modulate(self.norm(x), shift_ffn, scale_ffn))
|
161 |
+
return x
|
162 |
+
|
163 |
+
|
164 |
+
class FinalLayer(nn.Module):
|
165 |
+
"""
|
166 |
+
Final layer in the diffusion head.
|
167 |
+
|
168 |
+
Args:
|
169 |
+
hidden_size (`int`): Input dimension
|
170 |
+
output_size (`int`): Output dimension
|
171 |
+
cond_size (`int`): Condition embedding dimension
|
172 |
+
norm_eps (`float`, optional): Epsilon for normalization
|
173 |
+
"""
|
174 |
+
def __init__(self, hidden_size, output_size, cond_size, norm_eps=1e-5):
|
175 |
+
super().__init__()
|
176 |
+
self.norm_final = RMSNorm(hidden_size, eps=norm_eps, elementwise_affine=False)
|
177 |
+
self.linear = nn.Linear(hidden_size, output_size, bias=False)
|
178 |
+
self.adaLN_modulation = nn.Sequential(
|
179 |
+
# nn.SiLU(),
|
180 |
+
ACT2FN['silu'],
|
181 |
+
nn.Linear(cond_size, 2 * hidden_size, bias=False)
|
182 |
+
)
|
183 |
+
|
184 |
+
def forward(self, x, c):
|
185 |
+
shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
|
186 |
+
x = modulate(self.norm_final(x), shift, scale)
|
187 |
+
x = self.linear(x)
|
188 |
+
return x
|
189 |
+
|
190 |
+
|
191 |
+
class VibeVoiceDiffusionHead(PreTrainedModel):
|
192 |
+
"""
|
193 |
+
Diffusion head model for vibevoice.
|
194 |
+
|
195 |
+
Args:
|
196 |
+
config (`VibeVoiceDiffusionHeadConfig`): Model configuration
|
197 |
+
latent_size (`int`, optional): Size of the latent space. If not provided, uses `config.latent_size`.
|
198 |
+
"""
|
199 |
+
config_class = VibeVoiceDiffusionHeadConfig
|
200 |
+
supports_gradient_checkpointing = True
|
201 |
+
_supports_flash_attn_2 = True
|
202 |
+
_supports_sdpa = True
|
203 |
+
|
204 |
+
def __init__(
|
205 |
+
self,
|
206 |
+
config,
|
207 |
+
):
|
208 |
+
super().__init__(config)
|
209 |
+
self.config = config
|
210 |
+
self.cond_dim = config.hidden_size
|
211 |
+
latent_size = config.latent_size
|
212 |
+
|
213 |
+
self.noisy_images_proj = nn.Linear(latent_size, config.hidden_size, bias=False)
|
214 |
+
self.cond_proj = nn.Linear(config.hidden_size, self.cond_dim, bias=False)
|
215 |
+
self.t_embedder = TimestepEmbedder(self.cond_dim)
|
216 |
+
|
217 |
+
ffn_dim = int(config.hidden_size * config.head_ffn_ratio)
|
218 |
+
|
219 |
+
# Create the intermediate layers
|
220 |
+
self.layers = nn.ModuleList([
|
221 |
+
HeadLayer(
|
222 |
+
embed_dim=config.hidden_size,
|
223 |
+
ffn_dim=ffn_dim,
|
224 |
+
cond_dim=self.cond_dim,
|
225 |
+
norm_eps=config.rms_norm_eps
|
226 |
+
)
|
227 |
+
for _ in range(config.head_layers)
|
228 |
+
])
|
229 |
+
|
230 |
+
# Final layer for output
|
231 |
+
self.final_layer = FinalLayer(
|
232 |
+
hidden_size=config.hidden_size,
|
233 |
+
output_size=latent_size,
|
234 |
+
cond_size=self.cond_dim,
|
235 |
+
norm_eps=config.rms_norm_eps
|
236 |
+
)
|
237 |
+
|
238 |
+
self.initialize_weights()
|
239 |
+
|
240 |
+
def initialize_weights(self):
|
241 |
+
"""Initialize the weights of the model."""
|
242 |
+
# Initialize timestep embedder
|
243 |
+
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
|
244 |
+
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
|
245 |
+
|
246 |
+
# Zero-out adaLN modulation layers
|
247 |
+
for layer in self.layers:
|
248 |
+
nn.init.constant_(layer.adaLN_modulation[-1].weight, 0)
|
249 |
+
|
250 |
+
# Zero-out output layers
|
251 |
+
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
|
252 |
+
nn.init.constant_(self.final_layer.linear.weight, 0)
|
253 |
+
|
254 |
+
def forward(
|
255 |
+
self,
|
256 |
+
noisy_images,
|
257 |
+
timesteps,
|
258 |
+
condition,
|
259 |
+
):
|
260 |
+
"""
|
261 |
+
Forward pass of the prediction head.
|
262 |
+
|
263 |
+
Args:
|
264 |
+
noisy_images (`torch.Tensor`): Noisy images/latents to denoise
|
265 |
+
timesteps (`torch.Tensor`): Timesteps for diffusion
|
266 |
+
condition (`torch.Tensor`): Conditioning information
|
267 |
+
|
268 |
+
Returns:
|
269 |
+
`torch.Tensor`: The predicted noise/velocity
|
270 |
+
"""
|
271 |
+
x = self.noisy_images_proj(noisy_images)
|
272 |
+
t = self.t_embedder(timesteps)
|
273 |
+
condition = self.cond_proj(condition)
|
274 |
+
c = condition + t
|
275 |
+
|
276 |
+
for layer in self.layers:
|
277 |
+
x = layer(x, c)
|
278 |
+
|
279 |
+
x = self.final_layer(x, c)
|
280 |
+
return x
|
281 |
+
|
282 |
+
|
283 |
+
AutoModel.register(VibeVoiceDiffusionHeadConfig, VibeVoiceDiffusionHead)
|
284 |
+
|
285 |
+
__all__ = [
|
286 |
+
"VibeVoiceDiffusionHead",
|
287 |
+
]
|
modular/modular_vibevoice_text_tokenizer.py
ADDED
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Tokenization classes for vibevoice."""
|
2 |
+
|
3 |
+
from typing import List, Optional, Union
|
4 |
+
|
5 |
+
from transformers.utils import logging
|
6 |
+
from transformers.models.qwen2.tokenization_qwen2 import Qwen2Tokenizer
|
7 |
+
from transformers.models.qwen2.tokenization_qwen2_fast import Qwen2TokenizerFast
|
8 |
+
|
9 |
+
logger = logging.get_logger(__name__)
|
10 |
+
|
11 |
+
|
12 |
+
class VibeVoiceTextTokenizer(Qwen2Tokenizer):
|
13 |
+
"""
|
14 |
+
Construct a VibeVoice tokenizer. Based on the Qwen2 tokenizer with additional special tokens for speech.
|
15 |
+
|
16 |
+
Args:
|
17 |
+
vocab_file (`str`):
|
18 |
+
Path to the vocabulary file.
|
19 |
+
merges_file (`str`):
|
20 |
+
Path to the merges file.
|
21 |
+
errors (`str`, *optional*, defaults to `"replace"`):
|
22 |
+
Paradigm to follow when decoding bytes to UTF-8.
|
23 |
+
unk_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
|
24 |
+
The unknown token.
|
25 |
+
bos_token (`str`, *optional*):
|
26 |
+
The beginning of sequence token. Not used for vibevoice.
|
27 |
+
eos_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
|
28 |
+
The end of sequence token.
|
29 |
+
pad_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
|
30 |
+
The token used for padding.
|
31 |
+
add_special_tokens (`bool`, *optional*, defaults to `True`):
|
32 |
+
Whether or not to add special tokens when encoding.
|
33 |
+
"""
|
34 |
+
|
35 |
+
model_input_names = ["input_ids", "attention_mask"]
|
36 |
+
|
37 |
+
def __init__(
|
38 |
+
self,
|
39 |
+
vocab_file,
|
40 |
+
merges_file,
|
41 |
+
errors="replace",
|
42 |
+
unk_token="<|endoftext|>",
|
43 |
+
bos_token=None,
|
44 |
+
eos_token="<|endoftext|>",
|
45 |
+
pad_token="<|endoftext|>",
|
46 |
+
add_prefix_space=False,
|
47 |
+
add_special_tokens=True,
|
48 |
+
**kwargs,
|
49 |
+
):
|
50 |
+
super().__init__(
|
51 |
+
vocab_file=vocab_file,
|
52 |
+
merges_file=merges_file,
|
53 |
+
errors=errors,
|
54 |
+
unk_token=unk_token,
|
55 |
+
bos_token=bos_token,
|
56 |
+
eos_token=eos_token,
|
57 |
+
pad_token=pad_token,
|
58 |
+
add_prefix_space=add_prefix_space,
|
59 |
+
add_special_tokens=add_special_tokens,
|
60 |
+
**kwargs,
|
61 |
+
)
|
62 |
+
|
63 |
+
# Add VibeVoice-specific special tokens
|
64 |
+
self._add_vibevoice_special_tokens()
|
65 |
+
|
66 |
+
def _add_vibevoice_special_tokens(self):
|
67 |
+
"""Add VibeVoice-specific special tokens."""
|
68 |
+
special_tokens = {
|
69 |
+
"additional_special_tokens": [
|
70 |
+
"<|vision_start|>", # Speech start (reusing vision tokens)
|
71 |
+
"<|vision_end|>", # Speech end
|
72 |
+
"<|vision_pad|>", # Speech diffusion pad
|
73 |
+
]
|
74 |
+
}
|
75 |
+
num_added = self.add_special_tokens(special_tokens)
|
76 |
+
|
77 |
+
# Cache special token IDs
|
78 |
+
self._speech_start_id = self.convert_tokens_to_ids("<|vision_start|>")
|
79 |
+
self._speech_end_id = self.convert_tokens_to_ids("<|vision_end|>")
|
80 |
+
self._speech_diffusion_id = self.convert_tokens_to_ids("<|vision_pad|>")
|
81 |
+
|
82 |
+
self._eos_id = self.convert_tokens_to_ids('<|endoftext|>')
|
83 |
+
|
84 |
+
return num_added
|
85 |
+
|
86 |
+
@property
|
87 |
+
def eos_id(self) -> int:
|
88 |
+
"""Id of the end of sequence token."""
|
89 |
+
return self._eos_id
|
90 |
+
|
91 |
+
@property
|
92 |
+
def speech_start_id(self) -> int:
|
93 |
+
"""Id of the speech start token."""
|
94 |
+
return self._speech_start_id
|
95 |
+
|
96 |
+
@property
|
97 |
+
def speech_end_id(self) -> int:
|
98 |
+
"""Id of the speech end token."""
|
99 |
+
return self._speech_end_id
|
100 |
+
|
101 |
+
@property
|
102 |
+
def speech_diffusion_id(self) -> int:
|
103 |
+
"""Id of the speech diffusion token."""
|
104 |
+
return self._speech_diffusion_id
|
105 |
+
|
106 |
+
@property
|
107 |
+
def pad_id(self) -> int:
|
108 |
+
"""Id used for padding (returns -100 for loss masking)."""
|
109 |
+
return -100
|
110 |
+
|
111 |
+
|
112 |
+
class VibeVoiceTextTokenizerFast(Qwen2TokenizerFast):
|
113 |
+
"""
|
114 |
+
Construct a "fast" VibeVoice tokenizer (backed by HuggingFace's *tokenizers* library).
|
115 |
+
Based on the Qwen2 tokenizer with additional special tokens for speech.
|
116 |
+
|
117 |
+
Args:
|
118 |
+
vocab_file (`str`, *optional*):
|
119 |
+
Path to the vocabulary file.
|
120 |
+
merges_file (`str`, *optional*):
|
121 |
+
Path to the merges file.
|
122 |
+
tokenizer_file (`str`, *optional*):
|
123 |
+
Path to [tokenizers](https://github.com/huggingface/tokenizers) file.
|
124 |
+
unk_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
|
125 |
+
The unknown token.
|
126 |
+
bos_token (`str`, *optional*):
|
127 |
+
The beginning of sequence token. Not used for vibevoice.
|
128 |
+
eos_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
|
129 |
+
The end of sequence token.
|
130 |
+
pad_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
|
131 |
+
The token used for padding.
|
132 |
+
"""
|
133 |
+
|
134 |
+
model_input_names = ["input_ids", "attention_mask"]
|
135 |
+
|
136 |
+
def __init__(
|
137 |
+
self,
|
138 |
+
vocab_file=None,
|
139 |
+
merges_file=None,
|
140 |
+
tokenizer_file=None,
|
141 |
+
unk_token="<|endoftext|>",
|
142 |
+
bos_token=None,
|
143 |
+
eos_token="<|endoftext|>",
|
144 |
+
pad_token="<|endoftext|>",
|
145 |
+
add_prefix_space=False,
|
146 |
+
**kwargs,
|
147 |
+
):
|
148 |
+
super().__init__(
|
149 |
+
vocab_file=vocab_file,
|
150 |
+
merges_file=merges_file,
|
151 |
+
tokenizer_file=tokenizer_file,
|
152 |
+
unk_token=unk_token,
|
153 |
+
bos_token=bos_token,
|
154 |
+
eos_token=eos_token,
|
155 |
+
pad_token=pad_token,
|
156 |
+
add_prefix_space=add_prefix_space,
|
157 |
+
**kwargs,
|
158 |
+
)
|
159 |
+
|
160 |
+
# Add VibeVoice-specific special tokens
|
161 |
+
self._add_vibevoice_special_tokens()
|
162 |
+
|
163 |
+
def _add_vibevoice_special_tokens(self):
|
164 |
+
"""Add VibeVoice-specific special tokens."""
|
165 |
+
special_tokens = {
|
166 |
+
"additional_special_tokens": [
|
167 |
+
"<|vision_start|>", # Speech start (reusing vision tokens)
|
168 |
+
"<|vision_end|>", # Speech end
|
169 |
+
"<|vision_pad|>", # Speech diffusion pad
|
170 |
+
]
|
171 |
+
}
|
172 |
+
num_added = self.add_special_tokens(special_tokens)
|
173 |
+
|
174 |
+
# Cache special token IDs
|
175 |
+
self._speech_start_id = self.convert_tokens_to_ids("<|vision_start|>")
|
176 |
+
self._speech_end_id = self.convert_tokens_to_ids("<|vision_end|>")
|
177 |
+
self._speech_diffusion_id = self.convert_tokens_to_ids("<|vision_pad|>")
|
178 |
+
|
179 |
+
# self._eos_id = self.convert_tokens_to_ids('<|endoftext|>')
|
180 |
+
self._eos_id = self.eos_token_id # qwen2 / qwen3
|
181 |
+
self._pad_id = self.convert_tokens_to_ids('<|image_pad|>')
|
182 |
+
|
183 |
+
return num_added
|
184 |
+
|
185 |
+
@property
|
186 |
+
def eos_id(self) -> int:
|
187 |
+
"""Id of the end of sequence token."""
|
188 |
+
return self._eos_id
|
189 |
+
|
190 |
+
@property
|
191 |
+
def speech_start_id(self) -> int:
|
192 |
+
"""Id of the speech start token."""
|
193 |
+
return self._speech_start_id
|
194 |
+
|
195 |
+
@property
|
196 |
+
def speech_end_id(self) -> int:
|
197 |
+
"""Id of the speech end token."""
|
198 |
+
return self._speech_end_id
|
199 |
+
|
200 |
+
@property
|
201 |
+
def speech_diffusion_id(self) -> int:
|
202 |
+
"""Id of the speech diffusion token."""
|
203 |
+
return self._speech_diffusion_id
|
204 |
+
|
205 |
+
@property
|
206 |
+
def pad_id(self) -> int:
|
207 |
+
"""Id used for padding (returns -100 for loss masking)."""
|
208 |
+
return self._pad_id
|
209 |
+
|
210 |
+
|
211 |
+
__all__ = [
|
212 |
+
"VibeVoiceTextTokenizer",
|
213 |
+
"VibeVoiceTextTokenizerFast",
|
214 |
+
]
|
modular/modular_vibevoice_tokenizer.py
ADDED
@@ -0,0 +1,1195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import typing as tp
|
3 |
+
from functools import partial
|
4 |
+
from dataclasses import dataclass, field
|
5 |
+
from typing import Dict, List, Optional, Tuple, Union
|
6 |
+
import copy
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
import torch.nn.functional as F
|
12 |
+
|
13 |
+
from transformers.models.auto import AutoModel
|
14 |
+
|
15 |
+
from transformers.configuration_utils import PretrainedConfig
|
16 |
+
from transformers.utils import logging
|
17 |
+
from transformers.modeling_utils import PreTrainedModel
|
18 |
+
from transformers.activations import ACT2FN
|
19 |
+
|
20 |
+
from .configuration_vibevoice import VibeVoiceAcousticTokenizerConfig, VibeVoiceSemanticTokenizerConfig
|
21 |
+
|
22 |
+
logger = logging.get_logger(__name__)
|
23 |
+
|
24 |
+
import os
|
25 |
+
# Try to import APEX FusedRMSNorm
|
26 |
+
try:
|
27 |
+
from apex.normalization.fused_layer_norm import fused_rms_norm_affine
|
28 |
+
APEX_AVAILABLE = True
|
29 |
+
logger.info("APEX FusedRMSNorm is available and will be used for optimization")
|
30 |
+
if int(os.getenv("OPTIMIZE_FOR_SPEED", "0")) == 0:
|
31 |
+
APEX_AVAILABLE = False
|
32 |
+
logger.warning("APEX FusedRMSNorm is disabled by environment variable OPTIMIZE_FOR_SPEED=0")
|
33 |
+
except ImportError:
|
34 |
+
APEX_AVAILABLE = False
|
35 |
+
logger.warning("APEX FusedRMSNorm not available, using native implementation")
|
36 |
+
# APEX_AVAILABLE=False
|
37 |
+
|
38 |
+
# Normalization modules
|
39 |
+
class ConvLayerNorm(nn.LayerNorm):
|
40 |
+
"""
|
41 |
+
Convolution-friendly LayerNorm that moves channels to last dimensions
|
42 |
+
before running the normalization and moves them back to original position right after.
|
43 |
+
"""
|
44 |
+
def __init__(self, normalized_shape: tp.Union[int, tp.List[int], torch.Size], **kwargs):
|
45 |
+
super().__init__(normalized_shape, **kwargs)
|
46 |
+
|
47 |
+
def forward(self, x):
|
48 |
+
x = x.transpose(1, 2) # b ... t -> b t ...
|
49 |
+
x = nn.functional.layer_norm(x.float(), self.normalized_shape, self.weight.float(), self.bias.float(), self.eps).type_as(x)
|
50 |
+
x = x.transpose(1, 2) # b t ... -> b ... t
|
51 |
+
return x
|
52 |
+
|
53 |
+
class RMSNorm(nn.Module):
|
54 |
+
def __init__(self, dim: int, eps: float = 1e-5, elementwise_affine=True, weight_shape=None):
|
55 |
+
super().__init__()
|
56 |
+
self.dim = dim
|
57 |
+
self.eps = eps
|
58 |
+
self.elementwise_affine = elementwise_affine
|
59 |
+
if self.elementwise_affine:
|
60 |
+
weight_shape = (dim,) if weight_shape is None else weight_shape
|
61 |
+
self.weight = nn.Parameter(torch.ones(weight_shape))
|
62 |
+
else:
|
63 |
+
self.register_parameter('weight', None)
|
64 |
+
|
65 |
+
def _norm(self, x):
|
66 |
+
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
67 |
+
|
68 |
+
def forward(self, x):
|
69 |
+
output = self._norm(x.float()).type_as(x)
|
70 |
+
if self.weight is not None:
|
71 |
+
output = output * self.weight
|
72 |
+
return output
|
73 |
+
|
74 |
+
def extra_repr(self) -> str:
|
75 |
+
return f'dim={self.dim}, eps={self.eps}, elementwise_affine={self.elementwise_affine}'
|
76 |
+
|
77 |
+
class ConvRMSNorm(RMSNorm):
|
78 |
+
def __init__(self, dim: int, eps: float = 1e-5, elementwise_affine=True, weight_shape=None):
|
79 |
+
super().__init__(dim, eps, elementwise_affine, weight_shape)
|
80 |
+
|
81 |
+
def forward(self, x):
|
82 |
+
x = x.transpose(1, 2) # b ... t -> b t ...
|
83 |
+
if (not APEX_AVAILABLE) or (not self.elementwise_affine):
|
84 |
+
# Fallback to native implementation
|
85 |
+
output = self._norm(x.float()).type_as(x)
|
86 |
+
if self.weight is not None:
|
87 |
+
output = output * self.weight
|
88 |
+
else:
|
89 |
+
output = fused_rms_norm_affine(x, self.weight, self.weight.shape, self.eps)
|
90 |
+
output = output.transpose(1, 2) # b t ... -> b ... t
|
91 |
+
return output
|
92 |
+
|
93 |
+
# Convolutional layers and utilities
|
94 |
+
CONV_NORMALIZATIONS = frozenset(['none', 'weight_norm', 'spectral_norm',
|
95 |
+
'time_layer_norm', 'layer_norm', 'time_group_norm'])
|
96 |
+
|
97 |
+
|
98 |
+
def apply_parametrization_norm(module: nn.Module, norm: str = 'none') -> nn.Module:
|
99 |
+
assert norm in CONV_NORMALIZATIONS
|
100 |
+
if norm == 'weight_norm':
|
101 |
+
return nn.utils.weight_norm(module)
|
102 |
+
elif norm == 'spectral_norm':
|
103 |
+
return nn.utils.spectral_norm(module)
|
104 |
+
else:
|
105 |
+
# We already check was in CONV_NORMALIZATION, so any other choice
|
106 |
+
# doesn't need reparametrization.
|
107 |
+
return module
|
108 |
+
|
109 |
+
|
110 |
+
def get_norm_module(module: nn.Module, causal: bool = False, norm: str = 'none', **norm_kwargs) -> nn.Module:
|
111 |
+
"""Return the proper normalization module. If causal is True, this will ensure the returned
|
112 |
+
module is causal, or return an error if the normalization doesn't support causal evaluation.
|
113 |
+
"""
|
114 |
+
assert norm in CONV_NORMALIZATIONS
|
115 |
+
if norm == 'layer_norm':
|
116 |
+
assert isinstance(module, nn.modules.conv._ConvNd)
|
117 |
+
return ConvLayerNorm(module.out_channels, **norm_kwargs)
|
118 |
+
elif norm == 'time_group_norm':
|
119 |
+
if causal:
|
120 |
+
raise ValueError("GroupNorm doesn't support causal evaluation.")
|
121 |
+
assert isinstance(module, nn.modules.conv._ConvNd)
|
122 |
+
return nn.GroupNorm(1, module.out_channels, **norm_kwargs)
|
123 |
+
else:
|
124 |
+
return nn.Identity()
|
125 |
+
|
126 |
+
|
127 |
+
def get_extra_padding_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int,
|
128 |
+
padding_total: int = 0) -> int:
|
129 |
+
"""Calculate extra padding needed for convolution to have the same output length"""
|
130 |
+
length = x.shape[-1]
|
131 |
+
n_frames = (length - kernel_size + padding_total) / stride + 1
|
132 |
+
ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
|
133 |
+
return ideal_length - length
|
134 |
+
|
135 |
+
|
136 |
+
def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'zero', value: float = 0.):
|
137 |
+
"""Pad 1D input with handling for small inputs in reflect mode"""
|
138 |
+
length = x.shape[-1]
|
139 |
+
padding_left, padding_right = paddings
|
140 |
+
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
|
141 |
+
if mode == 'reflect':
|
142 |
+
max_pad = max(padding_left, padding_right)
|
143 |
+
extra_pad = 0
|
144 |
+
if length <= max_pad:
|
145 |
+
extra_pad = max_pad - length + 1
|
146 |
+
x = F.pad(x, (0, extra_pad))
|
147 |
+
padded = F.pad(x, paddings, mode, value)
|
148 |
+
end = padded.shape[-1] - extra_pad
|
149 |
+
return padded[..., :end]
|
150 |
+
else:
|
151 |
+
return F.pad(x, paddings, mode, value)
|
152 |
+
|
153 |
+
|
154 |
+
def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]):
|
155 |
+
"""Remove padding from x, handling properly zero padding. Only for 1d!"""
|
156 |
+
padding_left, padding_right = paddings
|
157 |
+
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
|
158 |
+
assert (padding_left + padding_right) <= x.shape[-1]
|
159 |
+
end = x.shape[-1] - padding_right
|
160 |
+
return x[..., padding_left: end]
|
161 |
+
|
162 |
+
|
163 |
+
class NormConv1d(nn.Module):
|
164 |
+
"""Wrapper around Conv1d and normalization applied to this conv"""
|
165 |
+
def __init__(self, *args, causal: bool = False, norm: str = 'none',
|
166 |
+
norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
|
167 |
+
super().__init__()
|
168 |
+
self.conv = apply_parametrization_norm(nn.Conv1d(*args, **kwargs), norm)
|
169 |
+
self.norm = get_norm_module(self.conv, causal, norm, **norm_kwargs)
|
170 |
+
self.norm_type = norm
|
171 |
+
|
172 |
+
def forward(self, x):
|
173 |
+
x = self.conv(x)
|
174 |
+
x = self.norm(x)
|
175 |
+
return x
|
176 |
+
|
177 |
+
|
178 |
+
class NormConvTranspose1d(nn.Module):
|
179 |
+
"""Wrapper around ConvTranspose1d and normalization applied to this conv"""
|
180 |
+
def __init__(self, *args, causal: bool = False, norm: str = 'none',
|
181 |
+
norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
|
182 |
+
super().__init__()
|
183 |
+
self.convtr = apply_parametrization_norm(nn.ConvTranspose1d(*args, **kwargs), norm)
|
184 |
+
self.norm = get_norm_module(self.convtr, causal, norm, **norm_kwargs)
|
185 |
+
self.norm_type = norm
|
186 |
+
|
187 |
+
def forward(self, x):
|
188 |
+
x = self.convtr(x)
|
189 |
+
x = self.norm(x)
|
190 |
+
return x
|
191 |
+
|
192 |
+
|
193 |
+
class VibeVoiceTokenizerStreamingCache:
|
194 |
+
"""Cache for streaming convolution, similar to KV cache in attention"""
|
195 |
+
def __init__(self):
|
196 |
+
self.cache = {} # Dict mapping (layer_id, sample_idx) to state tensor
|
197 |
+
|
198 |
+
def get(self, layer_id: str, sample_indices: torch.Tensor) -> Optional[torch.Tensor]:
|
199 |
+
"""Get cached states for given layer and sample indices"""
|
200 |
+
states = []
|
201 |
+
max_length = 0
|
202 |
+
|
203 |
+
# First pass: collect states and find max length
|
204 |
+
for idx in sample_indices.tolist():
|
205 |
+
key = (layer_id, idx)
|
206 |
+
if key not in self.cache:
|
207 |
+
return None # If any sample is missing, return None
|
208 |
+
state = self.cache[key]
|
209 |
+
states.append(state)
|
210 |
+
max_length = max(max_length, state.shape[-1])
|
211 |
+
|
212 |
+
# Second pass: pad states to max length if needed
|
213 |
+
if len(states) > 0 and states[0].dim() >= 2:
|
214 |
+
padded_states = []
|
215 |
+
for state in states:
|
216 |
+
if state.shape[-1] < max_length:
|
217 |
+
# Pad on the time dimension (last dimension)
|
218 |
+
pad_size = max_length - state.shape[-1]
|
219 |
+
# Pad with zeros on the LEFT to align the most recent samples
|
220 |
+
padded_state = F.pad(state, (pad_size, 0), mode='constant', value=0)
|
221 |
+
padded_states.append(padded_state)
|
222 |
+
else:
|
223 |
+
padded_states.append(state)
|
224 |
+
return torch.stack(padded_states, dim=0)
|
225 |
+
else:
|
226 |
+
return torch.stack(states, dim=0)
|
227 |
+
|
228 |
+
def set(self, layer_id: str, sample_indices: torch.Tensor, states: torch.Tensor):
|
229 |
+
"""Set cached states for given layer and sample indices"""
|
230 |
+
for i, idx in enumerate(sample_indices.tolist()):
|
231 |
+
key = (layer_id, idx)
|
232 |
+
self.cache[key] = states[i].detach()
|
233 |
+
|
234 |
+
def set_to_zero(self, sample_indices: torch.Tensor):
|
235 |
+
"""Set all cached states to zero for given sample indices"""
|
236 |
+
for key in list(self.cache.keys()):
|
237 |
+
layer_id, sample_idx = key
|
238 |
+
if sample_idx in sample_indices.tolist():
|
239 |
+
# Create zero tensor with same shape and dtype as cached tensor
|
240 |
+
cached_tensor = self.cache[key]
|
241 |
+
self.cache[key] = torch.zeros_like(cached_tensor)
|
242 |
+
|
243 |
+
def clear(self, layer_id: Optional[str] = None, sample_indices: Optional[torch.Tensor] = None):
|
244 |
+
"""Clear cache for specific layer/samples or everything"""
|
245 |
+
if layer_id is None and sample_indices is None:
|
246 |
+
self.cache.clear()
|
247 |
+
elif layer_id is not None and sample_indices is None:
|
248 |
+
# Clear all samples for a specific layer
|
249 |
+
keys_to_remove = [k for k in self.cache.keys() if k[0] == layer_id]
|
250 |
+
for k in keys_to_remove:
|
251 |
+
del self.cache[k]
|
252 |
+
elif layer_id is not None and sample_indices is not None:
|
253 |
+
# Clear specific samples for a specific layer
|
254 |
+
for idx in sample_indices.tolist():
|
255 |
+
key = (layer_id, idx)
|
256 |
+
self.cache.pop(key, None)
|
257 |
+
|
258 |
+
class SConv1d(nn.Module):
|
259 |
+
"""Conv1d with built-in handling of asymmetric or causal padding and normalization."""
|
260 |
+
def __init__(self, in_channels: int, out_channels: int,
|
261 |
+
kernel_size: int, stride: int = 1, dilation: int = 1,
|
262 |
+
groups: int = 1, bias: bool = True, causal: bool = False,
|
263 |
+
norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {},
|
264 |
+
pad_mode: str = 'reflect'):
|
265 |
+
super().__init__()
|
266 |
+
self.conv = NormConv1d(in_channels, out_channels, kernel_size, stride,
|
267 |
+
dilation=dilation, groups=groups, bias=bias, causal=causal,
|
268 |
+
norm=norm, norm_kwargs=norm_kwargs)
|
269 |
+
self.causal = causal
|
270 |
+
self.pad_mode = pad_mode
|
271 |
+
|
272 |
+
# Store configuration
|
273 |
+
self.kernel_size = kernel_size
|
274 |
+
self.dilation = dilation
|
275 |
+
self.stride = stride
|
276 |
+
self.in_channels = in_channels
|
277 |
+
self.out_channels = out_channels
|
278 |
+
|
279 |
+
# For causal convolution, we need to maintain kernel_size - 1 samples as context
|
280 |
+
# need to check use which context_size is more suitable
|
281 |
+
# self.context_size = (kernel_size - 1) * dilation
|
282 |
+
self.context_size = (kernel_size - 1) * dilation - (stride - 1)
|
283 |
+
|
284 |
+
# For non-streaming mode, calculate padding
|
285 |
+
self.padding_total = (kernel_size - 1) * dilation - (stride - 1)
|
286 |
+
|
287 |
+
# Create a unique layer ID for cache management
|
288 |
+
self._layer_id = None
|
289 |
+
|
290 |
+
@property
|
291 |
+
def layer_id(self):
|
292 |
+
if self._layer_id is None:
|
293 |
+
self._layer_id = f"sconv1d_{id(self)}"
|
294 |
+
return self._layer_id
|
295 |
+
|
296 |
+
def forward(self, x: torch.Tensor,
|
297 |
+
cache: Optional[VibeVoiceTokenizerStreamingCache] = None,
|
298 |
+
sample_indices: Optional[torch.Tensor] = None,
|
299 |
+
use_cache: bool = False,
|
300 |
+
debug: bool = False) -> torch.Tensor:
|
301 |
+
"""
|
302 |
+
Forward pass with optional streaming support via cache.
|
303 |
+
|
304 |
+
Args:
|
305 |
+
x: Input tensor [batch_size, channels, time]
|
306 |
+
cache: VibeVoiceTokenizerStreamingCache object for maintaining states
|
307 |
+
sample_indices: Indices identifying each sample for cache management
|
308 |
+
use_cache: Whether to use cached states for streaming
|
309 |
+
debug: Whether to print debug information
|
310 |
+
|
311 |
+
Returns:
|
312 |
+
Output tensor
|
313 |
+
"""
|
314 |
+
B, C, T = x.shape
|
315 |
+
|
316 |
+
# Non-streaming mode
|
317 |
+
if not use_cache or cache is None:
|
318 |
+
return self._forward_non_streaming(x, debug=debug)
|
319 |
+
|
320 |
+
# Streaming mode
|
321 |
+
assert self.causal, "Streaming mode is only supported for causal convolutions"
|
322 |
+
assert sample_indices is not None, "sample_indices must be provided for streaming mode"
|
323 |
+
assert len(sample_indices) == B, "sample_indices must match batch size"
|
324 |
+
|
325 |
+
return self._forward_streaming(x, cache, sample_indices, debug)
|
326 |
+
|
327 |
+
def _forward_streaming(self, x: torch.Tensor,
|
328 |
+
cache: VibeVoiceTokenizerStreamingCache,
|
329 |
+
sample_indices: torch.Tensor,
|
330 |
+
debug: bool = False) -> torch.Tensor:
|
331 |
+
"""Streaming forward pass with cache operations kept separate from compiled code"""
|
332 |
+
B, C, T = x.shape
|
333 |
+
|
334 |
+
# Cache operations (not compiled)
|
335 |
+
cached_states = cache.get(self.layer_id, sample_indices)
|
336 |
+
|
337 |
+
if cached_states is None:
|
338 |
+
# First chunk - initialize with zeros for context
|
339 |
+
if self.context_size > 0:
|
340 |
+
cached_states = torch.zeros(B, C, self.context_size, device=x.device, dtype=x.dtype)
|
341 |
+
if debug:
|
342 |
+
print(f"[DEBUG] Initialized cache with shape: {cached_states.shape}, context_size={self.context_size}")
|
343 |
+
else:
|
344 |
+
cached_states = torch.zeros(B, C, 0, device=x.device, dtype=x.dtype)
|
345 |
+
if debug:
|
346 |
+
print(f"[DEBUG] No context needed (kernel_size=stride)")
|
347 |
+
|
348 |
+
# Concatenate cached states with input
|
349 |
+
if cached_states.shape[2] > 0:
|
350 |
+
input_with_context = torch.cat([cached_states, x], dim=2)
|
351 |
+
else:
|
352 |
+
input_with_context = x
|
353 |
+
|
354 |
+
if debug:
|
355 |
+
print(f"[DEBUG] Input shape: {x.shape}, Cache shape: {cached_states.shape}, Combined: {input_with_context.shape}")
|
356 |
+
|
357 |
+
# Apply convolution directly - no extra padding in streaming mode
|
358 |
+
# The conv layer will handle its own padding internally
|
359 |
+
output = self.conv(input_with_context)
|
360 |
+
|
361 |
+
if debug:
|
362 |
+
print(f"[DEBUG] Output shape: {output.shape}")
|
363 |
+
|
364 |
+
# Update cache for next chunk
|
365 |
+
if self.context_size > 0:
|
366 |
+
# Calculate how many samples to keep
|
367 |
+
total_input_length = input_with_context.shape[2]
|
368 |
+
|
369 |
+
# Keep the last context_size samples
|
370 |
+
if total_input_length >= self.context_size:
|
371 |
+
new_cache_start = total_input_length - self.context_size
|
372 |
+
new_cache = input_with_context[:, :, new_cache_start:]
|
373 |
+
else:
|
374 |
+
# If we have less than context_size samples, keep everything
|
375 |
+
new_cache = input_with_context
|
376 |
+
|
377 |
+
if debug:
|
378 |
+
print(f"[DEBUG] New cache shape: {new_cache.shape}")
|
379 |
+
|
380 |
+
cache.set(self.layer_id, sample_indices, new_cache)
|
381 |
+
|
382 |
+
return output
|
383 |
+
|
384 |
+
def _forward_non_streaming(self, x: torch.Tensor, debug: bool = False) -> torch.Tensor:
|
385 |
+
"""Standard forward pass without streaming"""
|
386 |
+
B, C, T = x.shape
|
387 |
+
kernel_size = self.kernel_size
|
388 |
+
stride = self.stride
|
389 |
+
dilation = self.dilation
|
390 |
+
padding_total = self.padding_total
|
391 |
+
|
392 |
+
# Compute extra padding for stride alignment
|
393 |
+
extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
|
394 |
+
|
395 |
+
if debug:
|
396 |
+
print(f"[DEBUG NON-STREAMING] Input shape: {x.shape}, padding_total={padding_total}, extra_padding={extra_padding}")
|
397 |
+
|
398 |
+
if self.causal:
|
399 |
+
# Left padding for causal
|
400 |
+
if self.pad_mode == 'constant':
|
401 |
+
x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode, value=0)
|
402 |
+
else:
|
403 |
+
x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode)
|
404 |
+
else:
|
405 |
+
# Symmetric padding for non-causal
|
406 |
+
padding_right = padding_total // 2
|
407 |
+
padding_left = padding_total - padding_right
|
408 |
+
x = pad1d(x, (padding_left, padding_right + extra_padding), mode=self.pad_mode)
|
409 |
+
|
410 |
+
if debug:
|
411 |
+
print(f"[DEBUG NON-STREAMING] After padding: {x.shape}")
|
412 |
+
|
413 |
+
output = self.conv(x)
|
414 |
+
|
415 |
+
if debug:
|
416 |
+
print(f"[DEBUG NON-STREAMING] Output shape: {output.shape}")
|
417 |
+
|
418 |
+
return output
|
419 |
+
|
420 |
+
|
421 |
+
class SConvTranspose1d(nn.Module):
|
422 |
+
"""ConvTranspose1d with built-in handling of asymmetric or causal padding and normalization."""
|
423 |
+
def __init__(self, in_channels: int, out_channels: int,
|
424 |
+
kernel_size: int, stride: int = 1, causal: bool = False,
|
425 |
+
norm: str = 'none', trim_right_ratio: float = 1.,
|
426 |
+
norm_kwargs: tp.Dict[str, tp.Any] = {}, bias: bool = True):
|
427 |
+
super().__init__()
|
428 |
+
self.convtr = NormConvTranspose1d(in_channels, out_channels, kernel_size, stride,
|
429 |
+
causal=causal, norm=norm, norm_kwargs=norm_kwargs, bias=bias)
|
430 |
+
self.causal = causal
|
431 |
+
self.trim_right_ratio = trim_right_ratio
|
432 |
+
assert self.causal or self.trim_right_ratio == 1., \
|
433 |
+
"`trim_right_ratio` != 1.0 only makes sense for causal convolutions"
|
434 |
+
assert self.trim_right_ratio >= 0. and self.trim_right_ratio <= 1.
|
435 |
+
|
436 |
+
# Store configuration
|
437 |
+
self.kernel_size = kernel_size
|
438 |
+
self.stride = stride
|
439 |
+
self.in_channels = in_channels
|
440 |
+
self.out_channels = out_channels
|
441 |
+
|
442 |
+
# For transposed convolution, padding calculation is different
|
443 |
+
self.padding_total = kernel_size - stride
|
444 |
+
|
445 |
+
# For streaming, we need to keep track of input history
|
446 |
+
# Transposed conv needs to see multiple input samples to produce correct output
|
447 |
+
self.context_size = kernel_size - 1
|
448 |
+
|
449 |
+
# Create a unique layer ID for cache management
|
450 |
+
self._layer_id = None
|
451 |
+
|
452 |
+
@property
|
453 |
+
def layer_id(self):
|
454 |
+
if self._layer_id is None:
|
455 |
+
self._layer_id = f"sconvtr1d_{id(self)}"
|
456 |
+
return self._layer_id
|
457 |
+
|
458 |
+
def forward(self, x: torch.Tensor,
|
459 |
+
cache: Optional[VibeVoiceTokenizerStreamingCache] = None,
|
460 |
+
sample_indices: Optional[torch.Tensor] = None,
|
461 |
+
use_cache: bool = False,
|
462 |
+
debug: bool = False) -> torch.Tensor:
|
463 |
+
"""
|
464 |
+
Forward pass with optional streaming support via cache.
|
465 |
+
"""
|
466 |
+
B, C, T = x.shape
|
467 |
+
|
468 |
+
# Non-streaming mode
|
469 |
+
if not use_cache or cache is None:
|
470 |
+
return self._forward_non_streaming(x, debug=debug)
|
471 |
+
|
472 |
+
# Streaming mode
|
473 |
+
assert sample_indices is not None, "sample_indices must be provided for streaming mode"
|
474 |
+
assert len(sample_indices) == B, "sample_indices must match batch size"
|
475 |
+
|
476 |
+
return self._forward_streaming(x, cache, sample_indices, debug)
|
477 |
+
|
478 |
+
def _forward_streaming(self, x: torch.Tensor,
|
479 |
+
cache: VibeVoiceTokenizerStreamingCache,
|
480 |
+
sample_indices: torch.Tensor,
|
481 |
+
debug: bool = False) -> torch.Tensor:
|
482 |
+
"""Streaming forward pass with cache operations kept separate from compiled code"""
|
483 |
+
B, C, T = x.shape
|
484 |
+
|
485 |
+
# Cache operations (not compiled)
|
486 |
+
cached_input = cache.get(self.layer_id, sample_indices)
|
487 |
+
|
488 |
+
if cached_input is None:
|
489 |
+
# First chunk - no history yet
|
490 |
+
cached_input = torch.zeros(B, C, 0, device=x.device, dtype=x.dtype)
|
491 |
+
if debug:
|
492 |
+
print(f"[DEBUG] Initialized empty cache for transposed conv")
|
493 |
+
|
494 |
+
# Concatenate cached input with new input
|
495 |
+
full_input = torch.cat([cached_input, x], dim=2)
|
496 |
+
|
497 |
+
if debug:
|
498 |
+
print(f"[DEBUG] Input shape: {x.shape}, Cache shape: {cached_input.shape}, Combined: {full_input.shape}")
|
499 |
+
|
500 |
+
# First chunk or debug mode - use uncompiled version
|
501 |
+
full_output = self.convtr(full_input)
|
502 |
+
|
503 |
+
if debug:
|
504 |
+
print(f"[DEBUG] Full transposed conv output shape: {full_output.shape}")
|
505 |
+
|
506 |
+
# Calculate padding to remove
|
507 |
+
if self.causal:
|
508 |
+
padding_right = math.ceil(self.padding_total * self.trim_right_ratio)
|
509 |
+
padding_left = self.padding_total - padding_right
|
510 |
+
else:
|
511 |
+
padding_right = self.padding_total // 2
|
512 |
+
padding_left = self.padding_total - padding_right
|
513 |
+
|
514 |
+
# Remove padding
|
515 |
+
if padding_left + padding_right > 0:
|
516 |
+
full_output = unpad1d(full_output, (padding_left, padding_right))
|
517 |
+
|
518 |
+
if debug:
|
519 |
+
print(f"[DEBUG] After unpadding: {full_output.shape}")
|
520 |
+
|
521 |
+
# Determine which part of the output corresponds to the new input
|
522 |
+
if cached_input.shape[2] == 0:
|
523 |
+
# First chunk - return all output
|
524 |
+
output = full_output
|
525 |
+
else:
|
526 |
+
# Subsequent chunks - return only the new output
|
527 |
+
expected_new_output = T * self.stride
|
528 |
+
|
529 |
+
# Take the last expected_new_output samples
|
530 |
+
if full_output.shape[2] >= expected_new_output:
|
531 |
+
output = full_output[:, :, -expected_new_output:]
|
532 |
+
else:
|
533 |
+
output = full_output
|
534 |
+
|
535 |
+
if debug:
|
536 |
+
print(f"[DEBUG] Final streaming output shape: {output.shape}")
|
537 |
+
|
538 |
+
# Update cache
|
539 |
+
if full_input.shape[2] > self.context_size:
|
540 |
+
new_cache = full_input[:, :, -self.context_size:]
|
541 |
+
else:
|
542 |
+
new_cache = full_input
|
543 |
+
|
544 |
+
if debug:
|
545 |
+
print(f"[DEBUG] New cache shape: {new_cache.shape}")
|
546 |
+
|
547 |
+
cache.set(self.layer_id, sample_indices, new_cache)
|
548 |
+
|
549 |
+
return output
|
550 |
+
|
551 |
+
def _forward_non_streaming(self, x: torch.Tensor, debug: bool = False) -> torch.Tensor:
|
552 |
+
"""Standard forward pass without streaming"""
|
553 |
+
if debug:
|
554 |
+
print(f"[DEBUG NON-STREAMING] Input shape: {x.shape}")
|
555 |
+
|
556 |
+
# Apply transposed convolution
|
557 |
+
y = self.convtr(x)
|
558 |
+
|
559 |
+
if debug:
|
560 |
+
print(f"[DEBUG NON-STREAMING] After transposed conv: {y.shape}")
|
561 |
+
|
562 |
+
# Calculate and remove padding
|
563 |
+
if self.causal:
|
564 |
+
padding_right = math.ceil(self.padding_total * self.trim_right_ratio)
|
565 |
+
padding_left = self.padding_total - padding_right
|
566 |
+
else:
|
567 |
+
padding_right = self.padding_total // 2
|
568 |
+
padding_left = self.padding_total - padding_right
|
569 |
+
|
570 |
+
if padding_left + padding_right > 0:
|
571 |
+
y = unpad1d(y, (padding_left, padding_right))
|
572 |
+
|
573 |
+
if debug:
|
574 |
+
print(f"[DEBUG NON-STREAMING] Final output shape: {y.shape}")
|
575 |
+
|
576 |
+
return y
|
577 |
+
|
578 |
+
# FFN
|
579 |
+
class FFN(nn.Module):
|
580 |
+
def __init__(
|
581 |
+
self,
|
582 |
+
embed_dim,
|
583 |
+
ffn_dim,
|
584 |
+
bias=False,
|
585 |
+
):
|
586 |
+
super().__init__()
|
587 |
+
self.embed_dim = embed_dim
|
588 |
+
self.linear1 = nn.Linear(self.embed_dim, ffn_dim, bias=bias)
|
589 |
+
self.gelu = ACT2FN["gelu"]
|
590 |
+
self.linear2 = nn.Linear(ffn_dim, self.embed_dim, bias=bias)
|
591 |
+
|
592 |
+
def forward(self, x):
|
593 |
+
x = self.linear1(x)
|
594 |
+
x = self.gelu(x)
|
595 |
+
x = self.linear2(x)
|
596 |
+
return x
|
597 |
+
|
598 |
+
|
599 |
+
class Convlayer(nn.Module):
|
600 |
+
def __init__(
|
601 |
+
self,
|
602 |
+
in_channels,
|
603 |
+
out_channels,
|
604 |
+
kernel_size,
|
605 |
+
stride=1,
|
606 |
+
dilation=1,
|
607 |
+
groups=1,
|
608 |
+
bias=True,
|
609 |
+
pad_mode='zeros',
|
610 |
+
norm='weight_norm',
|
611 |
+
causal=True,
|
612 |
+
):
|
613 |
+
super().__init__()
|
614 |
+
self.conv = SConv1d(in_channels, out_channels, kernel_size, stride=stride, dilation=dilation,
|
615 |
+
groups=groups, bias=bias, pad_mode=pad_mode, norm=norm, causal=causal)
|
616 |
+
|
617 |
+
def forward(self, x):
|
618 |
+
return self.conv(x)
|
619 |
+
|
620 |
+
class Block1D(nn.Module):
|
621 |
+
def __init__(self, dim, kernel_size=7, drop_path=0., mixer_layer='conv',
|
622 |
+
layer_scale_init_value=1e-6, **kwargs):
|
623 |
+
super().__init__()
|
624 |
+
|
625 |
+
if kwargs.get('layernorm', 'LN') == 'LN':
|
626 |
+
self.norm = ConvLayerNorm(dim, eps=kwargs.get('eps', 1e-6))
|
627 |
+
self.ffn_norm = ConvLayerNorm(dim, eps=kwargs.get('eps', 1e-6))
|
628 |
+
elif kwargs.get('layernorm', 'RMSNorm') == 'RMSNorm':
|
629 |
+
self.norm = ConvRMSNorm(dim, eps=kwargs.get('eps', 1e-6))
|
630 |
+
self.ffn_norm = ConvRMSNorm(dim, eps=kwargs.get('eps', 1e-6))
|
631 |
+
|
632 |
+
if mixer_layer == 'conv':
|
633 |
+
self.mixer = Convlayer(dim, dim, groups=kwargs.get('groups', 1),
|
634 |
+
kernel_size=kernel_size,
|
635 |
+
pad_mode=kwargs.get('pad_mode', 'reflect'),
|
636 |
+
norm=kwargs.get('norm', 'none'),
|
637 |
+
causal=kwargs.get('causal', True),
|
638 |
+
bias=kwargs.get('bias', True),
|
639 |
+
)
|
640 |
+
elif mixer_layer == 'depthwise_conv':
|
641 |
+
self.mixer = Convlayer(dim, dim, groups=dim,
|
642 |
+
kernel_size=kernel_size,
|
643 |
+
pad_mode=kwargs.get('pad_mode', 'reflect'),
|
644 |
+
norm=kwargs.get('norm', 'none'),
|
645 |
+
causal=kwargs.get('causal', True),
|
646 |
+
bias=kwargs.get('bias', True),
|
647 |
+
)
|
648 |
+
else:
|
649 |
+
raise ValueError(f"Unsupported mixer layer: {mixer_layer}")
|
650 |
+
|
651 |
+
self.ffn = FFN(
|
652 |
+
dim,
|
653 |
+
kwargs.get('ffn_expansion', 4) * dim,
|
654 |
+
bias=kwargs.get('bias', False),
|
655 |
+
)
|
656 |
+
self.drop_path = nn.Identity() if drop_path <= 0. else nn.modules.DropPath(drop_path)
|
657 |
+
|
658 |
+
if layer_scale_init_value > 0:
|
659 |
+
self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
|
660 |
+
self.ffn_gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
|
661 |
+
else:
|
662 |
+
self.gamma = None
|
663 |
+
self.ffn_gamma = None
|
664 |
+
|
665 |
+
def forward(self, x):
|
666 |
+
# mixer
|
667 |
+
residual = x
|
668 |
+
x = self.norm(x)
|
669 |
+
x = self.mixer(x)
|
670 |
+
if self.gamma is not None:
|
671 |
+
x = x * self.gamma.unsqueeze(-1)
|
672 |
+
x = residual + self.drop_path(x)
|
673 |
+
|
674 |
+
# ffn
|
675 |
+
residual = x
|
676 |
+
x = self.ffn_norm(x)
|
677 |
+
x = x.permute(0, 2, 1)
|
678 |
+
x = self.ffn(x)
|
679 |
+
x = x.permute(0, 2, 1)
|
680 |
+
if self.ffn_gamma is not None:
|
681 |
+
x = x * self.ffn_gamma.unsqueeze(-1)
|
682 |
+
x = residual + self.drop_path(x)
|
683 |
+
|
684 |
+
return x
|
685 |
+
|
686 |
+
|
687 |
+
class TokenizerEncoder(nn.Module):
|
688 |
+
"""
|
689 |
+
Encoder component for the VibeVoice tokenizer that converts audio to latent representations.
|
690 |
+
|
691 |
+
Args:
|
692 |
+
config: Configuration object with model parameters
|
693 |
+
"""
|
694 |
+
def __init__(self, config):
|
695 |
+
super().__init__()
|
696 |
+
|
697 |
+
# Extract parameters from config
|
698 |
+
self.channels = config.channels
|
699 |
+
self.dimension = config.dimension
|
700 |
+
self.n_filters = config.n_filters
|
701 |
+
self.ratios = list(reversed(config.ratios))
|
702 |
+
self.depths = config.depths
|
703 |
+
self.n_residual_layers = getattr(config, "n_residual_layers", 1)
|
704 |
+
self.hop_length = np.prod(self.ratios)
|
705 |
+
self.causal = config.causal
|
706 |
+
|
707 |
+
# Additional config parameters with defaults
|
708 |
+
kernel_size = getattr(config, "kernel_size", 7)
|
709 |
+
last_kernel_size = getattr(config, "last_kernel_size", 7)
|
710 |
+
norm = getattr(config, "norm", "none")
|
711 |
+
norm_params = getattr(config, "norm_params", {})
|
712 |
+
pad_mode = getattr(config, "pad_mode", "reflect")
|
713 |
+
bias = getattr(config, "bias", True)
|
714 |
+
layernorm = getattr(config, "layernorm", "LN")
|
715 |
+
layernorm_eps = getattr(config, "layernorm_eps", 1e-6)
|
716 |
+
layernorm_elementwise_affine = getattr(config, "layernorm_elementwise_affine", True)
|
717 |
+
drop_path_rate = getattr(config, "drop_path_rate", 0.0)
|
718 |
+
mixer_layer = getattr(config, "mixer_layer", "conv")
|
719 |
+
layer_scale_init_value = getattr(config, "layer_scale_init_value", 0)
|
720 |
+
disable_last_norm = getattr(config, "disable_last_norm", False)
|
721 |
+
|
722 |
+
# determine the norm type based on layernorm
|
723 |
+
if layernorm == 'LN':
|
724 |
+
norm_type = ConvLayerNorm
|
725 |
+
elif layernorm == 'RMSNorm':
|
726 |
+
norm_type = partial(ConvRMSNorm, elementwise_affine=layernorm_elementwise_affine)
|
727 |
+
else:
|
728 |
+
raise ValueError(f"Unsupported norm type: {layernorm}")
|
729 |
+
|
730 |
+
# stem and intermediate downsampling conv layers
|
731 |
+
stem = nn.Sequential(
|
732 |
+
SConv1d(self.channels, self.n_filters, kernel_size, norm=norm, norm_kwargs=norm_params, causal=self.causal, pad_mode=pad_mode, bias=bias),
|
733 |
+
)
|
734 |
+
|
735 |
+
self.downsample_layers = nn.ModuleList()
|
736 |
+
self.downsample_layers.append(stem)
|
737 |
+
for i in range(len(self.ratios)):
|
738 |
+
in_ch = self.n_filters * (2 ** i)
|
739 |
+
out_ch = self.n_filters * (2 ** (i + 1))
|
740 |
+
downsample_layer = nn.Sequential(
|
741 |
+
SConv1d(in_ch, out_ch, kernel_size=self.ratios[i] * 2, stride=self.ratios[i], causal=self.causal, pad_mode=pad_mode, norm=norm, bias=bias)
|
742 |
+
)
|
743 |
+
self.downsample_layers.append(downsample_layer)
|
744 |
+
|
745 |
+
# configure the transformer blocks
|
746 |
+
layer_type = partial(
|
747 |
+
Block1D,
|
748 |
+
mixer_layer=mixer_layer,
|
749 |
+
layernorm=layernorm,
|
750 |
+
eps=layernorm_eps,
|
751 |
+
causal=self.causal,
|
752 |
+
pad_mode=pad_mode,
|
753 |
+
norm=norm,
|
754 |
+
bias=bias,
|
755 |
+
layer_scale_init_value=layer_scale_init_value,
|
756 |
+
)
|
757 |
+
|
758 |
+
self.stages = nn.ModuleList()
|
759 |
+
dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))]
|
760 |
+
cur = 0
|
761 |
+
|
762 |
+
for i in range(len(self.depths)):
|
763 |
+
in_ch = self.n_filters * (2 ** i)
|
764 |
+
stage = nn.Sequential(
|
765 |
+
*[layer_type(dim=in_ch, drop_path=dp_rates[cur + j]) for j in range(self.depths[i])]
|
766 |
+
)
|
767 |
+
self.stages.append(stage)
|
768 |
+
cur += self.depths[i]
|
769 |
+
|
770 |
+
if not disable_last_norm:
|
771 |
+
self.norm = norm_type(in_ch, eps=layernorm_eps)
|
772 |
+
else:
|
773 |
+
self.norm = nn.Identity()
|
774 |
+
self.head = SConv1d(in_ch, self.dimension, kernel_size=last_kernel_size, causal=self.causal, pad_mode=pad_mode, norm=norm, bias=bias)
|
775 |
+
|
776 |
+
def forward_features(self, x, cache=None, sample_indices=None, use_cache=False, debug=False):
|
777 |
+
for i in range(len(self.depths)):
|
778 |
+
# Apply downsampling
|
779 |
+
for layer in self.downsample_layers[i]:
|
780 |
+
if isinstance(layer, SConv1d):
|
781 |
+
x = layer(x, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
|
782 |
+
else:
|
783 |
+
x = layer(x)
|
784 |
+
|
785 |
+
# Apply stage (Block1D contains Convlayer which contains SConv1d)
|
786 |
+
for block in self.stages[i]:
|
787 |
+
if hasattr(block, 'mixer') and hasattr(block.mixer, 'conv') and isinstance(block.mixer.conv, SConv1d):
|
788 |
+
# Block1D forward with cache support
|
789 |
+
residual = x
|
790 |
+
x = block.norm(x)
|
791 |
+
x = block.mixer.conv(x, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
|
792 |
+
if block.gamma is not None:
|
793 |
+
x = x * block.gamma.unsqueeze(-1)
|
794 |
+
x = residual + x
|
795 |
+
|
796 |
+
# FFN part
|
797 |
+
residual = x
|
798 |
+
x = block.ffn_norm(x)
|
799 |
+
x = x.permute(0, 2, 1)
|
800 |
+
x = block.ffn(x)
|
801 |
+
x = x.permute(0, 2, 1)
|
802 |
+
if block.ffn_gamma is not None:
|
803 |
+
x = x * block.ffn_gamma.unsqueeze(-1)
|
804 |
+
x = residual + x
|
805 |
+
else:
|
806 |
+
x = block(x)
|
807 |
+
|
808 |
+
return self.norm(x)
|
809 |
+
|
810 |
+
def forward(self, x, cache=None, sample_indices=None, use_cache=False, debug=False):
|
811 |
+
x = self.forward_features(x, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
|
812 |
+
x = self.head(x, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
|
813 |
+
return x
|
814 |
+
|
815 |
+
|
816 |
+
class TokenizerDecoder(nn.Module):
|
817 |
+
"""
|
818 |
+
Decoder component for the VibeVoice tokenizer that converts latent representations back to audio.
|
819 |
+
|
820 |
+
Args:
|
821 |
+
config: Configuration object with model parameters
|
822 |
+
"""
|
823 |
+
def __init__(self, config):
|
824 |
+
super().__init__()
|
825 |
+
|
826 |
+
# Extract parameters from config
|
827 |
+
self.dimension = config.dimension
|
828 |
+
self.channels = config.channels
|
829 |
+
self.n_filters = config.n_filters
|
830 |
+
self.ratios = config.ratios
|
831 |
+
|
832 |
+
# IMPORTANT CHANGE: Don't reverse depths again since they're already reversed in VibeVoiceAcousticTokenizerModel
|
833 |
+
self.depths = config.depths # Changed from list(reversed(config.depths))
|
834 |
+
|
835 |
+
self.n_residual_layers = getattr(config, "n_residual_layers", 1)
|
836 |
+
self.hop_length = np.prod(self.ratios)
|
837 |
+
self.causal = config.causal
|
838 |
+
|
839 |
+
# Additional config parameters with defaults
|
840 |
+
kernel_size = getattr(config, "kernel_size", 7)
|
841 |
+
last_kernel_size = getattr(config, "last_kernel_size", 7)
|
842 |
+
norm = getattr(config, "norm", "none")
|
843 |
+
norm_params = getattr(config, "norm_params", {})
|
844 |
+
pad_mode = getattr(config, "pad_mode", "reflect")
|
845 |
+
bias = getattr(config, "bias", True)
|
846 |
+
layernorm = getattr(config, "layernorm", "LN")
|
847 |
+
layernorm_eps = getattr(config, "layernorm_eps", 1e-6)
|
848 |
+
trim_right_ratio = getattr(config, "trim_right_ratio", 1.0)
|
849 |
+
layernorm_elementwise_affine = getattr(config, "layernorm_elementwise_affine", True)
|
850 |
+
drop_path_rate = getattr(config, "drop_path_rate", 0.0)
|
851 |
+
mixer_layer = getattr(config, "mixer_layer", "conv")
|
852 |
+
layer_scale_init_value = getattr(config, "layer_scale_init_value", 0)
|
853 |
+
disable_last_norm = getattr(config, "disable_last_norm", False)
|
854 |
+
|
855 |
+
# determine the norm type based on layernorm
|
856 |
+
if layernorm == 'LN':
|
857 |
+
norm_type = ConvLayerNorm
|
858 |
+
elif layernorm == 'RMSNorm':
|
859 |
+
norm_type = partial(ConvRMSNorm, elementwise_affine=layernorm_elementwise_affine)
|
860 |
+
else:
|
861 |
+
raise ValueError(f"Unsupported norm type: {layernorm}")
|
862 |
+
|
863 |
+
# stem and upsampling layers
|
864 |
+
stem = nn.Sequential(
|
865 |
+
SConv1d(self.dimension, self.n_filters * 2 ** (len(self.depths) - 1), kernel_size, norm=norm,
|
866 |
+
norm_kwargs=norm_params, causal=self.causal, pad_mode=pad_mode, bias=bias),
|
867 |
+
)
|
868 |
+
|
869 |
+
self.upsample_layers = nn.ModuleList()
|
870 |
+
self.upsample_layers.append(stem)
|
871 |
+
for i in range(len(self.ratios)):
|
872 |
+
in_ch = self.n_filters * (2 ** (len(self.depths) - 1 - i))
|
873 |
+
out_ch = self.n_filters * (2 ** (len(self.depths) - 1 - i - 1))
|
874 |
+
upsample_layer = nn.Sequential(
|
875 |
+
SConvTranspose1d(in_ch, out_ch,
|
876 |
+
kernel_size=self.ratios[i] * 2, stride=self.ratios[i],
|
877 |
+
norm=norm, norm_kwargs=norm_params, bias=bias,
|
878 |
+
causal=self.causal, trim_right_ratio=trim_right_ratio),
|
879 |
+
)
|
880 |
+
self.upsample_layers.append(upsample_layer)
|
881 |
+
|
882 |
+
# configure transformer blocks
|
883 |
+
layer_type = partial(
|
884 |
+
Block1D,
|
885 |
+
mixer_layer=mixer_layer,
|
886 |
+
layernorm=layernorm,
|
887 |
+
eps=layernorm_eps,
|
888 |
+
causal=self.causal,
|
889 |
+
pad_mode=pad_mode,
|
890 |
+
norm=norm,
|
891 |
+
bias=bias,
|
892 |
+
layer_scale_init_value=layer_scale_init_value,
|
893 |
+
)
|
894 |
+
|
895 |
+
self.stages = nn.ModuleList()
|
896 |
+
dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))]
|
897 |
+
cur = 0
|
898 |
+
|
899 |
+
# Create stages in the same order as the original model
|
900 |
+
for i in range(len(self.depths)):
|
901 |
+
in_ch = self.n_filters * (2 ** (len(self.depths) - 1 - i))
|
902 |
+
stage = nn.Sequential(
|
903 |
+
*[layer_type(dim=in_ch, drop_path=dp_rates[cur + j]) for j in range(self.depths[i])]
|
904 |
+
)
|
905 |
+
self.stages.append(stage)
|
906 |
+
cur += self.depths[i]
|
907 |
+
|
908 |
+
if not disable_last_norm:
|
909 |
+
self.norm = norm_type(in_ch, eps=layernorm_eps)
|
910 |
+
else:
|
911 |
+
self.norm = nn.Identity()
|
912 |
+
self.head = SConv1d(in_ch, self.channels, kernel_size=last_kernel_size, causal=self.causal, pad_mode=pad_mode, norm=norm, bias=bias)
|
913 |
+
|
914 |
+
def forward_features(self, x, cache=None, sample_indices=None, use_cache=False, debug=False):
|
915 |
+
for i in range(len(self.depths)):
|
916 |
+
# Apply upsampling
|
917 |
+
for layer in self.upsample_layers[i]:
|
918 |
+
if isinstance(layer, (SConv1d, SConvTranspose1d)):
|
919 |
+
x = layer(x, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
|
920 |
+
else:
|
921 |
+
x = layer(x)
|
922 |
+
|
923 |
+
# Apply stage (Block1D contains Convlayer which contains SConv1d)
|
924 |
+
for block in self.stages[i]:
|
925 |
+
if hasattr(block, 'mixer') and hasattr(block.mixer, 'conv') and isinstance(block.mixer.conv, SConv1d):
|
926 |
+
# Block1D forward with cache support
|
927 |
+
residual = x
|
928 |
+
x = block.norm(x)
|
929 |
+
x = block.mixer.conv(x, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
|
930 |
+
if block.gamma is not None:
|
931 |
+
x = x * block.gamma.unsqueeze(-1)
|
932 |
+
x = residual + x
|
933 |
+
|
934 |
+
# FFN part
|
935 |
+
residual = x
|
936 |
+
x = block.ffn_norm(x)
|
937 |
+
x = x.permute(0, 2, 1)
|
938 |
+
x = block.ffn(x)
|
939 |
+
x = x.permute(0, 2, 1)
|
940 |
+
if block.ffn_gamma is not None:
|
941 |
+
x = x * block.ffn_gamma.unsqueeze(-1)
|
942 |
+
x = residual + x
|
943 |
+
else:
|
944 |
+
x = block(x)
|
945 |
+
|
946 |
+
return self.norm(x)
|
947 |
+
|
948 |
+
def forward(self, x, cache=None, sample_indices=None, use_cache=False, debug=False):
|
949 |
+
x = self.forward_features(x, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
|
950 |
+
x = self.head(x, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
|
951 |
+
return x
|
952 |
+
|
953 |
+
|
954 |
+
@dataclass
|
955 |
+
class VibeVoiceTokenizerEncoderOutput:
|
956 |
+
"""
|
957 |
+
Output of VibeVoice tokenizer encoder, representing a Gaussian distribution with fixed variance.
|
958 |
+
|
959 |
+
Args:
|
960 |
+
mean (`torch.FloatTensor`): The mean parameters of the distribution.
|
961 |
+
std (`float` or `torch.FloatTensor`): Fixed standard deviation value.
|
962 |
+
"""
|
963 |
+
mean: torch.Tensor
|
964 |
+
std: Optional[Union[float, torch.Tensor]] = None
|
965 |
+
|
966 |
+
def sample(self, dist_type='fix'):
|
967 |
+
"""
|
968 |
+
Sample from the distribution.
|
969 |
+
|
970 |
+
Args:
|
971 |
+
dist_type (`str`): Sampling method, either 'fix' or 'gaussian'.
|
972 |
+
|
973 |
+
Returns:
|
974 |
+
`torch.FloatTensor`: Sampled values.
|
975 |
+
`torch.FloatTensor` (optional): Standard deviation used (only when dist_type='gaussian').
|
976 |
+
"""
|
977 |
+
if dist_type == 'fix':
|
978 |
+
x = self.mean + self.std * torch.randn_like(self.mean)
|
979 |
+
return x, self.std
|
980 |
+
elif dist_type == 'gaussian':
|
981 |
+
batch_size = self.mean.size(0)
|
982 |
+
value = self.std / 0.8
|
983 |
+
std = torch.randn(batch_size, device=self.mean.device, dtype=self.mean.dtype) * value
|
984 |
+
|
985 |
+
while std.dim() < self.mean.dim():
|
986 |
+
std = std.unsqueeze(-1)
|
987 |
+
|
988 |
+
x = self.mean + std * torch.randn_like(self.mean)
|
989 |
+
return x, std
|
990 |
+
else:
|
991 |
+
return self.mean, self.std
|
992 |
+
|
993 |
+
def kl(self):
|
994 |
+
"""Compute KL divergence between this distribution and a standard normal."""
|
995 |
+
target = torch.zeros_like(self.mean)
|
996 |
+
return F.mse_loss(self.mean, target, reduction='none')
|
997 |
+
|
998 |
+
def mode(self):
|
999 |
+
"""Return the distribution mode (which is the mean for Gaussian)."""
|
1000 |
+
return self.mean
|
1001 |
+
|
1002 |
+
class VibeVoiceAcousticTokenizerModel(PreTrainedModel):
|
1003 |
+
"""VibeVoice speech tokenizer model combining encoder and decoder for acoustic tokens"""
|
1004 |
+
|
1005 |
+
config_class = VibeVoiceAcousticTokenizerConfig
|
1006 |
+
base_model_prefix = "vibevoice_acoustic_tokenizer"
|
1007 |
+
_supports_flash_attn_2 = True
|
1008 |
+
_supports_sdpa = True
|
1009 |
+
_no_split_modules = ["TokenizerEncoder", "TokenizerDecoder"]
|
1010 |
+
|
1011 |
+
def __init__(self, config):
|
1012 |
+
super().__init__(config)
|
1013 |
+
|
1014 |
+
self.register_buffer('fix_std', torch.tensor(config.fix_std), persistent=False)
|
1015 |
+
self.std_dist_type = getattr(config, "std_dist_type", "fix")
|
1016 |
+
|
1017 |
+
# Parse encoder depths
|
1018 |
+
if isinstance(config.encoder_depths, str):
|
1019 |
+
encoder_depths = [int(d) for d in config.encoder_depths.split('-')]
|
1020 |
+
else:
|
1021 |
+
encoder_depths = config.encoder_depths
|
1022 |
+
|
1023 |
+
# Parse decoder depths if provided
|
1024 |
+
if config.decoder_depths is not None and isinstance(config.decoder_depths, str):
|
1025 |
+
decoder_depths = [int(d) for d in config.decoder_depths.split('-')]
|
1026 |
+
else:
|
1027 |
+
# Default: use reversed encoder depths if decoder_depths is None
|
1028 |
+
decoder_depths = list(reversed(encoder_depths))
|
1029 |
+
|
1030 |
+
# Create encoder config
|
1031 |
+
encoder_config = copy.deepcopy(config)
|
1032 |
+
encoder_config.dimension = config.vae_dim
|
1033 |
+
encoder_config.n_filters = config.encoder_n_filters
|
1034 |
+
encoder_config.ratios = config.encoder_ratios
|
1035 |
+
encoder_config.depths = encoder_depths
|
1036 |
+
encoder_config.norm = config.conv_norm
|
1037 |
+
encoder_config.pad_mode = config.pad_mode
|
1038 |
+
encoder_config.bias = config.conv_bias
|
1039 |
+
encoder_config.layernorm_eps = config.layernorm_eps
|
1040 |
+
encoder_config.layernorm_elementwise_affine = config.layernorm_elementwise_affine
|
1041 |
+
encoder_config.mixer_layer = config.mixer_layer
|
1042 |
+
encoder_config.layer_scale_init_value = config.layer_scale_init_value
|
1043 |
+
encoder_config.disable_last_norm = config.disable_last_norm
|
1044 |
+
|
1045 |
+
# Create decoder config
|
1046 |
+
decoder_config = copy.deepcopy(config)
|
1047 |
+
decoder_config.dimension = config.vae_dim
|
1048 |
+
decoder_config.n_filters = config.decoder_n_filters
|
1049 |
+
decoder_config.ratios = config.decoder_ratios
|
1050 |
+
decoder_config.depths = decoder_depths
|
1051 |
+
decoder_config.norm = config.conv_norm
|
1052 |
+
decoder_config.pad_mode = config.pad_mode
|
1053 |
+
decoder_config.bias = config.conv_bias
|
1054 |
+
decoder_config.layernorm_eps = config.layernorm_eps
|
1055 |
+
decoder_config.layernorm_elementwise_affine = config.layernorm_elementwise_affine
|
1056 |
+
decoder_config.mixer_layer = config.mixer_layer
|
1057 |
+
decoder_config.layer_scale_init_value = config.layer_scale_init_value
|
1058 |
+
decoder_config.disable_last_norm = config.disable_last_norm
|
1059 |
+
|
1060 |
+
# Initialize encoder and decoder
|
1061 |
+
self.encoder = TokenizerEncoder(encoder_config)
|
1062 |
+
self.decoder = TokenizerDecoder(decoder_config)
|
1063 |
+
|
1064 |
+
# Initialize weights
|
1065 |
+
self.apply(self._init_weights)
|
1066 |
+
|
1067 |
+
def _init_weights(self, module):
|
1068 |
+
"""Initialize weights for the model"""
|
1069 |
+
if isinstance(module, nn.Linear):
|
1070 |
+
nn.init.normal_(module.weight, std=self.config.weight_init_value)
|
1071 |
+
if module.bias is not None:
|
1072 |
+
nn.init.zeros_(module.bias)
|
1073 |
+
elif isinstance(module, nn.LayerNorm):
|
1074 |
+
nn.init.ones_(module.weight)
|
1075 |
+
nn.init.zeros_(module.bias)
|
1076 |
+
elif isinstance(module, nn.Conv1d):
|
1077 |
+
nn.init.normal_(module.weight, std=self.config.weight_init_value)
|
1078 |
+
if module.bias is not None:
|
1079 |
+
nn.init.zeros_(module.bias)
|
1080 |
+
|
1081 |
+
@torch.no_grad()
|
1082 |
+
def encode(self, audio, cache=None, sample_indices=None, use_cache=False, debug=False):
|
1083 |
+
"""Convert audio to latent representations"""
|
1084 |
+
latents = self.encoder(audio, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
|
1085 |
+
return VibeVoiceTokenizerEncoderOutput(mean=latents.permute(0, 2, 1), std=self.fix_std)
|
1086 |
+
|
1087 |
+
@torch.no_grad()
|
1088 |
+
def sampling(self, encoder_output, dist_type=None):
|
1089 |
+
"""Sample from the encoder output distribution"""
|
1090 |
+
dist_type = dist_type or self.std_dist_type
|
1091 |
+
|
1092 |
+
if dist_type == 'fix':
|
1093 |
+
return encoder_output.sample(dist_type='fix')
|
1094 |
+
elif dist_type == 'gaussian':
|
1095 |
+
return encoder_output.sample(dist_type='gaussian')
|
1096 |
+
else:
|
1097 |
+
raise ValueError(f"Unsupported dist_type: {dist_type}, expected 'fix' or 'gaussian'")
|
1098 |
+
|
1099 |
+
@torch.no_grad()
|
1100 |
+
def decode(self, latents, cache=None, sample_indices=None, use_cache=False, debug=False):
|
1101 |
+
"""Convert latent representations back to audio"""
|
1102 |
+
if latents.shape[1] == self.config.vae_dim:
|
1103 |
+
pass
|
1104 |
+
else:
|
1105 |
+
latents = latents.permute(0, 2, 1)
|
1106 |
+
|
1107 |
+
audio = self.decoder(latents, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
|
1108 |
+
return audio
|
1109 |
+
|
1110 |
+
def forward(self, audio, cache=None, sample_indices=None, use_cache=False, debug=False):
|
1111 |
+
"""Full forward pass: encode audio to latents, then decode back to audio"""
|
1112 |
+
encoder_output = self.encode(audio, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
|
1113 |
+
sampled_latents, _ = self.sampling(encoder_output)
|
1114 |
+
reconstructed = self.decode(sampled_latents, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
|
1115 |
+
return reconstructed, sampled_latents
|
1116 |
+
|
1117 |
+
|
1118 |
+
class VibeVoiceSemanticTokenizerModel(PreTrainedModel):
|
1119 |
+
"""VibeVoice speech tokenizer model with only encoder for semantic tokens"""
|
1120 |
+
|
1121 |
+
config_class = VibeVoiceSemanticTokenizerConfig
|
1122 |
+
base_model_prefix = "vibevoice_semantic_tokenizer"
|
1123 |
+
_supports_flash_attn_2 = True
|
1124 |
+
_supports_sdpa = True
|
1125 |
+
_no_split_modules = ["TokenizerEncoder"]
|
1126 |
+
|
1127 |
+
def __init__(self, config):
|
1128 |
+
super().__init__(config)
|
1129 |
+
|
1130 |
+
# Parse encoder depths
|
1131 |
+
if isinstance(config.encoder_depths, str):
|
1132 |
+
encoder_depths = [int(d) for d in config.encoder_depths.split('-')]
|
1133 |
+
else:
|
1134 |
+
encoder_depths = config.encoder_depths
|
1135 |
+
|
1136 |
+
# Create encoder config
|
1137 |
+
encoder_config = copy.deepcopy(config)
|
1138 |
+
encoder_config.dimension = config.vae_dim
|
1139 |
+
encoder_config.n_filters = config.encoder_n_filters
|
1140 |
+
encoder_config.ratios = config.encoder_ratios
|
1141 |
+
encoder_config.depths = encoder_depths
|
1142 |
+
encoder_config.norm = config.conv_norm
|
1143 |
+
encoder_config.pad_mode = config.pad_mode
|
1144 |
+
encoder_config.bias = config.conv_bias
|
1145 |
+
encoder_config.layernorm_eps = config.layernorm_eps
|
1146 |
+
encoder_config.layernorm_elementwise_affine = config.layernorm_elementwise_affine
|
1147 |
+
encoder_config.mixer_layer = config.mixer_layer
|
1148 |
+
encoder_config.layer_scale_init_value = config.layer_scale_init_value
|
1149 |
+
encoder_config.disable_last_norm = config.disable_last_norm
|
1150 |
+
|
1151 |
+
# Initialize encoder and decoder
|
1152 |
+
self.encoder = TokenizerEncoder(encoder_config)
|
1153 |
+
|
1154 |
+
# Initialize weights
|
1155 |
+
self.apply(self._init_weights)
|
1156 |
+
|
1157 |
+
def _init_weights(self, module):
|
1158 |
+
"""Initialize weights for the model"""
|
1159 |
+
if isinstance(module, nn.Linear):
|
1160 |
+
nn.init.normal_(module.weight, std=self.config.weight_init_value)
|
1161 |
+
if module.bias is not None:
|
1162 |
+
nn.init.zeros_(module.bias)
|
1163 |
+
elif isinstance(module, nn.LayerNorm):
|
1164 |
+
nn.init.ones_(module.weight)
|
1165 |
+
nn.init.zeros_(module.bias)
|
1166 |
+
elif isinstance(module, nn.Conv1d):
|
1167 |
+
nn.init.normal_(module.weight, std=self.config.weight_init_value)
|
1168 |
+
if module.bias is not None:
|
1169 |
+
nn.init.zeros_(module.bias)
|
1170 |
+
|
1171 |
+
@torch.no_grad()
|
1172 |
+
def encode(self, audio, cache=None, sample_indices=None, use_cache=False, debug=False):
|
1173 |
+
"""Convert audio to latent representations"""
|
1174 |
+
latents = self.encoder(audio, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
|
1175 |
+
return VibeVoiceTokenizerEncoderOutput(mean=latents.permute(0, 2, 1))
|
1176 |
+
|
1177 |
+
@torch.no_grad()
|
1178 |
+
def sampling(self, encoder_output, dist_type=None):
|
1179 |
+
"""Sample from the encoder output distribution"""
|
1180 |
+
return encoder_output.sample(dist_type='none')
|
1181 |
+
|
1182 |
+
def forward(self, audio, cache=None, sample_indices=None, use_cache=False, debug=False):
|
1183 |
+
"""Full forward pass: encode audio to latents, then decode back to audio"""
|
1184 |
+
encoder_output = self.encode(audio, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
|
1185 |
+
sampled_latents, _ = self.sampling(encoder_output, dist_type='none')
|
1186 |
+
return None, sampled_latents
|
1187 |
+
|
1188 |
+
AutoModel.register(VibeVoiceAcousticTokenizerConfig, VibeVoiceAcousticTokenizerModel)
|
1189 |
+
AutoModel.register(VibeVoiceSemanticTokenizerConfig, VibeVoiceSemanticTokenizerModel)
|
1190 |
+
|
1191 |
+
__all__ = [
|
1192 |
+
"VibeVoiceTokenizerStreamingCache",
|
1193 |
+
"VibeVoiceAcousticTokenizerModel",
|
1194 |
+
"VibeVoiceSemanticTokenizerModel",
|
1195 |
+
]
|
modular/streamer.py
ADDED
@@ -0,0 +1,264 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
import asyncio
|
6 |
+
from queue import Queue
|
7 |
+
from typing import TYPE_CHECKING, Optional
|
8 |
+
|
9 |
+
|
10 |
+
from transformers.generation import BaseStreamer
|
11 |
+
|
12 |
+
|
13 |
+
class AudioStreamer(BaseStreamer):
|
14 |
+
"""
|
15 |
+
Audio streamer that stores audio chunks in queues for each sample in the batch.
|
16 |
+
This allows streaming audio generation for multiple samples simultaneously.
|
17 |
+
|
18 |
+
Parameters:
|
19 |
+
batch_size (`int`):
|
20 |
+
The batch size for generation
|
21 |
+
stop_signal (`any`, *optional*):
|
22 |
+
The signal to put in the queue when generation ends. Defaults to None.
|
23 |
+
timeout (`float`, *optional*):
|
24 |
+
The timeout for the audio queue. If `None`, the queue will block indefinitely.
|
25 |
+
"""
|
26 |
+
|
27 |
+
def __init__(
|
28 |
+
self,
|
29 |
+
batch_size: int,
|
30 |
+
stop_signal: Optional[any] = None,
|
31 |
+
timeout: Optional[float] = None,
|
32 |
+
):
|
33 |
+
self.batch_size = batch_size
|
34 |
+
self.stop_signal = stop_signal
|
35 |
+
self.timeout = timeout
|
36 |
+
|
37 |
+
# Create a queue for each sample in the batch
|
38 |
+
self.audio_queues = [Queue() for _ in range(batch_size)]
|
39 |
+
self.finished_flags = [False for _ in range(batch_size)]
|
40 |
+
self.sample_indices_map = {} # Maps from sample index to queue index
|
41 |
+
|
42 |
+
def put(self, audio_chunks: torch.Tensor, sample_indices: torch.Tensor):
|
43 |
+
"""
|
44 |
+
Receives audio chunks and puts them in the appropriate queues.
|
45 |
+
|
46 |
+
Args:
|
47 |
+
audio_chunks: Tensor of shape (num_samples, ...) containing audio chunks
|
48 |
+
sample_indices: Tensor indicating which samples these chunks belong to
|
49 |
+
"""
|
50 |
+
for i, sample_idx in enumerate(sample_indices):
|
51 |
+
idx = sample_idx.item()
|
52 |
+
if idx < self.batch_size and not self.finished_flags[idx]:
|
53 |
+
# Convert to numpy or keep as tensor based on preference
|
54 |
+
audio_chunk = audio_chunks[i].detach().cpu()
|
55 |
+
self.audio_queues[idx].put(audio_chunk, timeout=self.timeout)
|
56 |
+
|
57 |
+
def end(self, sample_indices: Optional[torch.Tensor] = None):
|
58 |
+
"""
|
59 |
+
Signals the end of generation for specified samples or all samples.
|
60 |
+
|
61 |
+
Args:
|
62 |
+
sample_indices: Optional tensor of sample indices to end. If None, ends all.
|
63 |
+
"""
|
64 |
+
if sample_indices is None:
|
65 |
+
# End all samples
|
66 |
+
for idx in range(self.batch_size):
|
67 |
+
if not self.finished_flags[idx]:
|
68 |
+
self.audio_queues[idx].put(self.stop_signal, timeout=self.timeout)
|
69 |
+
self.finished_flags[idx] = True
|
70 |
+
else:
|
71 |
+
# End specific samples
|
72 |
+
for sample_idx in sample_indices:
|
73 |
+
idx = sample_idx.item() if torch.is_tensor(sample_idx) else sample_idx
|
74 |
+
if idx < self.batch_size and not self.finished_flags[idx]:
|
75 |
+
self.audio_queues[idx].put(self.stop_signal, timeout=self.timeout)
|
76 |
+
self.finished_flags[idx] = True
|
77 |
+
|
78 |
+
def __iter__(self):
|
79 |
+
"""Returns an iterator over the batch of audio streams."""
|
80 |
+
return AudioBatchIterator(self)
|
81 |
+
|
82 |
+
def get_stream(self, sample_idx: int):
|
83 |
+
"""Get the audio stream for a specific sample."""
|
84 |
+
if sample_idx >= self.batch_size:
|
85 |
+
raise ValueError(f"Sample index {sample_idx} exceeds batch size {self.batch_size}")
|
86 |
+
return AudioSampleIterator(self, sample_idx)
|
87 |
+
|
88 |
+
|
89 |
+
class AudioSampleIterator:
|
90 |
+
"""Iterator for a single audio stream from the batch."""
|
91 |
+
|
92 |
+
def __init__(self, streamer: AudioStreamer, sample_idx: int):
|
93 |
+
self.streamer = streamer
|
94 |
+
self.sample_idx = sample_idx
|
95 |
+
|
96 |
+
def __iter__(self):
|
97 |
+
return self
|
98 |
+
|
99 |
+
def __next__(self):
|
100 |
+
value = self.streamer.audio_queues[self.sample_idx].get(timeout=self.streamer.timeout)
|
101 |
+
if value == self.streamer.stop_signal:
|
102 |
+
raise StopIteration()
|
103 |
+
return value
|
104 |
+
|
105 |
+
|
106 |
+
class AudioBatchIterator:
|
107 |
+
"""Iterator that yields audio chunks for all samples in the batch."""
|
108 |
+
|
109 |
+
def __init__(self, streamer: AudioStreamer):
|
110 |
+
self.streamer = streamer
|
111 |
+
self.active_samples = set(range(streamer.batch_size))
|
112 |
+
|
113 |
+
def __iter__(self):
|
114 |
+
return self
|
115 |
+
|
116 |
+
def __next__(self):
|
117 |
+
if not self.active_samples:
|
118 |
+
raise StopIteration()
|
119 |
+
|
120 |
+
batch_chunks = {}
|
121 |
+
samples_to_remove = set()
|
122 |
+
|
123 |
+
# Try to get chunks from all active samples
|
124 |
+
for idx in self.active_samples:
|
125 |
+
try:
|
126 |
+
value = self.streamer.audio_queues[idx].get(block=False)
|
127 |
+
if value == self.streamer.stop_signal:
|
128 |
+
samples_to_remove.add(idx)
|
129 |
+
else:
|
130 |
+
batch_chunks[idx] = value
|
131 |
+
except:
|
132 |
+
# Queue is empty for this sample, skip it this iteration
|
133 |
+
pass
|
134 |
+
|
135 |
+
# Remove finished samples
|
136 |
+
self.active_samples -= samples_to_remove
|
137 |
+
|
138 |
+
if batch_chunks:
|
139 |
+
return batch_chunks
|
140 |
+
elif self.active_samples:
|
141 |
+
# If no chunks were ready but we still have active samples,
|
142 |
+
# wait a bit and try again
|
143 |
+
import time
|
144 |
+
time.sleep(0.01)
|
145 |
+
return self.__next__()
|
146 |
+
else:
|
147 |
+
raise StopIteration()
|
148 |
+
|
149 |
+
|
150 |
+
class AsyncAudioStreamer(AudioStreamer):
|
151 |
+
"""
|
152 |
+
Async version of AudioStreamer for use in async contexts.
|
153 |
+
"""
|
154 |
+
|
155 |
+
def __init__(
|
156 |
+
self,
|
157 |
+
batch_size: int,
|
158 |
+
stop_signal: Optional[any] = None,
|
159 |
+
timeout: Optional[float] = None,
|
160 |
+
):
|
161 |
+
super().__init__(batch_size, stop_signal, timeout)
|
162 |
+
# Replace regular queues with async queues
|
163 |
+
self.audio_queues = [asyncio.Queue() for _ in range(batch_size)]
|
164 |
+
self.loop = asyncio.get_running_loop()
|
165 |
+
|
166 |
+
def put(self, audio_chunks: torch.Tensor, sample_indices: torch.Tensor):
|
167 |
+
"""Put audio chunks in the appropriate async queues."""
|
168 |
+
for i, sample_idx in enumerate(sample_indices):
|
169 |
+
idx = sample_idx.item()
|
170 |
+
if idx < self.batch_size and not self.finished_flags[idx]:
|
171 |
+
audio_chunk = audio_chunks[i].detach().cpu()
|
172 |
+
self.loop.call_soon_threadsafe(
|
173 |
+
self.audio_queues[idx].put_nowait, audio_chunk
|
174 |
+
)
|
175 |
+
|
176 |
+
def end(self, sample_indices: Optional[torch.Tensor] = None):
|
177 |
+
"""Signal the end of generation for specified samples."""
|
178 |
+
if sample_indices is None:
|
179 |
+
indices_to_end = range(self.batch_size)
|
180 |
+
else:
|
181 |
+
indices_to_end = [s.item() if torch.is_tensor(s) else s for s in sample_indices]
|
182 |
+
|
183 |
+
for idx in indices_to_end:
|
184 |
+
if idx < self.batch_size and not self.finished_flags[idx]:
|
185 |
+
self.loop.call_soon_threadsafe(
|
186 |
+
self.audio_queues[idx].put_nowait, self.stop_signal
|
187 |
+
)
|
188 |
+
self.finished_flags[idx] = True
|
189 |
+
|
190 |
+
async def get_stream(self, sample_idx: int):
|
191 |
+
"""Get async iterator for a specific sample's audio stream."""
|
192 |
+
if sample_idx >= self.batch_size:
|
193 |
+
raise ValueError(f"Sample index {sample_idx} exceeds batch size {self.batch_size}")
|
194 |
+
|
195 |
+
while True:
|
196 |
+
value = await self.audio_queues[sample_idx].get()
|
197 |
+
if value == self.stop_signal:
|
198 |
+
break
|
199 |
+
yield value
|
200 |
+
|
201 |
+
def __aiter__(self):
|
202 |
+
"""Returns an async iterator over all audio streams."""
|
203 |
+
return AsyncAudioBatchIterator(self)
|
204 |
+
|
205 |
+
|
206 |
+
class AsyncAudioBatchIterator:
|
207 |
+
"""Async iterator for batch audio streaming."""
|
208 |
+
|
209 |
+
def __init__(self, streamer: AsyncAudioStreamer):
|
210 |
+
self.streamer = streamer
|
211 |
+
self.active_samples = set(range(streamer.batch_size))
|
212 |
+
|
213 |
+
def __aiter__(self):
|
214 |
+
return self
|
215 |
+
|
216 |
+
async def __anext__(self):
|
217 |
+
if not self.active_samples:
|
218 |
+
raise StopAsyncIteration()
|
219 |
+
|
220 |
+
batch_chunks = {}
|
221 |
+
samples_to_remove = set()
|
222 |
+
|
223 |
+
# Create tasks for all active samples
|
224 |
+
tasks = {
|
225 |
+
idx: asyncio.create_task(self._get_chunk(idx))
|
226 |
+
for idx in self.active_samples
|
227 |
+
}
|
228 |
+
|
229 |
+
# Wait for at least one chunk to be ready
|
230 |
+
done, pending = await asyncio.wait(
|
231 |
+
tasks.values(),
|
232 |
+
return_when=asyncio.FIRST_COMPLETED,
|
233 |
+
timeout=self.streamer.timeout
|
234 |
+
)
|
235 |
+
|
236 |
+
# Cancel pending tasks
|
237 |
+
for task in pending:
|
238 |
+
task.cancel()
|
239 |
+
|
240 |
+
# Process completed tasks
|
241 |
+
for idx, task in tasks.items():
|
242 |
+
if task in done:
|
243 |
+
try:
|
244 |
+
value = await task
|
245 |
+
if value == self.streamer.stop_signal:
|
246 |
+
samples_to_remove.add(idx)
|
247 |
+
else:
|
248 |
+
batch_chunks[idx] = value
|
249 |
+
except asyncio.CancelledError:
|
250 |
+
pass
|
251 |
+
|
252 |
+
self.active_samples -= samples_to_remove
|
253 |
+
|
254 |
+
if batch_chunks:
|
255 |
+
return batch_chunks
|
256 |
+
elif self.active_samples:
|
257 |
+
# Try again if we still have active samples
|
258 |
+
return await self.__anext__()
|
259 |
+
else:
|
260 |
+
raise StopAsyncIteration()
|
261 |
+
|
262 |
+
async def _get_chunk(self, idx):
|
263 |
+
"""Helper to get a chunk from a specific queue."""
|
264 |
+
return await self.streamer.audio_queues[idx].get()
|
processor/__init__.py
ADDED
File without changes
|
processor/vibevoice_processor.py
ADDED
@@ -0,0 +1,677 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import warnings
|
3 |
+
from typing import List, Optional, Union, Dict, Any, Tuple
|
4 |
+
import os
|
5 |
+
import re
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
|
10 |
+
from transformers.tokenization_utils_base import BatchEncoding, PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
|
11 |
+
from transformers.utils import TensorType, logging
|
12 |
+
from .vibevoice_tokenizer_processor import AudioNormalizer
|
13 |
+
|
14 |
+
logger = logging.get_logger(__name__)
|
15 |
+
|
16 |
+
|
17 |
+
class VibeVoiceProcessor:
|
18 |
+
r"""
|
19 |
+
Constructs a VibeVoice processor which wraps a VibeVoice tokenizer and audio processor into a single processor.
|
20 |
+
|
21 |
+
[`VibeVoiceProcessor`] offers all the functionalities of [`VibeVoiceTokenizer`] and [`VibeVoiceTokenizerProcessor`].
|
22 |
+
See the [`~VibeVoiceProcessor.__call__`] and [`~VibeVoiceProcessor.decode`] for more information.
|
23 |
+
|
24 |
+
Args:
|
25 |
+
tokenizer (`VibeVoiceTextTokenizer` or `VibeVoiceTextTokenizerFast`):
|
26 |
+
The tokenizer for text processing.
|
27 |
+
audio_processor (`VibeVoiceTokenizerProcessor`):
|
28 |
+
The audio processor for speech processing.
|
29 |
+
speech_tok_compress_ratio (`int`, *optional*, defaults to 3200):
|
30 |
+
The compression ratio for speech tokenization.
|
31 |
+
db_normalize (`bool`, *optional*, defaults to True):
|
32 |
+
Whether to apply decibel normalization to audio inputs.
|
33 |
+
"""
|
34 |
+
|
35 |
+
def __init__(self, tokenizer=None, audio_processor=None, speech_tok_compress_ratio=3200, db_normalize=True, **kwargs):
|
36 |
+
self.tokenizer = tokenizer
|
37 |
+
self.audio_processor = audio_processor
|
38 |
+
self.speech_tok_compress_ratio = speech_tok_compress_ratio
|
39 |
+
self.db_normalize = db_normalize
|
40 |
+
self.audio_normalizer = AudioNormalizer() if db_normalize else None
|
41 |
+
self.system_prompt = " Transform the text provided by various speakers into speech output, utilizing the distinct voice of each respective speaker.\n"
|
42 |
+
|
43 |
+
@classmethod
|
44 |
+
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
|
45 |
+
"""
|
46 |
+
Instantiate a VibeVoiceProcessor from a pretrained VibeVoice processor.
|
47 |
+
|
48 |
+
Args:
|
49 |
+
pretrained_model_name_or_path (`str` or `os.PathLike`):
|
50 |
+
This can be either:
|
51 |
+
- a string, the *model id* of a pretrained model
|
52 |
+
- a path to a *directory* containing processor config
|
53 |
+
|
54 |
+
Returns:
|
55 |
+
[`VibeVoiceProcessor`]: The processor object instantiated from pretrained model.
|
56 |
+
"""
|
57 |
+
import os
|
58 |
+
import json
|
59 |
+
from .vibevoice_tokenizer_processor import VibeVoiceTokenizerProcessor
|
60 |
+
from vibevoice.modular.modular_vibevoice_text_tokenizer import (
|
61 |
+
VibeVoiceTextTokenizer,
|
62 |
+
VibeVoiceTextTokenizerFast
|
63 |
+
)
|
64 |
+
|
65 |
+
# Load processor configuration
|
66 |
+
config_path = os.path.join(pretrained_model_name_or_path, "preprocessor_config.json")
|
67 |
+
if os.path.exists(config_path):
|
68 |
+
with open(config_path, 'r') as f:
|
69 |
+
config = json.load(f)
|
70 |
+
else:
|
71 |
+
logger.warning(f"No preprocessor_config.json found at {pretrained_model_name_or_path}, using defaults")
|
72 |
+
config = {
|
73 |
+
"speech_tok_compress_ratio": 3200,
|
74 |
+
"db_normalize": True,
|
75 |
+
}
|
76 |
+
|
77 |
+
# Extract main processor parameters
|
78 |
+
speech_tok_compress_ratio = config.get("speech_tok_compress_ratio", 3200)
|
79 |
+
db_normalize = config.get("db_normalize", True)
|
80 |
+
|
81 |
+
# Load tokenizer - try from model path first, then fallback to Qwen
|
82 |
+
language_model_pretrained_name = config.get("language_model_pretrained_name", None) or kwargs.pop("language_model_pretrained_name", "Qwen/Qwen2.5-1.5B")
|
83 |
+
logger.info(f"Loading tokenizer from {language_model_pretrained_name}")
|
84 |
+
if 'qwen' in language_model_pretrained_name.lower():
|
85 |
+
tokenizer = VibeVoiceTextTokenizerFast.from_pretrained(
|
86 |
+
language_model_pretrained_name,
|
87 |
+
**kwargs
|
88 |
+
)
|
89 |
+
else:
|
90 |
+
raise ValueError(f"Unsupported tokenizer type for {language_model_pretrained_name}. Supported types: Qwen, Llama, Gemma.")
|
91 |
+
|
92 |
+
# Load audio processor
|
93 |
+
if "audio_processor" in config:
|
94 |
+
# Create audio processor from config
|
95 |
+
audio_config = config["audio_processor"]
|
96 |
+
audio_processor = VibeVoiceTokenizerProcessor(
|
97 |
+
sampling_rate=audio_config.get("sampling_rate", 24000),
|
98 |
+
normalize_audio=audio_config.get("normalize_audio", True),
|
99 |
+
target_dB_FS=audio_config.get("target_dB_FS", -25),
|
100 |
+
eps=audio_config.get("eps", 1e-6),
|
101 |
+
)
|
102 |
+
else:
|
103 |
+
# Create default audio processor
|
104 |
+
audio_processor = VibeVoiceTokenizerProcessor()
|
105 |
+
|
106 |
+
# Create and return the processor
|
107 |
+
return cls(
|
108 |
+
tokenizer=tokenizer,
|
109 |
+
audio_processor=audio_processor,
|
110 |
+
speech_tok_compress_ratio=speech_tok_compress_ratio,
|
111 |
+
db_normalize=db_normalize,
|
112 |
+
)
|
113 |
+
|
114 |
+
def save_pretrained(self, save_directory: Union[str, os.PathLike], **kwargs):
|
115 |
+
"""
|
116 |
+
Save a processor to a directory, so that it can be re-loaded using the
|
117 |
+
[`~VibeVoiceProcessor.from_pretrained`] class method.
|
118 |
+
|
119 |
+
Args:
|
120 |
+
save_directory (`str` or `os.PathLike`):
|
121 |
+
Directory where the processor will be saved.
|
122 |
+
"""
|
123 |
+
import os
|
124 |
+
import json
|
125 |
+
|
126 |
+
os.makedirs(save_directory, exist_ok=True)
|
127 |
+
|
128 |
+
# Save processor configuration
|
129 |
+
processor_config = {
|
130 |
+
"processor_class": "VibeVoiceProcessor",
|
131 |
+
"speech_tok_compress_ratio": self.speech_tok_compress_ratio,
|
132 |
+
"db_normalize": self.db_normalize,
|
133 |
+
"audio_processor": {
|
134 |
+
"feature_extractor_type": "VibeVoiceTokenizerProcessor",
|
135 |
+
"sampling_rate": getattr(self.audio_processor, 'sampling_rate', 24000),
|
136 |
+
"normalize_audio": getattr(self.audio_processor, 'normalize_audio', True),
|
137 |
+
"target_dB_FS": getattr(self.audio_processor, 'target_dB_FS', -25),
|
138 |
+
"eps": getattr(self.audio_processor, 'eps', 1e-6),
|
139 |
+
}
|
140 |
+
}
|
141 |
+
|
142 |
+
config_path = os.path.join(save_directory, "preprocessor_config.json")
|
143 |
+
with open(config_path, 'w') as f:
|
144 |
+
json.dump(processor_config, f, indent=2)
|
145 |
+
|
146 |
+
logger.info(f"Processor configuration saved in {config_path}")
|
147 |
+
|
148 |
+
def __call__(
|
149 |
+
self,
|
150 |
+
text: Optional[Union[str, List[str], TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None,
|
151 |
+
voice_samples: Optional[Union[List[Union[str, np.ndarray]], List[List[Union[str, np.ndarray]]]]] = None,
|
152 |
+
padding: Union[bool, str, PaddingStrategy] = True,
|
153 |
+
truncation: Union[bool, str, TruncationStrategy] = False,
|
154 |
+
max_length: Optional[int] = None,
|
155 |
+
return_tensors: Optional[Union[str, TensorType]] = None,
|
156 |
+
return_attention_mask: bool = True,
|
157 |
+
**kwargs,
|
158 |
+
) -> BatchEncoding:
|
159 |
+
"""
|
160 |
+
Main method to process one or more podcast scripts with optional voice samples.
|
161 |
+
|
162 |
+
Args:
|
163 |
+
text (`str`, `List[str]`):
|
164 |
+
The input text(s) to process. Can be:
|
165 |
+
- A single script string
|
166 |
+
- A list of script strings for batch processing
|
167 |
+
- A path to a .json or .txt file
|
168 |
+
- A list of paths
|
169 |
+
voice_samples (`List[Union[str, np.ndarray]]`, `List[List[Union[str, np.ndarray]]]`, *optional*):
|
170 |
+
Voice samples for each script. Can be:
|
171 |
+
- A list of samples for a single script
|
172 |
+
- A list of lists for batch processing
|
173 |
+
padding (`bool`, `str` or `PaddingStrategy`, defaults to `True`):
|
174 |
+
Whether to pad sequences to the same length
|
175 |
+
truncation (`bool`, `str` or `TruncationStrategy`, defaults to `False`):
|
176 |
+
Whether to truncate sequences
|
177 |
+
max_length (`int`, *optional*):
|
178 |
+
Maximum length of the returned sequences
|
179 |
+
return_tensors (`str` or `TensorType`, *optional*):
|
180 |
+
If set, will return tensors of a particular framework
|
181 |
+
return_attention_mask (`bool`, defaults to `True`):
|
182 |
+
Whether to return the attention mask
|
183 |
+
|
184 |
+
Returns:
|
185 |
+
`BatchEncoding`: A BatchEncoding with the following fields:
|
186 |
+
- **input_ids** -- List of token id sequences or tensor
|
187 |
+
- **attention_mask** -- List of attention masks or tensor
|
188 |
+
- **speech_tensors** -- Padded speech inputs (if voice_samples provided)
|
189 |
+
- **speech_masks** -- Speech masks (if voice_samples provided)
|
190 |
+
- **speech_input_mask** -- Boolean masks indicating speech token positions
|
191 |
+
"""
|
192 |
+
# Handle single vs batch input
|
193 |
+
if isinstance(text, str) or (isinstance(text, list) and len(text) > 0 and not isinstance(text[0], str)):
|
194 |
+
# Single input
|
195 |
+
texts = [text]
|
196 |
+
is_batched = False
|
197 |
+
else:
|
198 |
+
# Batch input
|
199 |
+
texts = text
|
200 |
+
is_batched = True
|
201 |
+
|
202 |
+
# Handle voice samples
|
203 |
+
if voice_samples is not None:
|
204 |
+
if not is_batched or (isinstance(voice_samples[0], (str, np.ndarray))):
|
205 |
+
# Single set of voice samples
|
206 |
+
voice_samples_list = [voice_samples]
|
207 |
+
else:
|
208 |
+
# Batch of voice samples
|
209 |
+
voice_samples_list = voice_samples
|
210 |
+
else:
|
211 |
+
voice_samples_list = [None] * len(texts)
|
212 |
+
|
213 |
+
# Process each input
|
214 |
+
all_encodings = []
|
215 |
+
for text_input, voice_input in zip(texts, voice_samples_list):
|
216 |
+
encoding = self._process_single(text_input, voice_input)
|
217 |
+
all_encodings.append(encoding)
|
218 |
+
|
219 |
+
# Combine batch
|
220 |
+
batch_encoding = self._batch_encode(
|
221 |
+
all_encodings,
|
222 |
+
padding=padding,
|
223 |
+
truncation=truncation,
|
224 |
+
max_length=max_length,
|
225 |
+
return_tensors=return_tensors,
|
226 |
+
return_attention_mask=return_attention_mask,
|
227 |
+
)
|
228 |
+
|
229 |
+
return batch_encoding
|
230 |
+
|
231 |
+
def _process_single(
|
232 |
+
self,
|
233 |
+
text: Union[str, TextInput],
|
234 |
+
voice_samples: Optional[List[Union[str, np.ndarray]]] = None,
|
235 |
+
) -> Dict[str, Any]:
|
236 |
+
"""Process a single podcast script."""
|
237 |
+
# Determine if text is a file path or direct script
|
238 |
+
script = None
|
239 |
+
if isinstance(text, str):
|
240 |
+
# Check if it's a file path
|
241 |
+
if text.endswith('.json') and os.path.exists(text):
|
242 |
+
script = self._convert_json_to_script(text)
|
243 |
+
elif text.endswith('.txt') and os.path.exists(text):
|
244 |
+
script = self._convert_text_to_script(text)
|
245 |
+
else:
|
246 |
+
# Assume it's the script content directly
|
247 |
+
script = text
|
248 |
+
|
249 |
+
if script is None:
|
250 |
+
raise ValueError(f"Could not process input text: {text}")
|
251 |
+
|
252 |
+
# Parse the script
|
253 |
+
parsed_lines = self._parse_script(script)
|
254 |
+
all_speakers = list(set(speaker_id for speaker_id, _ in parsed_lines))
|
255 |
+
|
256 |
+
# Create system prompt
|
257 |
+
# system_tokens = self.tokenizer.encode(self.system_prompt, add_special_tokens=False)
|
258 |
+
system_tokens = self.tokenizer.encode(self.system_prompt)
|
259 |
+
|
260 |
+
# Process voice samples if provided
|
261 |
+
if voice_samples:
|
262 |
+
voice_tokens, voice_speech_inputs, voice_speech_masks = self._create_voice_prompt(voice_samples[:len(all_speakers)])
|
263 |
+
else:
|
264 |
+
voice_tokens, voice_speech_inputs, voice_speech_masks = [], [], []
|
265 |
+
|
266 |
+
# Build full token sequence
|
267 |
+
full_tokens = system_tokens + voice_tokens
|
268 |
+
speech_input_mask = [False] * len(system_tokens) + voice_speech_masks
|
269 |
+
|
270 |
+
# Add text input section
|
271 |
+
full_tokens += self.tokenizer.encode(' Text input:\n', add_special_tokens=False)
|
272 |
+
speech_input_mask += [False] * len(self.tokenizer.encode(' Text input:\n', add_special_tokens=False))
|
273 |
+
|
274 |
+
for speaker_id, speaker_text in parsed_lines:
|
275 |
+
speaker_text_tokens = self.tokenizer.encode(f" Speaker {speaker_id}:{speaker_text}\n", add_special_tokens=False)
|
276 |
+
full_tokens += speaker_text_tokens
|
277 |
+
speech_input_mask += [False] * len(speaker_text_tokens)
|
278 |
+
|
279 |
+
# Add speech output section
|
280 |
+
full_tokens += self.tokenizer.encode(' Speech output:\n', add_special_tokens=False) + [self.tokenizer.speech_start_id]
|
281 |
+
speech_input_mask += [False] * (len(self.tokenizer.encode(' Speech output:\n', add_special_tokens=False)) + 1)
|
282 |
+
|
283 |
+
return {
|
284 |
+
"input_ids": full_tokens,
|
285 |
+
"speech_inputs": voice_speech_inputs if voice_speech_inputs else None,
|
286 |
+
"speech_input_mask": speech_input_mask,
|
287 |
+
"parsed_script": parsed_lines,
|
288 |
+
"all_speakers": all_speakers,
|
289 |
+
}
|
290 |
+
|
291 |
+
def _batch_encode(
|
292 |
+
self,
|
293 |
+
encodings: List[Dict[str, Any]],
|
294 |
+
padding: Union[bool, str, PaddingStrategy] = True,
|
295 |
+
truncation: Union[bool, str, TruncationStrategy] = False,
|
296 |
+
max_length: Optional[int] = None,
|
297 |
+
return_tensors: Optional[Union[str, TensorType]] = None,
|
298 |
+
return_attention_mask: bool = True,
|
299 |
+
) -> BatchEncoding:
|
300 |
+
"""Combine multiple encodings into a batch with padding."""
|
301 |
+
# Extract input_ids and create attention_mask
|
302 |
+
input_ids_list = [enc["input_ids"] for enc in encodings]
|
303 |
+
speech_input_masks_list = [enc["speech_input_mask"] for enc in encodings]
|
304 |
+
|
305 |
+
# Determine padding strategy
|
306 |
+
if isinstance(padding, bool):
|
307 |
+
padding_strategy = PaddingStrategy.LONGEST if padding else PaddingStrategy.DO_NOT_PAD
|
308 |
+
elif isinstance(padding, str):
|
309 |
+
padding_strategy = PaddingStrategy(padding)
|
310 |
+
else:
|
311 |
+
padding_strategy = padding
|
312 |
+
|
313 |
+
# Apply padding to input_ids
|
314 |
+
if padding_strategy != PaddingStrategy.DO_NOT_PAD:
|
315 |
+
if padding_strategy == PaddingStrategy.LONGEST:
|
316 |
+
max_len = max(len(ids) for ids in input_ids_list)
|
317 |
+
elif padding_strategy == PaddingStrategy.MAX_LENGTH and max_length is not None:
|
318 |
+
max_len = max_length
|
319 |
+
else:
|
320 |
+
max_len = max(len(ids) for ids in input_ids_list)
|
321 |
+
|
322 |
+
# Pad sequences
|
323 |
+
padded_input_ids = []
|
324 |
+
attention_masks = []
|
325 |
+
padded_speech_input_masks = []
|
326 |
+
|
327 |
+
for input_ids, speech_mask in zip(input_ids_list, speech_input_masks_list):
|
328 |
+
# Truncate if needed
|
329 |
+
if truncation and len(input_ids) > max_len:
|
330 |
+
input_ids = input_ids[:max_len]
|
331 |
+
speech_mask = speech_mask[:max_len]
|
332 |
+
|
333 |
+
# Pad
|
334 |
+
padding_length = max_len - len(input_ids)
|
335 |
+
# padded_ids = [self.tokenizer.pad_token_id] * padding_length + input_ids
|
336 |
+
padded_ids = [self.tokenizer.pad_id] * padding_length + input_ids
|
337 |
+
attention_mask = [0] * padding_length + [1] * len(input_ids)
|
338 |
+
padded_speech_mask = [False] * padding_length + speech_mask
|
339 |
+
|
340 |
+
padded_input_ids.append(padded_ids)
|
341 |
+
attention_masks.append(attention_mask)
|
342 |
+
padded_speech_input_masks.append(padded_speech_mask)
|
343 |
+
|
344 |
+
input_ids_list = padded_input_ids
|
345 |
+
speech_input_masks_list = padded_speech_input_masks
|
346 |
+
else:
|
347 |
+
# No padding, just create attention masks
|
348 |
+
attention_masks = [[1] * len(ids) for ids in input_ids_list] if return_attention_mask else None
|
349 |
+
|
350 |
+
# Process speech inputs
|
351 |
+
all_speech_inputs = []
|
352 |
+
has_speech = False
|
353 |
+
for enc in encodings:
|
354 |
+
if enc["speech_inputs"] is not None:
|
355 |
+
all_speech_inputs.extend(enc["speech_inputs"])
|
356 |
+
has_speech = True
|
357 |
+
|
358 |
+
# Prepare batch encoding
|
359 |
+
batch_encoding = BatchEncoding()
|
360 |
+
|
361 |
+
# Handle tensor conversion
|
362 |
+
if return_tensors is not None:
|
363 |
+
batch_encoding["input_ids"] = torch.tensor(input_ids_list, dtype=torch.long)
|
364 |
+
if return_attention_mask and attention_masks is not None:
|
365 |
+
batch_encoding["attention_mask"] = torch.tensor(attention_masks, dtype=torch.long)
|
366 |
+
batch_encoding["speech_input_mask"] = torch.tensor(speech_input_masks_list, dtype=torch.bool)
|
367 |
+
else:
|
368 |
+
batch_encoding["input_ids"] = input_ids_list
|
369 |
+
if return_attention_mask and attention_masks is not None:
|
370 |
+
batch_encoding["attention_mask"] = attention_masks
|
371 |
+
batch_encoding["speech_input_mask"] = speech_input_masks_list
|
372 |
+
|
373 |
+
# Process speech tensors if present
|
374 |
+
if has_speech:
|
375 |
+
speech_dict = self.prepare_speech_inputs(
|
376 |
+
all_speech_inputs,
|
377 |
+
return_tensors=return_tensors,
|
378 |
+
)
|
379 |
+
batch_encoding["speech_tensors"] = speech_dict["padded_speeches"]
|
380 |
+
batch_encoding["speech_masks"] = speech_dict["speech_masks"]
|
381 |
+
else:
|
382 |
+
batch_encoding["speech_tensors"] = None
|
383 |
+
batch_encoding["speech_masks"] = None
|
384 |
+
|
385 |
+
# Add metadata
|
386 |
+
batch_encoding["parsed_scripts"] = [enc["parsed_script"] for enc in encodings]
|
387 |
+
batch_encoding["all_speakers_list"] = [enc["all_speakers"] for enc in encodings]
|
388 |
+
|
389 |
+
return batch_encoding
|
390 |
+
|
391 |
+
def _create_voice_prompt(
|
392 |
+
self,
|
393 |
+
speaker_samples: List[Union[str, np.ndarray]]
|
394 |
+
) -> Tuple[List[int], List[np.ndarray], List[bool]]:
|
395 |
+
"""
|
396 |
+
Create voice prompt tokens and process audio samples.
|
397 |
+
|
398 |
+
Returns:
|
399 |
+
tuple: (voice_tokens, voice_speech_inputs, voice_speech_masks)
|
400 |
+
"""
|
401 |
+
vae_token_id = self.tokenizer.speech_diffusion_id
|
402 |
+
|
403 |
+
voice_full_tokens = self.tokenizer.encode(' Voice input:\n', add_special_tokens=False)
|
404 |
+
voice_speech_inputs = []
|
405 |
+
voice_speech_masks = [False] * len(voice_full_tokens)
|
406 |
+
|
407 |
+
for speaker_id, speaker_audio in enumerate(speaker_samples):
|
408 |
+
prefix_tokens = self.tokenizer.encode(f" Speaker {speaker_id}:", add_special_tokens=False)
|
409 |
+
|
410 |
+
# Process audio
|
411 |
+
if isinstance(speaker_audio, str):
|
412 |
+
# Load audio from file
|
413 |
+
wav = self.audio_processor._load_audio_from_path(speaker_audio)
|
414 |
+
else:
|
415 |
+
wav = np.array(speaker_audio, dtype=np.float32)
|
416 |
+
|
417 |
+
# Apply normalization if needed
|
418 |
+
if self.db_normalize and self.audio_normalizer:
|
419 |
+
wav = self.audio_normalizer(wav)
|
420 |
+
|
421 |
+
# Calculate token length based on compression ratio
|
422 |
+
# if speaker_audio.endswith('.pt') or speaker_audio.endswith('.npy'):
|
423 |
+
# vae_tok_len = wav.shape[0]
|
424 |
+
# else:
|
425 |
+
vae_tok_len = math.ceil(wav.shape[0] / self.speech_tok_compress_ratio)
|
426 |
+
|
427 |
+
# Build tokens and masks
|
428 |
+
speaker_tokens = (prefix_tokens +
|
429 |
+
[self.tokenizer.speech_start_id] +
|
430 |
+
[vae_token_id] * vae_tok_len +
|
431 |
+
[self.tokenizer.speech_end_id] +
|
432 |
+
self.tokenizer.encode('\n', add_special_tokens=False))
|
433 |
+
|
434 |
+
vae_input_mask = ([False] * len(prefix_tokens) +
|
435 |
+
[False] +
|
436 |
+
[True] * vae_tok_len +
|
437 |
+
[False] +
|
438 |
+
[False])
|
439 |
+
|
440 |
+
voice_full_tokens.extend(speaker_tokens)
|
441 |
+
voice_speech_masks.extend(vae_input_mask)
|
442 |
+
voice_speech_inputs.append(wav)
|
443 |
+
|
444 |
+
return voice_full_tokens, voice_speech_inputs, voice_speech_masks
|
445 |
+
|
446 |
+
def prepare_speech_inputs(
|
447 |
+
self,
|
448 |
+
speech_inputs: List[np.ndarray],
|
449 |
+
return_tensors: Optional[Union[str, TensorType]] = None,
|
450 |
+
device: Optional[Union[str, torch.device]] = None,
|
451 |
+
dtype: Optional[torch.dtype] = None,
|
452 |
+
) -> Dict[str, Any]:
|
453 |
+
"""
|
454 |
+
Prepare speech inputs for model consumption.
|
455 |
+
|
456 |
+
Args:
|
457 |
+
speech_inputs: List of speech arrays
|
458 |
+
return_tensors: Output tensor type
|
459 |
+
device: Device to place tensors on
|
460 |
+
dtype: Data type for tensors
|
461 |
+
|
462 |
+
Returns:
|
463 |
+
Dictionary with padded_speeches and speech_masks
|
464 |
+
"""
|
465 |
+
if not speech_inputs:
|
466 |
+
return {"padded_speeches": None, "speech_masks": None}
|
467 |
+
|
468 |
+
# Calculate sequence lengths
|
469 |
+
vae_tok_seqlens = [math.ceil(s.shape[0] / self.speech_tok_compress_ratio) for s in speech_inputs]
|
470 |
+
# vae_tok_seqlens = [math.ceil(s.shape[0] / self.speech_tok_compress_ratio) if s.ndim == 1 else s.shape[0] for s in speech_inputs]
|
471 |
+
max_speech_length = max(s.shape[0] for s in speech_inputs)
|
472 |
+
|
473 |
+
# Pad speeches
|
474 |
+
if speech_inputs[0].ndim == 1:
|
475 |
+
padded_speeches = np.full((len(speech_inputs), max_speech_length), fill_value=0, dtype=np.float32)
|
476 |
+
else:
|
477 |
+
padded_speeches = np.full((len(speech_inputs), max_speech_length, speech_inputs[0].shape[-1]), fill_value=0, dtype=np.float32)
|
478 |
+
speech_masks = np.zeros((len(speech_inputs), max(vae_tok_seqlens)), dtype=np.bool_)
|
479 |
+
|
480 |
+
for i, (speech, vae_tok_length) in enumerate(zip(speech_inputs, vae_tok_seqlens)):
|
481 |
+
padded_speeches[i, :len(speech)] = speech
|
482 |
+
speech_masks[i, :vae_tok_length] = True
|
483 |
+
|
484 |
+
result = {
|
485 |
+
"padded_speeches": padded_speeches,
|
486 |
+
"speech_masks": speech_masks,
|
487 |
+
}
|
488 |
+
|
489 |
+
# Convert to tensors if requested
|
490 |
+
if return_tensors == "pt":
|
491 |
+
result["padded_speeches"] = torch.tensor(padded_speeches, device=device, dtype=dtype or torch.float32)
|
492 |
+
result["speech_masks"] = torch.tensor(speech_masks, device=device, dtype=torch.bool)
|
493 |
+
|
494 |
+
return result
|
495 |
+
|
496 |
+
def _convert_json_to_script(self, json_file: str) -> str:
|
497 |
+
"""
|
498 |
+
Convert JSON format to script format.
|
499 |
+
Expected JSON format:
|
500 |
+
[
|
501 |
+
{"speaker": "1", "text": "Hello everyone..."},
|
502 |
+
{"speaker": "2", "text": "Great to be here..."}
|
503 |
+
]
|
504 |
+
"""
|
505 |
+
import json
|
506 |
+
|
507 |
+
with open(json_file, 'r', encoding='utf-8') as f:
|
508 |
+
data = json.load(f)
|
509 |
+
|
510 |
+
if not isinstance(data, list):
|
511 |
+
raise ValueError("JSON file must contain a list of speaker entries")
|
512 |
+
|
513 |
+
script_lines = []
|
514 |
+
for item in data:
|
515 |
+
if not isinstance(item, dict):
|
516 |
+
logger.warning(f"Skipping non-dict entry: {item}")
|
517 |
+
continue
|
518 |
+
|
519 |
+
speaker = item.get('speaker')
|
520 |
+
text = item.get('text')
|
521 |
+
|
522 |
+
if speaker is None or text is None:
|
523 |
+
logger.warning(f"Skipping entry missing speaker or text: {item}")
|
524 |
+
continue
|
525 |
+
|
526 |
+
# Ensure speaker ID is valid
|
527 |
+
try:
|
528 |
+
speaker_id = int(speaker)
|
529 |
+
except (ValueError, TypeError):
|
530 |
+
logger.warning(f"Invalid speaker ID: {speaker}, skipping entry")
|
531 |
+
continue
|
532 |
+
|
533 |
+
# Clean up text
|
534 |
+
text = text.strip()
|
535 |
+
if text:
|
536 |
+
script_lines.append(f"Speaker {speaker_id}: {text}")
|
537 |
+
|
538 |
+
if not script_lines:
|
539 |
+
raise ValueError("No valid entries found in JSON file")
|
540 |
+
|
541 |
+
return "\n".join(script_lines)
|
542 |
+
|
543 |
+
def _convert_text_to_script(self, text_file: str) -> str:
|
544 |
+
"""
|
545 |
+
Convert text file to script format.
|
546 |
+
Handles multiple formats:
|
547 |
+
1. Already formatted as "Speaker X: text"
|
548 |
+
2. Plain text (assigns to Speaker 1)
|
549 |
+
|
550 |
+
Handles edge cases like multiple colons in a line.
|
551 |
+
"""
|
552 |
+
with open(text_file, 'r', encoding='utf-8') as f:
|
553 |
+
lines = f.readlines()
|
554 |
+
|
555 |
+
script_lines = []
|
556 |
+
current_speaker = 1
|
557 |
+
|
558 |
+
for line in lines:
|
559 |
+
line = line.strip()
|
560 |
+
if not line:
|
561 |
+
continue
|
562 |
+
|
563 |
+
# Try to parse as "Speaker X: text" format
|
564 |
+
# Use regex to be more robust
|
565 |
+
speaker_match = re.match(r'^Speaker\s+(\d+)\s*:\s*(.*)$', line, re.IGNORECASE)
|
566 |
+
|
567 |
+
if speaker_match:
|
568 |
+
speaker_id = int(speaker_match.group(1))
|
569 |
+
text = speaker_match.group(2).strip()
|
570 |
+
if text:
|
571 |
+
script_lines.append(f"Speaker {speaker_id}: {text}")
|
572 |
+
else:
|
573 |
+
# Treat as plain text - assign to current speaker
|
574 |
+
script_lines.append(f"Speaker {current_speaker}: {line}")
|
575 |
+
|
576 |
+
if not script_lines:
|
577 |
+
raise ValueError("No valid content found in text file")
|
578 |
+
|
579 |
+
return "\n".join(script_lines)
|
580 |
+
|
581 |
+
def _parse_script(self, script: str) -> List[Tuple[int, str]]:
|
582 |
+
"""Parse script into list of (speaker_id, text) tuples."""
|
583 |
+
lines = script.strip().split("\n")
|
584 |
+
parsed_lines = []
|
585 |
+
speaker_ids = []
|
586 |
+
|
587 |
+
# First pass: parse all lines and collect speaker IDs
|
588 |
+
for line in lines:
|
589 |
+
if not line.strip():
|
590 |
+
continue
|
591 |
+
|
592 |
+
# Use regex to handle edge cases like multiple colons
|
593 |
+
match = re.match(r'^Speaker\s+(\d+)\s*:\s*(.*)$', line.strip(), re.IGNORECASE)
|
594 |
+
|
595 |
+
if match:
|
596 |
+
speaker_id = int(match.group(1))
|
597 |
+
text = ' ' + match.group(2).strip()
|
598 |
+
parsed_lines.append((speaker_id, text))
|
599 |
+
speaker_ids.append(speaker_id)
|
600 |
+
else:
|
601 |
+
logger.warning(f"Could not parse line: '{line}'")
|
602 |
+
|
603 |
+
if not parsed_lines:
|
604 |
+
raise ValueError("No valid speaker lines found in script")
|
605 |
+
|
606 |
+
# Check if we need to normalize speaker IDs (only if all are > 0)
|
607 |
+
min_speaker_id = min(speaker_ids)
|
608 |
+
if min_speaker_id > 0:
|
609 |
+
# Normalize to start from 0
|
610 |
+
normalized_lines = []
|
611 |
+
for speaker_id, text in parsed_lines:
|
612 |
+
normalized_lines.append((speaker_id - 1, text))
|
613 |
+
return normalized_lines
|
614 |
+
else:
|
615 |
+
# Keep original IDs
|
616 |
+
return parsed_lines
|
617 |
+
|
618 |
+
def _merge_inputs(self, text_inputs: BatchEncoding, audio_inputs: Dict) -> BatchEncoding:
|
619 |
+
"""Merge text and audio inputs into a single BatchEncoding."""
|
620 |
+
# Start with text inputs
|
621 |
+
merged = BatchEncoding(text_inputs)
|
622 |
+
|
623 |
+
# Add audio-specific fields
|
624 |
+
if "audio" in audio_inputs:
|
625 |
+
merged["speech_inputs"] = audio_inputs["audio"]
|
626 |
+
if "streaming" in audio_inputs:
|
627 |
+
merged["streaming"] = audio_inputs["streaming"]
|
628 |
+
|
629 |
+
return merged
|
630 |
+
|
631 |
+
def batch_decode(self, *args, **kwargs):
|
632 |
+
"""
|
633 |
+
This method forwards all its arguments to VibeVoiceTextTokenizer's [`~PreTrainedTokenizer.batch_decode`].
|
634 |
+
Please refer to the docstring of this method for more information.
|
635 |
+
"""
|
636 |
+
return self.tokenizer.batch_decode(*args, **kwargs)
|
637 |
+
|
638 |
+
def decode(self, *args, **kwargs):
|
639 |
+
"""
|
640 |
+
This method forwards all its arguments to VibeVoiceTextTokenizer's [`~PreTrainedTokenizer.decode`].
|
641 |
+
Please refer to the docstring of this method for more information.
|
642 |
+
"""
|
643 |
+
return self.tokenizer.decode(*args, **kwargs)
|
644 |
+
|
645 |
+
@property
|
646 |
+
def model_input_names(self):
|
647 |
+
"""
|
648 |
+
Return the list of inputs accepted by the model.
|
649 |
+
"""
|
650 |
+
tokenizer_input_names = self.tokenizer.model_input_names
|
651 |
+
audio_processor_input_names = self.audio_processor.model_input_names
|
652 |
+
return list(dict.fromkeys(tokenizer_input_names + audio_processor_input_names + ["speech_inputs", "speech_input_mask"]))
|
653 |
+
|
654 |
+
def save_audio(self,
|
655 |
+
audio: Union[torch.Tensor, np.ndarray, List[Union[torch.Tensor, np.ndarray]]],
|
656 |
+
output_path: str = "output.wav",
|
657 |
+
sampling_rate: Optional[int] = None,
|
658 |
+
normalize: bool = False,
|
659 |
+
batch_prefix: str = "audio_",
|
660 |
+
) -> str:
|
661 |
+
"""
|
662 |
+
Save audio data to a file.
|
663 |
+
Args:
|
664 |
+
audio (Union[torch.Tensor, np.ndarray, List[Union[torch.Tensor, np.ndarray]]]):
|
665 |
+
The audio data to save. Can be a single tensor/array or a list of them.
|
666 |
+
output_path (str, optional): Path to save the audio file. Defaults to "output.wav".
|
667 |
+
sampling_rate (int, optional): Sampling rate for the audio. If None, uses the processor's default.
|
668 |
+
normalize (bool, optional): Whether to normalize the audio before saving. Defaults to False.
|
669 |
+
batch_prefix (str, optional): Prefix for batch audio files. Defaults to "audio_".
|
670 |
+
Returns:
|
671 |
+
str: The path to the saved audio file.
|
672 |
+
"""
|
673 |
+
return self.audio_processor.save_audio(audio, output_path=output_path, sampling_rate=sampling_rate, normalize=normalize, batch_prefix=batch_prefix)
|
674 |
+
|
675 |
+
__all__ = [
|
676 |
+
"VibeVoiceProcessor",
|
677 |
+
]
|
processor/vibevoice_tokenizer_processor.py
ADDED
@@ -0,0 +1,483 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Processor class for VibeVoice models.
|
3 |
+
"""
|
4 |
+
|
5 |
+
import os
|
6 |
+
import json
|
7 |
+
import warnings
|
8 |
+
from typing import List, Optional, Union, Dict, Any
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
import torch
|
12 |
+
|
13 |
+
from transformers.feature_extraction_utils import FeatureExtractionMixin
|
14 |
+
from transformers.utils import logging
|
15 |
+
|
16 |
+
logger = logging.get_logger(__name__)
|
17 |
+
|
18 |
+
|
19 |
+
class AudioNormalizer:
|
20 |
+
"""
|
21 |
+
Audio normalization class for VibeVoice tokenizer.
|
22 |
+
|
23 |
+
This class provides audio normalization to ensure consistent input levels
|
24 |
+
for the VibeVoice tokenizer while maintaining audio quality.
|
25 |
+
"""
|
26 |
+
|
27 |
+
def __init__(self, target_dB_FS: float = -25, eps: float = 1e-6):
|
28 |
+
"""
|
29 |
+
Initialize the audio normalizer.
|
30 |
+
|
31 |
+
Args:
|
32 |
+
target_dB_FS (float): Target dB FS level for the audio. Default: -25
|
33 |
+
eps (float): Small value to avoid division by zero. Default: 1e-6
|
34 |
+
"""
|
35 |
+
self.target_dB_FS = target_dB_FS
|
36 |
+
self.eps = eps
|
37 |
+
|
38 |
+
def tailor_dB_FS(self, audio: np.ndarray) -> tuple:
|
39 |
+
"""
|
40 |
+
Adjust the audio to the target dB FS level.
|
41 |
+
|
42 |
+
Args:
|
43 |
+
audio (np.ndarray): Input audio signal
|
44 |
+
|
45 |
+
Returns:
|
46 |
+
tuple: (normalized_audio, rms, scalar)
|
47 |
+
"""
|
48 |
+
rms = np.sqrt(np.mean(audio**2))
|
49 |
+
scalar = 10 ** (self.target_dB_FS / 20) / (rms + self.eps)
|
50 |
+
normalized_audio = audio * scalar
|
51 |
+
return normalized_audio, rms, scalar
|
52 |
+
|
53 |
+
def avoid_clipping(self, audio: np.ndarray, scalar: Optional[float] = None) -> tuple:
|
54 |
+
"""
|
55 |
+
Avoid clipping by scaling down if necessary.
|
56 |
+
|
57 |
+
Args:
|
58 |
+
audio (np.ndarray): Input audio signal
|
59 |
+
scalar (float, optional): Explicit scaling factor
|
60 |
+
|
61 |
+
Returns:
|
62 |
+
tuple: (normalized_audio, scalar)
|
63 |
+
"""
|
64 |
+
if scalar is None:
|
65 |
+
max_val = np.max(np.abs(audio))
|
66 |
+
if max_val > 1.0:
|
67 |
+
scalar = max_val + self.eps
|
68 |
+
else:
|
69 |
+
scalar = 1.0
|
70 |
+
|
71 |
+
return audio / scalar, scalar
|
72 |
+
|
73 |
+
def __call__(self, audio: np.ndarray) -> np.ndarray:
|
74 |
+
"""
|
75 |
+
Normalize the audio by adjusting to target dB FS and avoiding clipping.
|
76 |
+
|
77 |
+
Args:
|
78 |
+
audio (np.ndarray): Input audio signal
|
79 |
+
|
80 |
+
Returns:
|
81 |
+
np.ndarray: Normalized audio signal
|
82 |
+
"""
|
83 |
+
# First adjust to target dB FS
|
84 |
+
audio, _, _ = self.tailor_dB_FS(audio)
|
85 |
+
# Then avoid clipping
|
86 |
+
audio, _ = self.avoid_clipping(audio)
|
87 |
+
return audio
|
88 |
+
|
89 |
+
|
90 |
+
# Change from ProcessorMixin to FeatureExtractionMixin which is designed for single components
|
91 |
+
class VibeVoiceTokenizerProcessor(FeatureExtractionMixin):
|
92 |
+
"""
|
93 |
+
Processor for VibeVoice acoustic tokenizer models.
|
94 |
+
|
95 |
+
This processor handles audio preprocessing for VibeVoice models, including:
|
96 |
+
- Audio format conversion (stereo to mono)
|
97 |
+
- Optional audio normalization
|
98 |
+
- Streaming support for infinite-length audio
|
99 |
+
|
100 |
+
Args:
|
101 |
+
sampling_rate (int, optional): Expected sampling rate. Defaults to 24000.
|
102 |
+
normalize_audio (bool, optional): Whether to normalize audio. Defaults to True.
|
103 |
+
target_dB_FS (float, optional): Target dB FS for normalization. Defaults to -25.
|
104 |
+
eps (float, optional): Small value for numerical stability. Defaults to 1e-6.
|
105 |
+
"""
|
106 |
+
model_input_names = ["input_features"]
|
107 |
+
|
108 |
+
def __init__(
|
109 |
+
self,
|
110 |
+
sampling_rate: int = 24000,
|
111 |
+
normalize_audio: bool = True,
|
112 |
+
target_dB_FS: float = -25,
|
113 |
+
eps: float = 1e-6,
|
114 |
+
**kwargs,
|
115 |
+
):
|
116 |
+
super().__init__(**kwargs)
|
117 |
+
|
118 |
+
self.sampling_rate = sampling_rate
|
119 |
+
self.normalize_audio = normalize_audio
|
120 |
+
|
121 |
+
# Initialize audio normalizer if needed
|
122 |
+
if self.normalize_audio:
|
123 |
+
self.normalizer = AudioNormalizer(target_dB_FS=target_dB_FS, eps=eps)
|
124 |
+
else:
|
125 |
+
self.normalizer = None
|
126 |
+
|
127 |
+
# Save config
|
128 |
+
self.feature_extractor_dict = {
|
129 |
+
"sampling_rate": sampling_rate,
|
130 |
+
"normalize_audio": normalize_audio,
|
131 |
+
"target_dB_FS": target_dB_FS,
|
132 |
+
"eps": eps,
|
133 |
+
}
|
134 |
+
|
135 |
+
def _ensure_mono(self, audio: np.ndarray) -> np.ndarray:
|
136 |
+
"""
|
137 |
+
Convert stereo audio to mono if needed.
|
138 |
+
|
139 |
+
Args:
|
140 |
+
audio (np.ndarray): Input audio array
|
141 |
+
|
142 |
+
Returns:
|
143 |
+
np.ndarray: Mono audio array
|
144 |
+
"""
|
145 |
+
if len(audio.shape) == 1:
|
146 |
+
return audio
|
147 |
+
elif len(audio.shape) == 2:
|
148 |
+
if audio.shape[0] == 2: # (2, time)
|
149 |
+
return np.mean(audio, axis=0)
|
150 |
+
elif audio.shape[1] == 2: # (time, 2)
|
151 |
+
return np.mean(audio, axis=1)
|
152 |
+
else:
|
153 |
+
# If one dimension is 1, squeeze it
|
154 |
+
if audio.shape[0] == 1:
|
155 |
+
return audio.squeeze(0)
|
156 |
+
elif audio.shape[1] == 1:
|
157 |
+
return audio.squeeze(1)
|
158 |
+
else:
|
159 |
+
raise ValueError(f"Unexpected audio shape: {audio.shape}")
|
160 |
+
else:
|
161 |
+
raise ValueError(f"Audio should be 1D or 2D, got shape: {audio.shape}")
|
162 |
+
|
163 |
+
def _process_single_audio(self, audio: Union[np.ndarray, List[float]]) -> np.ndarray:
|
164 |
+
"""
|
165 |
+
Process a single audio array.
|
166 |
+
|
167 |
+
Args:
|
168 |
+
audio: Single audio input
|
169 |
+
|
170 |
+
Returns:
|
171 |
+
np.ndarray: Processed audio
|
172 |
+
"""
|
173 |
+
# Convert to numpy array
|
174 |
+
if not isinstance(audio, np.ndarray):
|
175 |
+
audio = np.array(audio, dtype=np.float32)
|
176 |
+
else:
|
177 |
+
audio = audio.astype(np.float32)
|
178 |
+
|
179 |
+
# Ensure mono
|
180 |
+
audio = self._ensure_mono(audio)
|
181 |
+
|
182 |
+
# Normalize if requested
|
183 |
+
if self.normalize_audio and self.normalizer is not None:
|
184 |
+
audio = self.normalizer(audio)
|
185 |
+
|
186 |
+
return audio
|
187 |
+
|
188 |
+
def __call__(
|
189 |
+
self,
|
190 |
+
audio: Union[str, np.ndarray, List[float], List[np.ndarray], List[List[float]], List[str]] = None,
|
191 |
+
sampling_rate: Optional[int] = None,
|
192 |
+
return_tensors: Optional[str] = None,
|
193 |
+
**kwargs,
|
194 |
+
):
|
195 |
+
"""
|
196 |
+
Process audio for VibeVoice models.
|
197 |
+
|
198 |
+
Args:
|
199 |
+
audio: Audio input(s) to process. Can be:
|
200 |
+
- str: Path to audio file
|
201 |
+
- np.ndarray: Audio array
|
202 |
+
- List[float]: Audio as list of floats
|
203 |
+
- List[np.ndarray]: Batch of audio arrays
|
204 |
+
- List[str]: Batch of audio file paths
|
205 |
+
sampling_rate (int, optional): Sampling rate of the input audio
|
206 |
+
return_tensors (str, optional): Return format ('pt' for PyTorch, 'np' for NumPy)
|
207 |
+
|
208 |
+
Returns:
|
209 |
+
dict: Processed audio inputs with keys:
|
210 |
+
- input_features: Audio tensor(s) ready for the model
|
211 |
+
"""
|
212 |
+
if audio is None:
|
213 |
+
raise ValueError("Audio input is required")
|
214 |
+
|
215 |
+
# Validate sampling rate
|
216 |
+
if sampling_rate is not None and sampling_rate != self.sampling_rate:
|
217 |
+
logger.warning(
|
218 |
+
f"Input sampling rate ({sampling_rate}) differs from expected "
|
219 |
+
f"sampling rate ({self.sampling_rate}). Please resample your audio."
|
220 |
+
)
|
221 |
+
|
222 |
+
# Handle different input types
|
223 |
+
if isinstance(audio, str):
|
224 |
+
# Single audio file path
|
225 |
+
audio = self._load_audio_from_path(audio)
|
226 |
+
is_batched = False
|
227 |
+
elif isinstance(audio, list):
|
228 |
+
if len(audio) == 0:
|
229 |
+
raise ValueError("Empty audio list provided")
|
230 |
+
|
231 |
+
# Check if it's a list of file paths
|
232 |
+
if all(isinstance(item, str) for item in audio):
|
233 |
+
# Batch of audio file paths
|
234 |
+
audio = [self._load_audio_from_path(path) for path in audio]
|
235 |
+
is_batched = True
|
236 |
+
else:
|
237 |
+
# Check if it's batched audio arrays
|
238 |
+
is_batched = isinstance(audio[0], (np.ndarray, list))
|
239 |
+
else:
|
240 |
+
# Single audio array or list
|
241 |
+
is_batched = False
|
242 |
+
|
243 |
+
# Process audio
|
244 |
+
if is_batched:
|
245 |
+
processed_audio = [self._process_single_audio(a) for a in audio]
|
246 |
+
else:
|
247 |
+
processed_audio = [self._process_single_audio(audio)]
|
248 |
+
|
249 |
+
# Convert to tensors if requested
|
250 |
+
if return_tensors == "pt":
|
251 |
+
if len(processed_audio) == 1:
|
252 |
+
# Create a proper batch dimension (B, T)
|
253 |
+
input_features = torch.from_numpy(processed_audio[0]).unsqueeze(0).unsqueeze(1)
|
254 |
+
else:
|
255 |
+
# For batched input with different lengths, create a batch properly
|
256 |
+
input_features = torch.stack([torch.from_numpy(a) for a in processed_audio]).unsqueeze(1)
|
257 |
+
elif return_tensors == "np":
|
258 |
+
if len(processed_audio) == 1:
|
259 |
+
input_features = processed_audio[0][np.newaxis, np.newaxis, :]
|
260 |
+
else:
|
261 |
+
input_features = np.stack(processed_audio)[:, np.newaxis, :]
|
262 |
+
else:
|
263 |
+
input_features = processed_audio[0] if len(processed_audio) == 1 else processed_audio
|
264 |
+
|
265 |
+
outputs = {
|
266 |
+
"audio": input_features, # Use "audio" instead of "input_features"
|
267 |
+
}
|
268 |
+
|
269 |
+
return outputs
|
270 |
+
|
271 |
+
def _load_audio_from_path(self, audio_path: str) -> np.ndarray:
|
272 |
+
"""
|
273 |
+
Load audio from file path.
|
274 |
+
|
275 |
+
Args:
|
276 |
+
audio_path (str): Path to audio file
|
277 |
+
|
278 |
+
Returns:
|
279 |
+
np.ndarray: Loaded audio array
|
280 |
+
"""
|
281 |
+
# Get file extension to determine loading method
|
282 |
+
file_ext = os.path.splitext(audio_path)[1].lower()
|
283 |
+
|
284 |
+
if file_ext in ['.wav', '.mp3', '.flac', '.m4a', '.ogg']:
|
285 |
+
# Audio file - use librosa
|
286 |
+
import librosa
|
287 |
+
audio_array, sr = librosa.load(
|
288 |
+
audio_path,
|
289 |
+
sr=self.sampling_rate,
|
290 |
+
mono=True
|
291 |
+
)
|
292 |
+
return audio_array
|
293 |
+
elif file_ext == '.pt':
|
294 |
+
# PyTorch tensor file
|
295 |
+
audio_tensor = torch.load(audio_path, map_location='cpu').squeeze()
|
296 |
+
if isinstance(audio_tensor, torch.Tensor):
|
297 |
+
audio_array = audio_tensor.numpy()
|
298 |
+
else:
|
299 |
+
audio_array = np.array(audio_tensor)
|
300 |
+
return audio_array.astype(np.float32)
|
301 |
+
elif file_ext == '.npy':
|
302 |
+
# NumPy file
|
303 |
+
audio_array = np.load(audio_path)
|
304 |
+
return audio_array.astype(np.float32)
|
305 |
+
else:
|
306 |
+
raise ValueError(
|
307 |
+
f"Unsupported file format: {file_ext}. "
|
308 |
+
f"Supported formats: .wav, .mp3, .flac, .m4a, .ogg, .pt, .npy, .npz"
|
309 |
+
)
|
310 |
+
|
311 |
+
def preprocess_audio(
|
312 |
+
self,
|
313 |
+
audio_path_or_array: Union[str, np.ndarray],
|
314 |
+
normalize: Optional[bool] = None,
|
315 |
+
) -> np.ndarray:
|
316 |
+
"""
|
317 |
+
Convenience method to preprocess audio from file path or array.
|
318 |
+
This method is kept for backward compatibility but __call__ is recommended.
|
319 |
+
|
320 |
+
Args:
|
321 |
+
audio_path_or_array: Path to audio file or numpy array
|
322 |
+
normalize: Whether to normalize (overrides default setting)
|
323 |
+
|
324 |
+
Returns:
|
325 |
+
np.ndarray: Preprocessed audio array
|
326 |
+
"""
|
327 |
+
if isinstance(audio_path_or_array, str):
|
328 |
+
audio_array = self._load_audio_from_path(audio_path_or_array)
|
329 |
+
else:
|
330 |
+
audio_array = np.array(audio_path_or_array, dtype=np.float32)
|
331 |
+
|
332 |
+
# Override normalization setting if specified
|
333 |
+
original_normalize = self.normalize_audio
|
334 |
+
if normalize is not None:
|
335 |
+
self.normalize_audio = normalize
|
336 |
+
|
337 |
+
try:
|
338 |
+
processed = self._process_single_audio(audio_array)
|
339 |
+
finally:
|
340 |
+
# Restore original setting
|
341 |
+
self.normalize_audio = original_normalize
|
342 |
+
|
343 |
+
return processed
|
344 |
+
|
345 |
+
# Override to_dict method for configuration saving
|
346 |
+
def to_dict(self) -> Dict[str, Any]:
|
347 |
+
"""
|
348 |
+
Convert the object to a dict containing all attributes needed for serialization.
|
349 |
+
"""
|
350 |
+
return self.feature_extractor_dict
|
351 |
+
|
352 |
+
def save_audio(
|
353 |
+
self,
|
354 |
+
audio: Union[torch.Tensor, np.ndarray, List[Union[torch.Tensor, np.ndarray]]],
|
355 |
+
output_path: str = "output.wav",
|
356 |
+
sampling_rate: Optional[int] = None,
|
357 |
+
normalize: bool = False,
|
358 |
+
batch_prefix: str = "audio_",
|
359 |
+
):
|
360 |
+
"""
|
361 |
+
Save audio data to WAV file(s).
|
362 |
+
|
363 |
+
Args:
|
364 |
+
audio: Audio data to save. Can be:
|
365 |
+
- torch.Tensor: PyTorch tensor with shape (B, C, T) or (B, T) or (T)
|
366 |
+
- np.ndarray: NumPy array with shape (B, C, T) or (B, T) or (T)
|
367 |
+
- List of tensors or arrays
|
368 |
+
output_path: Path where to save the audio. If saving multiple files,
|
369 |
+
this is treated as a directory and individual files will be saved inside.
|
370 |
+
sampling_rate: Sampling rate for the saved audio. Defaults to the processor's rate.
|
371 |
+
normalize: Whether to normalize audio before saving.
|
372 |
+
batch_prefix: Prefix for batch files when saving multiple audios.
|
373 |
+
|
374 |
+
Returns:
|
375 |
+
List[str]: Paths to the saved audio files.
|
376 |
+
"""
|
377 |
+
if sampling_rate is None:
|
378 |
+
sampling_rate = self.sampling_rate
|
379 |
+
|
380 |
+
try:
|
381 |
+
import soundfile as sf
|
382 |
+
except ImportError:
|
383 |
+
raise ImportError(
|
384 |
+
"soundfile is required to save audio files. "
|
385 |
+
"Install it with: pip install soundfile"
|
386 |
+
)
|
387 |
+
|
388 |
+
# Ensure audio is in the right format
|
389 |
+
if isinstance(audio, torch.Tensor):
|
390 |
+
# Convert PyTorch tensor to numpy
|
391 |
+
audio_np = audio.float().detach().cpu().numpy()
|
392 |
+
elif isinstance(audio, np.ndarray):
|
393 |
+
audio_np = audio
|
394 |
+
elif isinstance(audio, list):
|
395 |
+
# Handle list of tensors or arrays
|
396 |
+
if all(isinstance(a, torch.Tensor) for a in audio):
|
397 |
+
audio_np = [a.float().detach().cpu().numpy() for a in audio]
|
398 |
+
else:
|
399 |
+
audio_np = audio
|
400 |
+
else:
|
401 |
+
raise ValueError(f"Unsupported audio type: {type(audio)}")
|
402 |
+
|
403 |
+
saved_paths = []
|
404 |
+
|
405 |
+
# Handle based on shape or type
|
406 |
+
if isinstance(audio_np, list):
|
407 |
+
# Multiple separate audios to save
|
408 |
+
output_dir = output_path
|
409 |
+
|
410 |
+
# Ensure output directory exists
|
411 |
+
os.makedirs(output_dir, exist_ok=True)
|
412 |
+
|
413 |
+
# Save each audio
|
414 |
+
for i, audio_item in enumerate(audio_np):
|
415 |
+
audio_item = self._prepare_audio_for_save(audio_item, normalize)
|
416 |
+
file_path = os.path.join(output_dir, f"{batch_prefix}{i}.wav")
|
417 |
+
sf.write(file_path, audio_item, sampling_rate)
|
418 |
+
saved_paths.append(file_path)
|
419 |
+
|
420 |
+
else:
|
421 |
+
# Handle different dimensions
|
422 |
+
if len(audio_np.shape) >= 3: # (B, C, T) or similar
|
423 |
+
# Get batch size
|
424 |
+
batch_size = audio_np.shape[0]
|
425 |
+
|
426 |
+
if batch_size > 1:
|
427 |
+
# Multiple audios in a batch
|
428 |
+
output_dir = output_path
|
429 |
+
|
430 |
+
# Ensure output directory exists
|
431 |
+
os.makedirs(output_dir, exist_ok=True)
|
432 |
+
|
433 |
+
# Save each audio in the batch
|
434 |
+
for i in range(batch_size):
|
435 |
+
# Extract single audio and remove channel dim if present
|
436 |
+
single_audio = audio_np[i]
|
437 |
+
if len(single_audio.shape) > 1:
|
438 |
+
if single_audio.shape[0] == 1: # (1, T)
|
439 |
+
single_audio = single_audio.squeeze(0)
|
440 |
+
|
441 |
+
single_audio = self._prepare_audio_for_save(single_audio, normalize)
|
442 |
+
file_path = os.path.join(output_dir, f"{batch_prefix}{i}.wav")
|
443 |
+
sf.write(file_path, single_audio, sampling_rate)
|
444 |
+
saved_paths.append(file_path)
|
445 |
+
else:
|
446 |
+
# Single audio with batch and channel dims
|
447 |
+
audio_item = audio_np.squeeze() # Remove batch and channel dimensions
|
448 |
+
audio_item = self._prepare_audio_for_save(audio_item, normalize)
|
449 |
+
sf.write(output_path, audio_item, sampling_rate)
|
450 |
+
saved_paths.append(output_path)
|
451 |
+
else:
|
452 |
+
# Single audio without batch dimension
|
453 |
+
audio_item = self._prepare_audio_for_save(audio_np, normalize)
|
454 |
+
sf.write(output_path, audio_item, sampling_rate)
|
455 |
+
saved_paths.append(output_path)
|
456 |
+
|
457 |
+
return saved_paths
|
458 |
+
|
459 |
+
def _prepare_audio_for_save(self, audio: np.ndarray, normalize: bool) -> np.ndarray:
|
460 |
+
"""
|
461 |
+
Prepare audio for saving by ensuring it's the right shape and optionally normalizing.
|
462 |
+
|
463 |
+
Args:
|
464 |
+
audio: Audio data as numpy array
|
465 |
+
normalize: Whether to normalize audio
|
466 |
+
|
467 |
+
Returns:
|
468 |
+
np.ndarray: Processed audio ready for saving
|
469 |
+
"""
|
470 |
+
# Ensure right dimensionality
|
471 |
+
if len(audio.shape) > 1 and audio.shape[0] == 1: # (1, T)
|
472 |
+
audio = audio.squeeze(0)
|
473 |
+
|
474 |
+
# Normalize if requested
|
475 |
+
if normalize:
|
476 |
+
max_val = np.abs(audio).max()
|
477 |
+
if max_val > 0:
|
478 |
+
audio = audio / max_val
|
479 |
+
|
480 |
+
return audio
|
481 |
+
|
482 |
+
|
483 |
+
__all__ = ["VibeVoiceTokenizerProcessor", "AudioNormalizer"]
|
schedule/__init__.py
ADDED
File without changes
|
schedule/dpm_solver.py
ADDED
@@ -0,0 +1,1065 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 TSAIL Team and The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
# DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver
|
16 |
+
|
17 |
+
import math
|
18 |
+
from typing import List, Optional, Tuple, Union
|
19 |
+
|
20 |
+
import numpy as np
|
21 |
+
import torch
|
22 |
+
|
23 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
24 |
+
from diffusers.utils import deprecate
|
25 |
+
from diffusers.utils.torch_utils import randn_tensor
|
26 |
+
from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
|
27 |
+
|
28 |
+
def betas_for_alpha_bar(
|
29 |
+
num_diffusion_timesteps,
|
30 |
+
max_beta=0.999,
|
31 |
+
alpha_transform_type="cosine",
|
32 |
+
):
|
33 |
+
"""
|
34 |
+
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
|
35 |
+
(1-beta) over time from t = [0,1].
|
36 |
+
|
37 |
+
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
|
38 |
+
to that part of the diffusion process.
|
39 |
+
|
40 |
+
|
41 |
+
Args:
|
42 |
+
num_diffusion_timesteps (`int`): the number of betas to produce.
|
43 |
+
max_beta (`float`): the maximum beta to use; use values lower than 1 to
|
44 |
+
prevent singularities.
|
45 |
+
alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
|
46 |
+
Choose from `cosine` or `exp`
|
47 |
+
|
48 |
+
Returns:
|
49 |
+
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
|
50 |
+
"""
|
51 |
+
if alpha_transform_type == "cosine":
|
52 |
+
|
53 |
+
def alpha_bar_fn(t):
|
54 |
+
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
|
55 |
+
# return math.cos(t * math.pi / 2 * 0.95) ** 2
|
56 |
+
|
57 |
+
elif alpha_transform_type == "exp":
|
58 |
+
|
59 |
+
def alpha_bar_fn(t):
|
60 |
+
return math.exp(t * -12.0)
|
61 |
+
|
62 |
+
elif alpha_transform_type == "cauchy":
|
63 |
+
# µ + γ tan (π (0.5 - x)) γ = 1, µ = 3
|
64 |
+
# alpha^2 = 1-1/(exp(λ)+1)
|
65 |
+
def alpha_bar_fn(t, gamma=1, mu=3):
|
66 |
+
snr = mu + gamma * math.tan(math.pi * (0.5 - t) * 0.9)
|
67 |
+
return 1 - 1 / (math.exp(snr) + 1.1)
|
68 |
+
|
69 |
+
elif alpha_transform_type == "laplace":
|
70 |
+
# µ − bsgn(0.5 − t) log(1 − 2|t − 0.5|) µ = 0, b = 1
|
71 |
+
def alpha_bar_fn(t, mu=0, b=1):
|
72 |
+
snr = mu - b * math.copysign(1, 0.5 - t) * math.log(1 - 2 * abs(t - 0.5) * 0.98)
|
73 |
+
return 1 - 1 / (math.exp(snr) + 1.02)
|
74 |
+
|
75 |
+
else:
|
76 |
+
raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
|
77 |
+
|
78 |
+
betas = []
|
79 |
+
for i in range(num_diffusion_timesteps):
|
80 |
+
t1 = i / num_diffusion_timesteps
|
81 |
+
t2 = (i + 1) / num_diffusion_timesteps
|
82 |
+
betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
|
83 |
+
return torch.tensor(betas, dtype=torch.float32)
|
84 |
+
|
85 |
+
|
86 |
+
# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
|
87 |
+
def rescale_zero_terminal_snr(betas):
|
88 |
+
"""
|
89 |
+
Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
|
90 |
+
|
91 |
+
|
92 |
+
Args:
|
93 |
+
betas (`torch.Tensor`):
|
94 |
+
the betas that the scheduler is being initialized with.
|
95 |
+
|
96 |
+
Returns:
|
97 |
+
`torch.Tensor`: rescaled betas with zero terminal SNR
|
98 |
+
"""
|
99 |
+
# Convert betas to alphas_bar_sqrt
|
100 |
+
alphas = 1.0 - betas
|
101 |
+
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
102 |
+
alphas_bar_sqrt = alphas_cumprod.sqrt()
|
103 |
+
|
104 |
+
# Store old values.
|
105 |
+
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
|
106 |
+
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
|
107 |
+
|
108 |
+
# Shift so the last timestep is zero.
|
109 |
+
alphas_bar_sqrt -= alphas_bar_sqrt_T
|
110 |
+
|
111 |
+
# Scale so the first timestep is back to the old value.
|
112 |
+
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
|
113 |
+
|
114 |
+
# Convert alphas_bar_sqrt to betas
|
115 |
+
alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
|
116 |
+
alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
|
117 |
+
alphas = torch.cat([alphas_bar[0:1], alphas])
|
118 |
+
betas = 1 - alphas
|
119 |
+
|
120 |
+
return betas
|
121 |
+
|
122 |
+
class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
123 |
+
"""
|
124 |
+
`DPMSolverMultistepScheduler` is a fast dedicated high-order solver for diffusion ODEs.
|
125 |
+
|
126 |
+
This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
|
127 |
+
methods the library implements for all schedulers such as loading and saving.
|
128 |
+
|
129 |
+
Args:
|
130 |
+
num_train_timesteps (`int`, defaults to 1000):
|
131 |
+
The number of diffusion steps to train the model.
|
132 |
+
beta_start (`float`, defaults to 0.0001):
|
133 |
+
The starting `beta` value of inference.
|
134 |
+
beta_end (`float`, defaults to 0.02):
|
135 |
+
The final `beta` value.
|
136 |
+
beta_schedule (`str`, defaults to `"linear"`):
|
137 |
+
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
|
138 |
+
`linear`, `scaled_linear`, or `squaredcos_cap_v2`.
|
139 |
+
trained_betas (`np.ndarray`, *optional*):
|
140 |
+
Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
|
141 |
+
solver_order (`int`, defaults to 2):
|
142 |
+
The DPMSolver order which can be `1` or `2` or `3`. It is recommended to use `solver_order=2` for guided
|
143 |
+
sampling, and `solver_order=3` for unconditional sampling.
|
144 |
+
prediction_type (`str`, defaults to `epsilon`, *optional*):
|
145 |
+
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
|
146 |
+
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
|
147 |
+
Video](https://imagen.research.google/video/paper.pdf) paper).
|
148 |
+
thresholding (`bool`, defaults to `False`):
|
149 |
+
Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
|
150 |
+
as Stable Diffusion.
|
151 |
+
dynamic_thresholding_ratio (`float`, defaults to 0.995):
|
152 |
+
The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
|
153 |
+
sample_max_value (`float`, defaults to 1.0):
|
154 |
+
The threshold value for dynamic thresholding. Valid only when `thresholding=True` and
|
155 |
+
`algorithm_type="dpmsolver++"`.
|
156 |
+
algorithm_type (`str`, defaults to `dpmsolver++`):
|
157 |
+
Algorithm type for the solver; can be `dpmsolver`, `dpmsolver++`, `sde-dpmsolver` or `sde-dpmsolver++`. The
|
158 |
+
`dpmsolver` type implements the algorithms in the [DPMSolver](https://huggingface.co/papers/2206.00927)
|
159 |
+
paper, and the `dpmsolver++` type implements the algorithms in the
|
160 |
+
[DPMSolver++](https://huggingface.co/papers/2211.01095) paper. It is recommended to use `dpmsolver++` or
|
161 |
+
`sde-dpmsolver++` with `solver_order=2` for guided sampling like in Stable Diffusion.
|
162 |
+
solver_type (`str`, defaults to `midpoint`):
|
163 |
+
Solver type for the second-order solver; can be `midpoint` or `heun`. The solver type slightly affects the
|
164 |
+
sample quality, especially for a small number of steps. It is recommended to use `midpoint` solvers.
|
165 |
+
lower_order_final (`bool`, defaults to `True`):
|
166 |
+
Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can
|
167 |
+
stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10.
|
168 |
+
euler_at_final (`bool`, defaults to `False`):
|
169 |
+
Whether to use Euler's method in the final step. It is a trade-off between numerical stability and detail
|
170 |
+
richness. This can stabilize the sampling of the SDE variant of DPMSolver for small number of inference
|
171 |
+
steps, but sometimes may result in blurring.
|
172 |
+
use_karras_sigmas (`bool`, *optional*, defaults to `False`):
|
173 |
+
Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
|
174 |
+
the sigmas are determined according to a sequence of noise levels {σi}.
|
175 |
+
use_lu_lambdas (`bool`, *optional*, defaults to `False`):
|
176 |
+
Whether to use the uniform-logSNR for step sizes proposed by Lu's DPM-Solver in the noise schedule during
|
177 |
+
the sampling process. If `True`, the sigmas and time steps are determined according to a sequence of
|
178 |
+
`lambda(t)`.
|
179 |
+
final_sigmas_type (`str`, defaults to `"zero"`):
|
180 |
+
The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
|
181 |
+
sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
|
182 |
+
lambda_min_clipped (`float`, defaults to `-inf`):
|
183 |
+
Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the
|
184 |
+
cosine (`squaredcos_cap_v2`) noise schedule.
|
185 |
+
variance_type (`str`, *optional*):
|
186 |
+
Set to "learned" or "learned_range" for diffusion models that predict variance. If set, the model's output
|
187 |
+
contains the predicted Gaussian variance.
|
188 |
+
timestep_spacing (`str`, defaults to `"linspace"`):
|
189 |
+
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
|
190 |
+
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
|
191 |
+
steps_offset (`int`, defaults to 0):
|
192 |
+
An offset added to the inference steps, as required by some model families.
|
193 |
+
rescale_betas_zero_snr (`bool`, defaults to `False`):
|
194 |
+
Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
|
195 |
+
dark samples instead of limiting it to samples with medium brightness. Loosely related to
|
196 |
+
[`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
|
197 |
+
"""
|
198 |
+
|
199 |
+
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
|
200 |
+
order = 1
|
201 |
+
|
202 |
+
@register_to_config
|
203 |
+
def __init__(
|
204 |
+
self,
|
205 |
+
num_train_timesteps: int = 1000,
|
206 |
+
beta_start: float = 0.0001,
|
207 |
+
beta_end: float = 0.02,
|
208 |
+
beta_schedule: str = "linear",
|
209 |
+
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
210 |
+
solver_order: int = 2,
|
211 |
+
prediction_type: str = "epsilon",
|
212 |
+
thresholding: bool = False,
|
213 |
+
dynamic_thresholding_ratio: float = 0.995,
|
214 |
+
sample_max_value: float = 1.0,
|
215 |
+
algorithm_type: str = "dpmsolver++",
|
216 |
+
solver_type: str = "midpoint",
|
217 |
+
lower_order_final: bool = True,
|
218 |
+
euler_at_final: bool = False,
|
219 |
+
use_karras_sigmas: Optional[bool] = False,
|
220 |
+
use_lu_lambdas: Optional[bool] = False,
|
221 |
+
final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
|
222 |
+
lambda_min_clipped: float = -float("inf"),
|
223 |
+
variance_type: Optional[str] = None,
|
224 |
+
timestep_spacing: str = "linspace",
|
225 |
+
steps_offset: int = 0,
|
226 |
+
rescale_betas_zero_snr: bool = False,
|
227 |
+
):
|
228 |
+
if algorithm_type in ["dpmsolver", "sde-dpmsolver"]:
|
229 |
+
deprecation_message = f"algorithm_type {algorithm_type} is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead"
|
230 |
+
deprecate("algorithm_types dpmsolver and sde-dpmsolver", "1.0.0", deprecation_message)
|
231 |
+
|
232 |
+
if trained_betas is not None:
|
233 |
+
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
|
234 |
+
elif beta_schedule == "linear":
|
235 |
+
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
236 |
+
elif beta_schedule == "scaled_linear":
|
237 |
+
# this schedule is very specific to the latent diffusion model.
|
238 |
+
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
239 |
+
elif beta_schedule == "squaredcos_cap_v2" or beta_schedule == "cosine":
|
240 |
+
# Glide cosine schedule
|
241 |
+
self.betas = betas_for_alpha_bar(num_train_timesteps, alpha_transform_type="cosine")
|
242 |
+
elif beta_schedule == "cauchy":
|
243 |
+
self.betas = betas_for_alpha_bar(num_train_timesteps, alpha_transform_type="cauchy")
|
244 |
+
elif beta_schedule == "laplace":
|
245 |
+
self.betas = betas_for_alpha_bar(num_train_timesteps, alpha_transform_type="laplace")
|
246 |
+
else:
|
247 |
+
raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
|
248 |
+
|
249 |
+
if rescale_betas_zero_snr:
|
250 |
+
self.betas = rescale_zero_terminal_snr(self.betas)
|
251 |
+
|
252 |
+
self.alphas = 1.0 - self.betas
|
253 |
+
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
254 |
+
|
255 |
+
if rescale_betas_zero_snr:
|
256 |
+
# Close to 0 without being 0 so first sigma is not inf
|
257 |
+
# FP16 smallest positive subnormal works well here
|
258 |
+
self.alphas_cumprod[-1] = 2**-24
|
259 |
+
|
260 |
+
# Currently we only support VP-type noise schedule
|
261 |
+
self.alpha_t = torch.sqrt(self.alphas_cumprod)
|
262 |
+
self.sigma_t = torch.sqrt(1 - self.alphas_cumprod)
|
263 |
+
self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t)
|
264 |
+
self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5
|
265 |
+
|
266 |
+
# standard deviation of the initial noise distribution
|
267 |
+
self.init_noise_sigma = 1.0
|
268 |
+
|
269 |
+
# settings for DPM-Solver
|
270 |
+
if algorithm_type not in ["dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++"]:
|
271 |
+
if algorithm_type == "deis":
|
272 |
+
self.register_to_config(algorithm_type="dpmsolver++")
|
273 |
+
else:
|
274 |
+
raise NotImplementedError(f"{algorithm_type} is not implemented for {self.__class__}")
|
275 |
+
|
276 |
+
if solver_type not in ["midpoint", "heun"]:
|
277 |
+
if solver_type in ["logrho", "bh1", "bh2"]:
|
278 |
+
self.register_to_config(solver_type="midpoint")
|
279 |
+
else:
|
280 |
+
raise NotImplementedError(f"{solver_type} is not implemented for {self.__class__}")
|
281 |
+
|
282 |
+
if algorithm_type not in ["dpmsolver++", "sde-dpmsolver++"] and final_sigmas_type == "zero":
|
283 |
+
raise ValueError(
|
284 |
+
f"`final_sigmas_type` {final_sigmas_type} is not supported for `algorithm_type` {algorithm_type}. Please choose `sigma_min` instead."
|
285 |
+
)
|
286 |
+
|
287 |
+
# setable values
|
288 |
+
self.num_inference_steps = None
|
289 |
+
timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy()
|
290 |
+
self.timesteps = torch.from_numpy(timesteps)
|
291 |
+
self.model_outputs = [None] * solver_order
|
292 |
+
self.lower_order_nums = 0
|
293 |
+
self._step_index = None
|
294 |
+
self._begin_index = None
|
295 |
+
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
296 |
+
|
297 |
+
@property
|
298 |
+
def step_index(self):
|
299 |
+
"""
|
300 |
+
The index counter for current timestep. It will increase 1 after each scheduler step.
|
301 |
+
"""
|
302 |
+
return self._step_index
|
303 |
+
|
304 |
+
@property
|
305 |
+
def begin_index(self):
|
306 |
+
"""
|
307 |
+
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
|
308 |
+
"""
|
309 |
+
return self._begin_index
|
310 |
+
|
311 |
+
def set_begin_index(self, begin_index: int = 0):
|
312 |
+
"""
|
313 |
+
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
|
314 |
+
|
315 |
+
Args:
|
316 |
+
begin_index (`int`):
|
317 |
+
The begin index for the scheduler.
|
318 |
+
"""
|
319 |
+
self._begin_index = begin_index
|
320 |
+
|
321 |
+
def set_timesteps(
|
322 |
+
self,
|
323 |
+
num_inference_steps: int = None,
|
324 |
+
device: Union[str, torch.device] = None,
|
325 |
+
timesteps: Optional[List[int]] = None,
|
326 |
+
):
|
327 |
+
"""
|
328 |
+
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
329 |
+
|
330 |
+
Args:
|
331 |
+
num_inference_steps (`int`):
|
332 |
+
The number of diffusion steps used when generating samples with a pre-trained model.
|
333 |
+
device (`str` or `torch.device`, *optional*):
|
334 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
335 |
+
timesteps (`List[int]`, *optional*):
|
336 |
+
Custom timesteps used to support arbitrary timesteps schedule. If `None`, timesteps will be generated
|
337 |
+
based on the `timestep_spacing` attribute. If `timesteps` is passed, `num_inference_steps` and `sigmas`
|
338 |
+
must be `None`, and `timestep_spacing` attribute will be ignored.
|
339 |
+
"""
|
340 |
+
if num_inference_steps is None and timesteps is None:
|
341 |
+
raise ValueError("Must pass exactly one of `num_inference_steps` or `timesteps`.")
|
342 |
+
if num_inference_steps is not None and timesteps is not None:
|
343 |
+
raise ValueError("Can only pass one of `num_inference_steps` or `custom_timesteps`.")
|
344 |
+
if timesteps is not None and self.config.use_karras_sigmas:
|
345 |
+
raise ValueError("Cannot use `timesteps` with `config.use_karras_sigmas = True`")
|
346 |
+
if timesteps is not None and self.config.use_lu_lambdas:
|
347 |
+
raise ValueError("Cannot use `timesteps` with `config.use_lu_lambdas = True`")
|
348 |
+
|
349 |
+
if timesteps is not None:
|
350 |
+
timesteps = np.array(timesteps).astype(np.int64)
|
351 |
+
else:
|
352 |
+
# Clipping the minimum of all lambda(t) for numerical stability.
|
353 |
+
# This is critical for cosine (squaredcos_cap_v2) noise schedule.
|
354 |
+
clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.config.lambda_min_clipped)
|
355 |
+
last_timestep = ((self.config.num_train_timesteps - clipped_idx).numpy()).item()
|
356 |
+
|
357 |
+
# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
|
358 |
+
if self.config.timestep_spacing == "linspace":
|
359 |
+
timesteps = (
|
360 |
+
np.linspace(0, last_timestep - 1, num_inference_steps + 1)
|
361 |
+
.round()[::-1][:-1]
|
362 |
+
.copy()
|
363 |
+
.astype(np.int64)
|
364 |
+
)
|
365 |
+
elif self.config.timestep_spacing == "leading":
|
366 |
+
step_ratio = last_timestep // (num_inference_steps + 1)
|
367 |
+
# creates integer timesteps by multiplying by ratio
|
368 |
+
# casting to int to avoid issues when num_inference_step is power of 3
|
369 |
+
timesteps = (
|
370 |
+
(np.arange(0, num_inference_steps + 1) * step_ratio).round()[::-1][:-1].copy().astype(np.int64)
|
371 |
+
)
|
372 |
+
timesteps += self.config.steps_offset
|
373 |
+
elif self.config.timestep_spacing == "trailing":
|
374 |
+
step_ratio = self.config.num_train_timesteps / num_inference_steps
|
375 |
+
# creates integer timesteps by multiplying by ratio
|
376 |
+
# casting to int to avoid issues when num_inference_step is power of 3
|
377 |
+
timesteps = np.arange(last_timestep, 0, -step_ratio).round().copy().astype(np.int64)
|
378 |
+
timesteps -= 1
|
379 |
+
else:
|
380 |
+
raise ValueError(
|
381 |
+
f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
|
382 |
+
)
|
383 |
+
|
384 |
+
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
|
385 |
+
log_sigmas = np.log(sigmas)
|
386 |
+
|
387 |
+
if self.config.use_karras_sigmas:
|
388 |
+
sigmas = np.flip(sigmas).copy()
|
389 |
+
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
|
390 |
+
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
|
391 |
+
elif self.config.use_lu_lambdas:
|
392 |
+
lambdas = np.flip(log_sigmas.copy())
|
393 |
+
lambdas = self._convert_to_lu(in_lambdas=lambdas, num_inference_steps=num_inference_steps)
|
394 |
+
sigmas = np.exp(lambdas)
|
395 |
+
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
|
396 |
+
else:
|
397 |
+
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
|
398 |
+
|
399 |
+
if self.config.final_sigmas_type == "sigma_min":
|
400 |
+
sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
|
401 |
+
elif self.config.final_sigmas_type == "zero":
|
402 |
+
sigma_last = 0
|
403 |
+
else:
|
404 |
+
raise ValueError(
|
405 |
+
f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
|
406 |
+
)
|
407 |
+
|
408 |
+
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
|
409 |
+
|
410 |
+
self.sigmas = torch.from_numpy(sigmas)
|
411 |
+
self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64)
|
412 |
+
|
413 |
+
self.num_inference_steps = len(timesteps)
|
414 |
+
|
415 |
+
self.model_outputs = [
|
416 |
+
None,
|
417 |
+
] * self.config.solver_order
|
418 |
+
self.lower_order_nums = 0
|
419 |
+
|
420 |
+
# add an index counter for schedulers that allow duplicated timesteps
|
421 |
+
self._step_index = None
|
422 |
+
self._begin_index = None
|
423 |
+
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
424 |
+
|
425 |
+
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
|
426 |
+
def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
|
427 |
+
"""
|
428 |
+
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
|
429 |
+
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
|
430 |
+
s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
|
431 |
+
pixels from saturation at each step. We find that dynamic thresholding results in significantly better
|
432 |
+
photorealism as well as better image-text alignment, especially when using very large guidance weights."
|
433 |
+
|
434 |
+
https://arxiv.org/abs/2205.11487
|
435 |
+
"""
|
436 |
+
dtype = sample.dtype
|
437 |
+
batch_size, channels, *remaining_dims = sample.shape
|
438 |
+
|
439 |
+
if dtype not in (torch.float32, torch.float64):
|
440 |
+
sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
|
441 |
+
|
442 |
+
# Flatten sample for doing quantile calculation along each image
|
443 |
+
sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
|
444 |
+
|
445 |
+
abs_sample = sample.abs() # "a certain percentile absolute pixel value"
|
446 |
+
|
447 |
+
s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
|
448 |
+
s = torch.clamp(
|
449 |
+
s, min=1, max=self.config.sample_max_value
|
450 |
+
) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
|
451 |
+
s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
|
452 |
+
sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
|
453 |
+
|
454 |
+
sample = sample.reshape(batch_size, channels, *remaining_dims)
|
455 |
+
sample = sample.to(dtype)
|
456 |
+
|
457 |
+
return sample
|
458 |
+
|
459 |
+
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
|
460 |
+
def _sigma_to_t(self, sigma, log_sigmas):
|
461 |
+
# get log sigma
|
462 |
+
log_sigma = np.log(np.maximum(sigma, 1e-10))
|
463 |
+
|
464 |
+
# get distribution
|
465 |
+
dists = log_sigma - log_sigmas[:, np.newaxis]
|
466 |
+
|
467 |
+
# get sigmas range
|
468 |
+
low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2)
|
469 |
+
high_idx = low_idx + 1
|
470 |
+
|
471 |
+
low = log_sigmas[low_idx]
|
472 |
+
high = log_sigmas[high_idx]
|
473 |
+
|
474 |
+
# interpolate sigmas
|
475 |
+
w = (low - log_sigma) / (low - high)
|
476 |
+
w = np.clip(w, 0, 1)
|
477 |
+
|
478 |
+
# transform interpolation to time range
|
479 |
+
t = (1 - w) * low_idx + w * high_idx
|
480 |
+
t = t.reshape(sigma.shape)
|
481 |
+
return t
|
482 |
+
|
483 |
+
def _sigma_to_alpha_sigma_t(self, sigma):
|
484 |
+
alpha_t = 1 / ((sigma**2 + 1) ** 0.5)
|
485 |
+
sigma_t = sigma * alpha_t
|
486 |
+
|
487 |
+
return alpha_t, sigma_t
|
488 |
+
|
489 |
+
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
|
490 |
+
def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
|
491 |
+
"""Constructs the noise schedule of Karras et al. (2022)."""
|
492 |
+
|
493 |
+
# Hack to make sure that other schedulers which copy this function don't break
|
494 |
+
# TODO: Add this logic to the other schedulers
|
495 |
+
if hasattr(self.config, "sigma_min"):
|
496 |
+
sigma_min = self.config.sigma_min
|
497 |
+
else:
|
498 |
+
sigma_min = None
|
499 |
+
|
500 |
+
if hasattr(self.config, "sigma_max"):
|
501 |
+
sigma_max = self.config.sigma_max
|
502 |
+
else:
|
503 |
+
sigma_max = None
|
504 |
+
|
505 |
+
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
|
506 |
+
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
|
507 |
+
|
508 |
+
rho = 7.0 # 7.0 is the value used in the paper
|
509 |
+
ramp = np.linspace(0, 1, num_inference_steps)
|
510 |
+
min_inv_rho = sigma_min ** (1 / rho)
|
511 |
+
max_inv_rho = sigma_max ** (1 / rho)
|
512 |
+
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
|
513 |
+
return sigmas
|
514 |
+
|
515 |
+
def _convert_to_lu(self, in_lambdas: torch.Tensor, num_inference_steps) -> torch.Tensor:
|
516 |
+
"""Constructs the noise schedule of Lu et al. (2022)."""
|
517 |
+
|
518 |
+
lambda_min: float = in_lambdas[-1].item()
|
519 |
+
lambda_max: float = in_lambdas[0].item()
|
520 |
+
|
521 |
+
rho = 1.0 # 1.0 is the value used in the paper
|
522 |
+
ramp = np.linspace(0, 1, num_inference_steps)
|
523 |
+
min_inv_rho = lambda_min ** (1 / rho)
|
524 |
+
max_inv_rho = lambda_max ** (1 / rho)
|
525 |
+
lambdas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
|
526 |
+
return lambdas
|
527 |
+
|
528 |
+
def convert_model_output(
|
529 |
+
self,
|
530 |
+
model_output: torch.Tensor,
|
531 |
+
*args,
|
532 |
+
sample: torch.Tensor = None,
|
533 |
+
**kwargs,
|
534 |
+
) -> torch.Tensor:
|
535 |
+
"""
|
536 |
+
Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is
|
537 |
+
designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an
|
538 |
+
integral of the data prediction model.
|
539 |
+
|
540 |
+
<Tip>
|
541 |
+
|
542 |
+
The algorithm and model type are decoupled. You can use either DPMSolver or DPMSolver++ for both noise
|
543 |
+
prediction and data prediction models.
|
544 |
+
|
545 |
+
</Tip>
|
546 |
+
|
547 |
+
Args:
|
548 |
+
model_output (`torch.Tensor`):
|
549 |
+
The direct output from the learned diffusion model.
|
550 |
+
sample (`torch.Tensor`):
|
551 |
+
A current instance of a sample created by the diffusion process.
|
552 |
+
|
553 |
+
Returns:
|
554 |
+
`torch.Tensor`:
|
555 |
+
The converted model output.
|
556 |
+
"""
|
557 |
+
timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
|
558 |
+
if sample is None:
|
559 |
+
if len(args) > 1:
|
560 |
+
sample = args[1]
|
561 |
+
else:
|
562 |
+
raise ValueError("missing `sample` as a required keyward argument")
|
563 |
+
if timestep is not None:
|
564 |
+
deprecate(
|
565 |
+
"timesteps",
|
566 |
+
"1.0.0",
|
567 |
+
"Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
568 |
+
)
|
569 |
+
|
570 |
+
# DPM-Solver++ needs to solve an integral of the data prediction model.
|
571 |
+
if self.config.algorithm_type in ["dpmsolver++", "sde-dpmsolver++"]:
|
572 |
+
if self.config.prediction_type == "epsilon":
|
573 |
+
# DPM-Solver and DPM-Solver++ only need the "mean" output.
|
574 |
+
if self.config.variance_type in ["learned", "learned_range"]:
|
575 |
+
model_output = model_output[:, :3]
|
576 |
+
sigma = self.sigmas[self.step_index]
|
577 |
+
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
|
578 |
+
x0_pred = (sample - sigma_t * model_output) / alpha_t
|
579 |
+
elif self.config.prediction_type == "sample":
|
580 |
+
x0_pred = model_output
|
581 |
+
elif self.config.prediction_type == "v_prediction":
|
582 |
+
sigma = self.sigmas[self.step_index]
|
583 |
+
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
|
584 |
+
x0_pred = alpha_t * sample - sigma_t * model_output
|
585 |
+
else:
|
586 |
+
raise ValueError(
|
587 |
+
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
|
588 |
+
" `v_prediction` for the DPMSolverMultistepScheduler."
|
589 |
+
)
|
590 |
+
|
591 |
+
if self.config.thresholding:
|
592 |
+
x0_pred = self._threshold_sample(x0_pred)
|
593 |
+
|
594 |
+
return x0_pred
|
595 |
+
|
596 |
+
# DPM-Solver needs to solve an integral of the noise prediction model.
|
597 |
+
elif self.config.algorithm_type in ["dpmsolver", "sde-dpmsolver"]:
|
598 |
+
if self.config.prediction_type == "epsilon":
|
599 |
+
# DPM-Solver and DPM-Solver++ only need the "mean" output.
|
600 |
+
if self.config.variance_type in ["learned", "learned_range"]:
|
601 |
+
epsilon = model_output[:, :3]
|
602 |
+
else:
|
603 |
+
epsilon = model_output
|
604 |
+
elif self.config.prediction_type == "sample":
|
605 |
+
sigma = self.sigmas[self.step_index]
|
606 |
+
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
|
607 |
+
epsilon = (sample - alpha_t * model_output) / sigma_t
|
608 |
+
elif self.config.prediction_type == "v_prediction":
|
609 |
+
sigma = self.sigmas[self.step_index]
|
610 |
+
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
|
611 |
+
epsilon = alpha_t * model_output + sigma_t * sample
|
612 |
+
else:
|
613 |
+
raise ValueError(
|
614 |
+
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
|
615 |
+
" `v_prediction` for the DPMSolverMultistepScheduler."
|
616 |
+
)
|
617 |
+
|
618 |
+
if self.config.thresholding:
|
619 |
+
sigma = self.sigmas[self.step_index]
|
620 |
+
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
|
621 |
+
x0_pred = (sample - sigma_t * epsilon) / alpha_t
|
622 |
+
x0_pred = self._threshold_sample(x0_pred)
|
623 |
+
epsilon = (sample - alpha_t * x0_pred) / sigma_t
|
624 |
+
|
625 |
+
return epsilon
|
626 |
+
|
627 |
+
def dpm_solver_first_order_update(
|
628 |
+
self,
|
629 |
+
model_output: torch.Tensor,
|
630 |
+
*args,
|
631 |
+
sample: torch.Tensor = None,
|
632 |
+
noise: Optional[torch.Tensor] = None,
|
633 |
+
**kwargs,
|
634 |
+
) -> torch.Tensor:
|
635 |
+
"""
|
636 |
+
One step for the first-order DPMSolver (equivalent to DDIM).
|
637 |
+
|
638 |
+
Args:
|
639 |
+
model_output (`torch.Tensor`):
|
640 |
+
The direct output from the learned diffusion model.
|
641 |
+
sample (`torch.Tensor`):
|
642 |
+
A current instance of a sample created by the diffusion process.
|
643 |
+
|
644 |
+
Returns:
|
645 |
+
`torch.Tensor`:
|
646 |
+
The sample tensor at the previous timestep.
|
647 |
+
"""
|
648 |
+
timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
|
649 |
+
prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
|
650 |
+
if sample is None:
|
651 |
+
if len(args) > 2:
|
652 |
+
sample = args[2]
|
653 |
+
else:
|
654 |
+
raise ValueError(" missing `sample` as a required keyward argument")
|
655 |
+
if timestep is not None:
|
656 |
+
deprecate(
|
657 |
+
"timesteps",
|
658 |
+
"1.0.0",
|
659 |
+
"Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
660 |
+
)
|
661 |
+
|
662 |
+
if prev_timestep is not None:
|
663 |
+
deprecate(
|
664 |
+
"prev_timestep",
|
665 |
+
"1.0.0",
|
666 |
+
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
667 |
+
)
|
668 |
+
|
669 |
+
sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index]
|
670 |
+
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
|
671 |
+
alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s)
|
672 |
+
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
|
673 |
+
lambda_s = torch.log(alpha_s) - torch.log(sigma_s)
|
674 |
+
|
675 |
+
h = lambda_t - lambda_s
|
676 |
+
if self.config.algorithm_type == "dpmsolver++":
|
677 |
+
x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output
|
678 |
+
elif self.config.algorithm_type == "dpmsolver":
|
679 |
+
x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output
|
680 |
+
elif self.config.algorithm_type == "sde-dpmsolver++":
|
681 |
+
assert noise is not None
|
682 |
+
x_t = (
|
683 |
+
(sigma_t / sigma_s * torch.exp(-h)) * sample
|
684 |
+
+ (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output
|
685 |
+
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
|
686 |
+
)
|
687 |
+
elif self.config.algorithm_type == "sde-dpmsolver":
|
688 |
+
assert noise is not None
|
689 |
+
x_t = (
|
690 |
+
(alpha_t / alpha_s) * sample
|
691 |
+
- 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * model_output
|
692 |
+
+ sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise
|
693 |
+
)
|
694 |
+
return x_t
|
695 |
+
|
696 |
+
def multistep_dpm_solver_second_order_update(
|
697 |
+
self,
|
698 |
+
model_output_list: List[torch.Tensor],
|
699 |
+
*args,
|
700 |
+
sample: torch.Tensor = None,
|
701 |
+
noise: Optional[torch.Tensor] = None,
|
702 |
+
**kwargs,
|
703 |
+
) -> torch.Tensor:
|
704 |
+
"""
|
705 |
+
One step for the second-order multistep DPMSolver.
|
706 |
+
|
707 |
+
Args:
|
708 |
+
model_output_list (`List[torch.Tensor]`):
|
709 |
+
The direct outputs from learned diffusion model at current and latter timesteps.
|
710 |
+
sample (`torch.Tensor`):
|
711 |
+
A current instance of a sample created by the diffusion process.
|
712 |
+
|
713 |
+
Returns:
|
714 |
+
`torch.Tensor`:
|
715 |
+
The sample tensor at the previous timestep.
|
716 |
+
"""
|
717 |
+
timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None)
|
718 |
+
prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
|
719 |
+
if sample is None:
|
720 |
+
if len(args) > 2:
|
721 |
+
sample = args[2]
|
722 |
+
else:
|
723 |
+
raise ValueError(" missing `sample` as a required keyward argument")
|
724 |
+
if timestep_list is not None:
|
725 |
+
deprecate(
|
726 |
+
"timestep_list",
|
727 |
+
"1.0.0",
|
728 |
+
"Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
729 |
+
)
|
730 |
+
|
731 |
+
if prev_timestep is not None:
|
732 |
+
deprecate(
|
733 |
+
"prev_timestep",
|
734 |
+
"1.0.0",
|
735 |
+
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
736 |
+
)
|
737 |
+
|
738 |
+
sigma_t, sigma_s0, sigma_s1 = (
|
739 |
+
self.sigmas[self.step_index + 1],
|
740 |
+
self.sigmas[self.step_index],
|
741 |
+
self.sigmas[self.step_index - 1],
|
742 |
+
)
|
743 |
+
|
744 |
+
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
|
745 |
+
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
|
746 |
+
alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)
|
747 |
+
|
748 |
+
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
|
749 |
+
lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
|
750 |
+
lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1)
|
751 |
+
|
752 |
+
m0, m1 = model_output_list[-1], model_output_list[-2]
|
753 |
+
|
754 |
+
h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1
|
755 |
+
r0 = h_0 / h
|
756 |
+
D0, D1 = m0, (1.0 / r0) * (m0 - m1)
|
757 |
+
if self.config.algorithm_type == "dpmsolver++":
|
758 |
+
# See https://arxiv.org/abs/2211.01095 for detailed derivations
|
759 |
+
if self.config.solver_type == "midpoint":
|
760 |
+
x_t = (
|
761 |
+
(sigma_t / sigma_s0) * sample
|
762 |
+
- (alpha_t * (torch.exp(-h) - 1.0)) * D0
|
763 |
+
- 0.5 * (alpha_t * (torch.exp(-h) - 1.0)) * D1
|
764 |
+
)
|
765 |
+
elif self.config.solver_type == "heun":
|
766 |
+
x_t = (
|
767 |
+
(sigma_t / sigma_s0) * sample
|
768 |
+
- (alpha_t * (torch.exp(-h) - 1.0)) * D0
|
769 |
+
+ (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1
|
770 |
+
)
|
771 |
+
elif self.config.algorithm_type == "dpmsolver":
|
772 |
+
# See https://arxiv.org/abs/2206.00927 for detailed derivations
|
773 |
+
if self.config.solver_type == "midpoint":
|
774 |
+
x_t = (
|
775 |
+
(alpha_t / alpha_s0) * sample
|
776 |
+
- (sigma_t * (torch.exp(h) - 1.0)) * D0
|
777 |
+
- 0.5 * (sigma_t * (torch.exp(h) - 1.0)) * D1
|
778 |
+
)
|
779 |
+
elif self.config.solver_type == "heun":
|
780 |
+
x_t = (
|
781 |
+
(alpha_t / alpha_s0) * sample
|
782 |
+
- (sigma_t * (torch.exp(h) - 1.0)) * D0
|
783 |
+
- (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
|
784 |
+
)
|
785 |
+
elif self.config.algorithm_type == "sde-dpmsolver++":
|
786 |
+
assert noise is not None
|
787 |
+
if self.config.solver_type == "midpoint":
|
788 |
+
x_t = (
|
789 |
+
(sigma_t / sigma_s0 * torch.exp(-h)) * sample
|
790 |
+
+ (alpha_t * (1 - torch.exp(-2.0 * h))) * D0
|
791 |
+
+ 0.5 * (alpha_t * (1 - torch.exp(-2.0 * h))) * D1
|
792 |
+
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
|
793 |
+
)
|
794 |
+
elif self.config.solver_type == "heun":
|
795 |
+
x_t = (
|
796 |
+
(sigma_t / sigma_s0 * torch.exp(-h)) * sample
|
797 |
+
+ (alpha_t * (1 - torch.exp(-2.0 * h))) * D0
|
798 |
+
+ (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1
|
799 |
+
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
|
800 |
+
)
|
801 |
+
elif self.config.algorithm_type == "sde-dpmsolver":
|
802 |
+
assert noise is not None
|
803 |
+
if self.config.solver_type == "midpoint":
|
804 |
+
x_t = (
|
805 |
+
(alpha_t / alpha_s0) * sample
|
806 |
+
- 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0
|
807 |
+
- (sigma_t * (torch.exp(h) - 1.0)) * D1
|
808 |
+
+ sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise
|
809 |
+
)
|
810 |
+
elif self.config.solver_type == "heun":
|
811 |
+
x_t = (
|
812 |
+
(alpha_t / alpha_s0) * sample
|
813 |
+
- 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0
|
814 |
+
- 2.0 * (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
|
815 |
+
+ sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise
|
816 |
+
)
|
817 |
+
return x_t
|
818 |
+
|
819 |
+
def multistep_dpm_solver_third_order_update(
|
820 |
+
self,
|
821 |
+
model_output_list: List[torch.Tensor],
|
822 |
+
*args,
|
823 |
+
sample: torch.Tensor = None,
|
824 |
+
**kwargs,
|
825 |
+
) -> torch.Tensor:
|
826 |
+
"""
|
827 |
+
One step for the third-order multistep DPMSolver.
|
828 |
+
|
829 |
+
Args:
|
830 |
+
model_output_list (`List[torch.Tensor]`):
|
831 |
+
The direct outputs from learned diffusion model at current and latter timesteps.
|
832 |
+
sample (`torch.Tensor`):
|
833 |
+
A current instance of a sample created by diffusion process.
|
834 |
+
|
835 |
+
Returns:
|
836 |
+
`torch.Tensor`:
|
837 |
+
The sample tensor at the previous timestep.
|
838 |
+
"""
|
839 |
+
|
840 |
+
timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None)
|
841 |
+
prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
|
842 |
+
if sample is None:
|
843 |
+
if len(args) > 2:
|
844 |
+
sample = args[2]
|
845 |
+
else:
|
846 |
+
raise ValueError(" missing`sample` as a required keyward argument")
|
847 |
+
if timestep_list is not None:
|
848 |
+
deprecate(
|
849 |
+
"timestep_list",
|
850 |
+
"1.0.0",
|
851 |
+
"Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
852 |
+
)
|
853 |
+
|
854 |
+
if prev_timestep is not None:
|
855 |
+
deprecate(
|
856 |
+
"prev_timestep",
|
857 |
+
"1.0.0",
|
858 |
+
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
859 |
+
)
|
860 |
+
|
861 |
+
sigma_t, sigma_s0, sigma_s1, sigma_s2 = (
|
862 |
+
self.sigmas[self.step_index + 1],
|
863 |
+
self.sigmas[self.step_index],
|
864 |
+
self.sigmas[self.step_index - 1],
|
865 |
+
self.sigmas[self.step_index - 2],
|
866 |
+
)
|
867 |
+
|
868 |
+
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
|
869 |
+
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
|
870 |
+
alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)
|
871 |
+
alpha_s2, sigma_s2 = self._sigma_to_alpha_sigma_t(sigma_s2)
|
872 |
+
|
873 |
+
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
|
874 |
+
lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
|
875 |
+
lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1)
|
876 |
+
lambda_s2 = torch.log(alpha_s2) - torch.log(sigma_s2)
|
877 |
+
|
878 |
+
m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3]
|
879 |
+
|
880 |
+
h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2
|
881 |
+
r0, r1 = h_0 / h, h_1 / h
|
882 |
+
D0 = m0
|
883 |
+
D1_0, D1_1 = (1.0 / r0) * (m0 - m1), (1.0 / r1) * (m1 - m2)
|
884 |
+
D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1)
|
885 |
+
D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1)
|
886 |
+
if self.config.algorithm_type == "dpmsolver++":
|
887 |
+
# See https://arxiv.org/abs/2206.00927 for detailed derivations
|
888 |
+
x_t = (
|
889 |
+
(sigma_t / sigma_s0) * sample
|
890 |
+
- (alpha_t * (torch.exp(-h) - 1.0)) * D0
|
891 |
+
+ (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1
|
892 |
+
- (alpha_t * ((torch.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2
|
893 |
+
)
|
894 |
+
elif self.config.algorithm_type == "dpmsolver":
|
895 |
+
# See https://arxiv.org/abs/2206.00927 for detailed derivations
|
896 |
+
x_t = (
|
897 |
+
(alpha_t / alpha_s0) * sample
|
898 |
+
- (sigma_t * (torch.exp(h) - 1.0)) * D0
|
899 |
+
- (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
|
900 |
+
- (sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2
|
901 |
+
)
|
902 |
+
return x_t
|
903 |
+
|
904 |
+
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
905 |
+
if schedule_timesteps is None:
|
906 |
+
schedule_timesteps = self.timesteps
|
907 |
+
|
908 |
+
index_candidates = (schedule_timesteps == timestep).nonzero()
|
909 |
+
|
910 |
+
if len(index_candidates) == 0:
|
911 |
+
step_index = len(self.timesteps) - 1
|
912 |
+
# The sigma index that is taken for the **very** first `step`
|
913 |
+
# is always the second index (or the last index if there is only 1)
|
914 |
+
# This way we can ensure we don't accidentally skip a sigma in
|
915 |
+
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
|
916 |
+
elif len(index_candidates) > 1:
|
917 |
+
step_index = index_candidates[1].item()
|
918 |
+
else:
|
919 |
+
step_index = index_candidates[0].item()
|
920 |
+
|
921 |
+
return step_index
|
922 |
+
|
923 |
+
def _init_step_index(self, timestep):
|
924 |
+
"""
|
925 |
+
Initialize the step_index counter for the scheduler.
|
926 |
+
"""
|
927 |
+
|
928 |
+
if self.begin_index is None:
|
929 |
+
if isinstance(timestep, torch.Tensor):
|
930 |
+
timestep = timestep.to(self.timesteps.device)
|
931 |
+
self._step_index = self.index_for_timestep(timestep)
|
932 |
+
else:
|
933 |
+
self._step_index = self._begin_index
|
934 |
+
|
935 |
+
def step(
|
936 |
+
self,
|
937 |
+
model_output: torch.Tensor,
|
938 |
+
timestep: int,
|
939 |
+
sample: torch.Tensor,
|
940 |
+
generator=None,
|
941 |
+
variance_noise: Optional[torch.Tensor] = None,
|
942 |
+
return_dict: bool = True,
|
943 |
+
) -> Union[SchedulerOutput, Tuple]:
|
944 |
+
"""
|
945 |
+
Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with
|
946 |
+
the multistep DPMSolver.
|
947 |
+
|
948 |
+
Args:
|
949 |
+
model_output (`torch.Tensor`):
|
950 |
+
The direct output from learned diffusion model.
|
951 |
+
timestep (`int`):
|
952 |
+
The current discrete timestep in the diffusion chain.
|
953 |
+
sample (`torch.Tensor`):
|
954 |
+
A current instance of a sample created by the diffusion process.
|
955 |
+
generator (`torch.Generator`, *optional*):
|
956 |
+
A random number generator.
|
957 |
+
variance_noise (`torch.Tensor`):
|
958 |
+
Alternative to generating noise with `generator` by directly providing the noise for the variance
|
959 |
+
itself. Useful for methods such as [`LEdits++`].
|
960 |
+
return_dict (`bool`):
|
961 |
+
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
|
962 |
+
|
963 |
+
Returns:
|
964 |
+
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
|
965 |
+
If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
|
966 |
+
tuple is returned where the first element is the sample tensor.
|
967 |
+
|
968 |
+
"""
|
969 |
+
if self.num_inference_steps is None:
|
970 |
+
raise ValueError(
|
971 |
+
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
|
972 |
+
)
|
973 |
+
|
974 |
+
if self.step_index is None:
|
975 |
+
self._init_step_index(timestep)
|
976 |
+
|
977 |
+
# Improve numerical stability for small number of steps
|
978 |
+
lower_order_final = (self.step_index == len(self.timesteps) - 1) and (
|
979 |
+
self.config.euler_at_final
|
980 |
+
or (self.config.lower_order_final and len(self.timesteps) < 15)
|
981 |
+
or self.config.final_sigmas_type == "zero"
|
982 |
+
)
|
983 |
+
lower_order_second = (
|
984 |
+
(self.step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15
|
985 |
+
)
|
986 |
+
|
987 |
+
model_output = self.convert_model_output(model_output, sample=sample)
|
988 |
+
for i in range(self.config.solver_order - 1):
|
989 |
+
self.model_outputs[i] = self.model_outputs[i + 1]
|
990 |
+
self.model_outputs[-1] = model_output
|
991 |
+
|
992 |
+
# Upcast to avoid precision issues when computing prev_sample
|
993 |
+
sample = sample.to(torch.float32)
|
994 |
+
if self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"] and variance_noise is None:
|
995 |
+
noise = randn_tensor(
|
996 |
+
model_output.shape, generator=generator, device=model_output.device, dtype=torch.float32
|
997 |
+
)
|
998 |
+
elif self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]:
|
999 |
+
noise = variance_noise.to(device=model_output.device, dtype=torch.float32)
|
1000 |
+
else:
|
1001 |
+
noise = None
|
1002 |
+
|
1003 |
+
if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final:
|
1004 |
+
prev_sample = self.dpm_solver_first_order_update(model_output, sample=sample, noise=noise)
|
1005 |
+
elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second:
|
1006 |
+
prev_sample = self.multistep_dpm_solver_second_order_update(self.model_outputs, sample=sample, noise=noise)
|
1007 |
+
else:
|
1008 |
+
prev_sample = self.multistep_dpm_solver_third_order_update(self.model_outputs, sample=sample)
|
1009 |
+
|
1010 |
+
if self.lower_order_nums < self.config.solver_order:
|
1011 |
+
self.lower_order_nums += 1
|
1012 |
+
|
1013 |
+
# Cast sample back to expected dtype
|
1014 |
+
prev_sample = prev_sample.to(model_output.dtype)
|
1015 |
+
|
1016 |
+
# upon completion increase step index by one
|
1017 |
+
self._step_index += 1
|
1018 |
+
|
1019 |
+
if not return_dict:
|
1020 |
+
return (prev_sample,)
|
1021 |
+
|
1022 |
+
return SchedulerOutput(prev_sample=prev_sample)
|
1023 |
+
|
1024 |
+
def add_noise(
|
1025 |
+
self,
|
1026 |
+
original_samples: torch.Tensor,
|
1027 |
+
noise: torch.Tensor,
|
1028 |
+
timesteps: torch.IntTensor,
|
1029 |
+
) -> torch.Tensor:
|
1030 |
+
# Make sure sigmas and timesteps have the same device and dtype as original_samples
|
1031 |
+
# alpha_t = self.alpha_t.to(device=original_samples.device, dtype=original_samples.dtype)
|
1032 |
+
# sigma_t = self.sigma_t.to(device=original_samples.device, dtype=original_samples.dtype)
|
1033 |
+
alpha_t = self.alpha_t.to(original_samples.device).to(original_samples.dtype)
|
1034 |
+
sigma_t = self.sigma_t.to(original_samples.device).to(original_samples.dtype)
|
1035 |
+
timesteps = timesteps.to(original_samples.device)
|
1036 |
+
alpha_t = alpha_t[timesteps].flatten()
|
1037 |
+
while len(alpha_t.shape) < len(original_samples.shape):
|
1038 |
+
alpha_t = alpha_t.unsqueeze(-1)
|
1039 |
+
|
1040 |
+
sigma_t = sigma_t[timesteps].flatten()
|
1041 |
+
while len(sigma_t.shape) < len(original_samples.shape):
|
1042 |
+
sigma_t = sigma_t.unsqueeze(-1)
|
1043 |
+
noisy_samples = alpha_t * original_samples + sigma_t * noise
|
1044 |
+
return noisy_samples
|
1045 |
+
|
1046 |
+
def get_velocity(self, original_samples: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor) -> torch.Tensor:
|
1047 |
+
# alpha_t = self.alpha_t.to(device=original_samples.device, dtype=original_samples.dtype)
|
1048 |
+
# sigma_t = self.sigma_t.to(device=original_samples.device, dtype=original_samples.dtype)
|
1049 |
+
alpha_t = self.alpha_t.to(original_samples.device).to(original_samples.dtype)
|
1050 |
+
sigma_t = self.sigma_t.to(original_samples.device).to(original_samples.dtype)
|
1051 |
+
|
1052 |
+
timesteps = timesteps.to(original_samples.device)
|
1053 |
+
alpha_t = alpha_t[timesteps].flatten()
|
1054 |
+
while len(alpha_t.shape) < len(original_samples.shape):
|
1055 |
+
alpha_t = alpha_t.unsqueeze(-1)
|
1056 |
+
|
1057 |
+
sigma_t = sigma_t[timesteps].flatten()
|
1058 |
+
while len(sigma_t.shape) < len(original_samples.shape):
|
1059 |
+
sigma_t = sigma_t.unsqueeze(-1)
|
1060 |
+
|
1061 |
+
velocity = alpha_t * noise - sigma_t * original_samples
|
1062 |
+
return velocity
|
1063 |
+
|
1064 |
+
def __len__(self):
|
1065 |
+
return self.config.num_train_timesteps
|
schedule/timestep_sampler.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
|
4 |
+
|
5 |
+
class UniformSampler:
|
6 |
+
def __init__(self, timesteps = 1000):
|
7 |
+
self.timesteps = timesteps
|
8 |
+
def sample(self, batch_size, device):
|
9 |
+
return torch.randint(0, self.timesteps, (batch_size,), device=device)
|
10 |
+
|
11 |
+
class LogitNormalSampler:
|
12 |
+
def __init__(self, timesteps = 1000, m = 0, s = 1):
|
13 |
+
self.timesteps = timesteps
|
14 |
+
timesteps = torch.linspace(0, 1, timesteps)
|
15 |
+
logit = torch.log(timesteps / (1 - timesteps))
|
16 |
+
self.prob = torch.exp(-0.5 * (logit - m) ** 2 / s ** 2) / (s * math.sqrt(2 * math.pi))
|
17 |
+
def sample(self, batch_size, device):
|
18 |
+
return torch.multinomial(self.prob, batch_size, replacement=True).to(device)
|
19 |
+
|
scripts/__init__.py
ADDED
File without changes
|
scripts/convert_nnscaler_checkpoint_to_transformers.py
ADDED
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# coding=utf-8
|
3 |
+
|
4 |
+
import argparse
|
5 |
+
import json
|
6 |
+
import os
|
7 |
+
from pathlib import Path
|
8 |
+
import re
|
9 |
+
import torch
|
10 |
+
from typing import Dict, List, Tuple
|
11 |
+
|
12 |
+
from vibevoice.modular.configuration_vibevoice import (
|
13 |
+
VibeVoiceConfig
|
14 |
+
)
|
15 |
+
from vibevoice.modular.modeling_vibevoice import VibeVoiceForConditionalGeneration
|
16 |
+
from transformers.utils import logging
|
17 |
+
|
18 |
+
logger = logging.get_logger(__name__)
|
19 |
+
|
20 |
+
def convert_vibevoice_nnscaler_checkpoint_to_hf(
|
21 |
+
checkpoint_path: str,
|
22 |
+
pytorch_dump_folder_path: str,
|
23 |
+
config_path: str = None,
|
24 |
+
):
|
25 |
+
"""
|
26 |
+
Convert a nnscaler VibeVoice checkpoint to HuggingFace format.
|
27 |
+
Supports both regular checkpoints and tensor parallel checkpoints.
|
28 |
+
"""
|
29 |
+
|
30 |
+
# Load regular checkpoint
|
31 |
+
logger.info(f"Loading regular checkpoint from {checkpoint_path}")
|
32 |
+
checkpoint = torch.load(checkpoint_path, map_location="cpu") # ['model', 'optimizer', 'lr_scheduler', 'train_status', 'train_args', 'rng_states', 'nnscaler', 'dataloader']
|
33 |
+
|
34 |
+
# config = checkpoint['train_args']
|
35 |
+
init_config_name = checkpoint['train_args']['vars']['model_args']['config_path']['relative_path']
|
36 |
+
pretrained_name = checkpoint['train_args']['vars']['data_args']['tokenizer_path']
|
37 |
+
|
38 |
+
init_config_path = Path(__file__).parent.parent / 'configs' / init_config_name.split('/')[-1]
|
39 |
+
if init_config_path.exists():
|
40 |
+
logger.info(f"Loading initial config from {init_config_path}")
|
41 |
+
with open(init_config_path, 'r') as f:
|
42 |
+
init_config = json.load(f)
|
43 |
+
else:
|
44 |
+
raise FileNotFoundError(f"Initial config file {init_config_path} not found. Please provide a valid path.")
|
45 |
+
|
46 |
+
tie_word_embeddings = init_config['decoder_config'].get('tie_word_embeddings', True)
|
47 |
+
logger.info(f"Tie word embeddings: {tie_word_embeddings}")
|
48 |
+
|
49 |
+
init_config['decoder_config']['use_cache'] = True
|
50 |
+
config = VibeVoiceConfig(**init_config, tie_word_embeddings=tie_word_embeddings)
|
51 |
+
|
52 |
+
# # Extract the model state dict
|
53 |
+
model_state_dict = {k.replace('model.model.', 'model.'): v for k, v in checkpoint["model"].items() if k.startswith('model.model.')}
|
54 |
+
if not tie_word_embeddings and 'model.lm_head.weight' in checkpoint["model"].keys():
|
55 |
+
# If not tying weights, we need to add the lm_head weight separately
|
56 |
+
model_state_dict['lm_head.weight'] = checkpoint["model"]['model.lm_head.weight']
|
57 |
+
|
58 |
+
# Override with provided config if available
|
59 |
+
if config_path:
|
60 |
+
logger.info(f"Loading config from {config_path}")
|
61 |
+
with open(config_path, 'r') as f:
|
62 |
+
config_dict = json.load(f)
|
63 |
+
config = VibeVoiceConfig.from_dict(config_dict)
|
64 |
+
|
65 |
+
# Set the default dtype to bfloat16 before creating the model
|
66 |
+
original_dtype = torch.get_default_dtype()
|
67 |
+
torch.set_default_dtype(torch.bfloat16)
|
68 |
+
|
69 |
+
# Create the HuggingFace model
|
70 |
+
logger.info("Creating HuggingFace VibeVoiceForConditionalGeneration model")
|
71 |
+
model = VibeVoiceForConditionalGeneration(config)
|
72 |
+
|
73 |
+
# Restore original dtype
|
74 |
+
torch.set_default_dtype(original_dtype)
|
75 |
+
|
76 |
+
# Load the state dict
|
77 |
+
logger.info("Loading weights into model")
|
78 |
+
missing_keys, unexpected_keys = model.load_state_dict(model_state_dict, strict=False)
|
79 |
+
|
80 |
+
if missing_keys:
|
81 |
+
logger.warning(f"Missing keys: {missing_keys}")
|
82 |
+
if unexpected_keys:
|
83 |
+
logger.warning(f"Unexpected keys: {unexpected_keys}")
|
84 |
+
|
85 |
+
# Create output directory
|
86 |
+
os.makedirs(pytorch_dump_folder_path, exist_ok=True)
|
87 |
+
|
88 |
+
# Save the model and config
|
89 |
+
logger.info(f"Saving model to {pytorch_dump_folder_path}")
|
90 |
+
|
91 |
+
# Save config
|
92 |
+
config.save_pretrained(pytorch_dump_folder_path)
|
93 |
+
|
94 |
+
# Save VibeVoiceProcessor configuration
|
95 |
+
logger.info("Saving VibeVoiceProcessor configuration")
|
96 |
+
processor_config = {
|
97 |
+
"processor_class": "VibeVoiceProcessor",
|
98 |
+
"speech_tok_compress_ratio": 3200,
|
99 |
+
"db_normalize": True,
|
100 |
+
# Audio processor configuration
|
101 |
+
"audio_processor": {
|
102 |
+
"feature_extractor_type": "VibeVoiceTokenizerProcessor",
|
103 |
+
"sampling_rate": 24000,
|
104 |
+
"normalize_audio": True,
|
105 |
+
"target_dB_FS": -25,
|
106 |
+
"eps": 1e-6,
|
107 |
+
},
|
108 |
+
"language_model_pretrained_name": pretrained_name,
|
109 |
+
}
|
110 |
+
|
111 |
+
processor_config_path = os.path.join(pytorch_dump_folder_path, "preprocessor_config.json")
|
112 |
+
with open(processor_config_path, 'w') as f:
|
113 |
+
json.dump(processor_config, f, indent=2)
|
114 |
+
logger.info(f"Saved processor config to {processor_config_path}")
|
115 |
+
|
116 |
+
# Save model with sharding
|
117 |
+
# save_pretrained handles tied weights automatically
|
118 |
+
logger.info("Saving model weights with sharding...")
|
119 |
+
model.save_pretrained(
|
120 |
+
pytorch_dump_folder_path,
|
121 |
+
max_shard_size="2GB", # Set maximum size for each shard
|
122 |
+
safe_serialization=True # Ensure saving in .safetensors format
|
123 |
+
)
|
124 |
+
logger.info(f"Model weights saved to {pytorch_dump_folder_path}")
|
125 |
+
|
126 |
+
logger.info("Conversion complete!")
|
127 |
+
|
128 |
+
# Verify the saved model can be loaded
|
129 |
+
logger.info("Verifying saved model...")
|
130 |
+
loaded_model = VibeVoiceForConditionalGeneration.from_pretrained(pytorch_dump_folder_path)
|
131 |
+
logger.info("Model successfully loaded from saved checkpoint!")
|
132 |
+
|
133 |
+
def main():
|
134 |
+
parser = argparse.ArgumentParser()
|
135 |
+
parser.add_argument(
|
136 |
+
"--nnscaler_checkpoint_path",
|
137 |
+
type=str,
|
138 |
+
required=True,
|
139 |
+
help="Path to the fairseq checkpoint (.pt file). For tensor parallel checkpoints, "
|
140 |
+
"provide any one of the part files (e.g., checkpoint_1_5000-model_part-0.pt), "
|
141 |
+
"and the script will automatically detect and merge all parts.",
|
142 |
+
)
|
143 |
+
parser.add_argument(
|
144 |
+
"--pytorch_dump_folder_path",
|
145 |
+
type=str,
|
146 |
+
required=True,
|
147 |
+
help="Path to the output PyTorch model directory",
|
148 |
+
)
|
149 |
+
parser.add_argument(
|
150 |
+
"--config_path",
|
151 |
+
type=str,
|
152 |
+
default=None,
|
153 |
+
help="Optional path to a config JSON file to override extracted config",
|
154 |
+
)
|
155 |
+
|
156 |
+
args = parser.parse_args()
|
157 |
+
|
158 |
+
convert_vibevoice_nnscaler_checkpoint_to_hf(
|
159 |
+
args.nnscaler_checkpoint_path,
|
160 |
+
args.pytorch_dump_folder_path,
|
161 |
+
args.config_path,
|
162 |
+
)
|
163 |
+
|
164 |
+
|
165 |
+
if __name__ == "__main__":
|
166 |
+
main()
|