Mark Shi commited on
Commit
c0a944c
Β·
1 Parent(s): 789bd04

upload code

Browse files
app.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+
3
+ import subprocess
4
+ subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
5
+
6
+ import os
7
+ import random
8
+ import shutil
9
+ import pickle
10
+ import gradio as gr
11
+ import soundfile as sf
12
+ from pathlib import Path
13
+
14
+ import torch
15
+ import torchaudio
16
+
17
+ from huggingface_hub import hf_hub_download
18
+
19
+ from infer import load_model, eval_model
20
+ from spkr import SpeakerEmbedding
21
+
22
+
23
+ spkr_model = SpeakerEmbedding(device="cuda")
24
+ model, tokenizer, tokenizer_voila, model_type = load_model("maitrix-org/Voila-chat", "maitrix-org/Voila-Tokenizer")
25
+ default_ref_file = "examples/character_ref_emb_demo.pkl"
26
+ default_ref_name = "Homer Simpson"
27
+ million_voice_ref_file = hf_hub_download(repo_id="maitrix-org/Voila-million-voice", filename="character_ref_emb_chunk0.pkl", repo_type="dataset")
28
+
29
+ instruction = "You are a smart AI agent created by Maitrix.org."
30
+ save_path = os.environ.get("GRADIO_TEMP_DIR", tempfile.gettempdir())
31
+
32
+ intro = """**Voila**
33
+
34
+ For more demos, please goto [https://voila.maitrix.org](https://voila.maitrix.org)."""
35
+
36
+ default_ref_emb_mask_list = pickle.load(open(default_ref_file, "rb"))
37
+ million_voice_ref_emb_mask_list = pickle.load(open(million_voice_ref_file, "rb"))
38
+
39
+ def get_ref_embs(ref_audio):
40
+ wav, sr = torchaudio.load(ref_audio)
41
+ ref_embs = spkr_model(wav, sr).cpu()
42
+ return ref_embs
43
+
44
+ def delete_directory(request: gr.Request):
45
+ if not request.session_hash:
46
+ return
47
+ user_dir = Path(f"{save_path}/{str(request.session_hash)}")
48
+ if user_dir.exists():
49
+ shutil.rmtree(str(user_dir))
50
+
51
+ def add_message(history, message):
52
+ history.append({"role": "user", "content": {"path": message}})
53
+ return history, gr.Audio(value=None), gr.Button(interactive=False)
54
+
55
+ def call_bot(history, ref_embs, request: gr.Request):
56
+ formated_history = {
57
+ "instruction": instruction,
58
+ "conversations": [{'from': item["role"], 'audio': {"file": item["content"][0]}} for item in history],
59
+ }
60
+ formated_history["conversations"].append({"from": "assistant"})
61
+ print(formated_history)
62
+ ref_embs = torch.tensor(ref_embs, dtype=torch.float32, device="cuda")
63
+ ref_embs_mask = torch.tensor([1], device="cuda")
64
+ out = eval_model(model, tokenizer, tokenizer_voila, model_type, "chat_aiao", formated_history, ref_embs, ref_embs_mask, max_new_tokens=512)
65
+ if 'audio' in out:
66
+ wav, sr = out['audio']
67
+
68
+ user_dir = Path(f"{save_path}/{str(request.session_hash)}")
69
+ user_dir.mkdir(exist_ok=True)
70
+ save_name = f"{user_dir}/{len(history)}.wav"
71
+ sf.write(save_name, wav, sr)
72
+
73
+ history.append({"role": "assistant", "content": {"path": save_name}})
74
+ else:
75
+ history.append({"role": "assistant", "content": {"text": out['text']}})
76
+
77
+ return history
78
+
79
+ def run_tts(text, ref_embs):
80
+ formated_history = {
81
+ "instruction": "",
82
+ "conversations": [{'from': "user", 'text': text}],
83
+ }
84
+ formated_history["conversations"].append({"from": "assistant"})
85
+ ref_embs = torch.tensor(ref_embs, dtype=torch.float32, device="cuda")
86
+ ref_embs_mask = torch.tensor([1], device="cuda")
87
+ out = eval_model(model, tokenizer, tokenizer_voila, model_type, "chat_tts", formated_history, ref_embs, ref_embs_mask, max_new_tokens=512)
88
+ if 'audio' in out:
89
+ wav, sr = out['audio']
90
+ return (sr, wav)
91
+ else:
92
+ raise Exception("No audio output")
93
+
94
+ def run_asr(audio):
95
+ formated_history = {
96
+ "instruction": "",
97
+ "conversations": [{'from': "user", 'audio': {"file": audio}}],
98
+ }
99
+ formated_history["conversations"].append({"from": "assistant"})
100
+ out = eval_model(model, tokenizer, tokenizer_voila, model_type, "chat_asr", formated_history, None, None, max_new_tokens=512)
101
+ if 'text' in out:
102
+ return out['text']
103
+ else:
104
+ raise Exception("No text output")
105
+
106
+
107
+ def markdown_ref_name(ref_name):
108
+ return f"### Current voice id: {ref_name}"
109
+
110
+ def random_million_voice():
111
+ voice_id = random.choice(list(million_voice_ref_emb_mask_list.keys()))
112
+ return markdown_ref_name(voice_id), million_voice_ref_emb_mask_list[voice_id]
113
+
114
+ def get_ref_modules(cur_ref_embs):
115
+ with gr.Row() as ref_row:
116
+ with gr.Row():
117
+ current_ref_name = gr.Markdown(markdown_ref_name(default_ref_name))
118
+ with gr.Row() as ref_name_row:
119
+ with gr.Column(scale=2, min_width=160):
120
+ ref_name_dropdown = gr.Dropdown(
121
+ choices=list(default_ref_emb_mask_list.keys()),
122
+ value=default_ref_name,
123
+ label="Reference voice",
124
+ min_width=160,
125
+ )
126
+ with gr.Column(scale=1, min_width=80):
127
+ random_ref_button = gr.Button(
128
+ "Random from Million Voice", size="md",
129
+ )
130
+ with gr.Row(visible=False) as ref_audio_row:
131
+ with gr.Column(scale=2, min_width=80):
132
+ ref_audio = gr.Audio(
133
+ sources=["microphone", "upload"],
134
+ type="filepath",
135
+ show_label=False,
136
+ min_width=80,
137
+ )
138
+ with gr.Column(scale=1, min_width=80):
139
+ change_ref_button = gr.Button(
140
+ "Change voice",
141
+ interactive=False,
142
+ min_width=80,
143
+ )
144
+ ref_name_dropdown.change(
145
+ lambda x: (markdown_ref_name(x), default_ref_emb_mask_list[x]),
146
+ ref_name_dropdown,
147
+ [current_ref_name, cur_ref_embs]
148
+ )
149
+ random_ref_button.click(
150
+ random_million_voice,
151
+ None,
152
+ [current_ref_name, cur_ref_embs],
153
+ )
154
+ ref_audio.input(lambda: gr.Button(interactive=True), None, change_ref_button)
155
+ # If custom ref voice checkbox is checked, show the Audio component to record or upload a reference voice
156
+ custom_ref_voice = gr.Checkbox(label="Use custom voice", value=False)
157
+ # Checked: enable audio and button
158
+ # Unchecked: disable audio and button
159
+ def custom_ref_voice_change(x, cur_ref_embs, cur_ref_embs_mask):
160
+ if not x:
161
+ cur_ref_embs = default_ref_emb_mask_list[default_ref_name]
162
+ return [gr.Row(visible=not x), gr.Audio(value=None), gr.Row(visible=x), markdown_ref_name("Custom voice"), cur_ref_embs]
163
+ custom_ref_voice.change(
164
+ custom_ref_voice_change,
165
+ [custom_ref_voice, cur_ref_embs],
166
+ [ref_name_row, ref_audio, ref_audio_row, current_ref_name, cur_ref_embs]
167
+ )
168
+ # When change ref button is clicked, get the reference voice and update the reference voice state
169
+ change_ref_button.click(
170
+ lambda: gr.Button(interactive=False), None, [change_ref_button]
171
+ ).then(
172
+ get_ref_embs, ref_audio, cur_ref_embs
173
+ )
174
+ return ref_row
175
+
176
+ def get_chat_tab():
177
+ cur_ref_embs = gr.State(default_ref_emb_mask_list[default_ref_name])
178
+ with gr.Row() as chat_tab:
179
+ with gr.Column(scale=1):
180
+ ref_row = get_ref_modules(cur_ref_embs)
181
+ # Voice chat input
182
+ chat_input = gr.Audio(
183
+ sources=["microphone", "upload"],
184
+ type="filepath",
185
+ show_label=False,
186
+ )
187
+ submit = gr.Button("Submit", interactive=False)
188
+ gr.Markdown(intro)
189
+ with gr.Column(scale=9):
190
+ chatbot = gr.Chatbot(
191
+ elem_id="chatbot",
192
+ type="messages",
193
+ bubble_full_width=False,
194
+ scale=1,
195
+ show_copy_button=False,
196
+ avatar_images=(
197
+ None, # os.path.join("files", "avatar.png"),
198
+ None, # os.path.join("files", "avatar.png"),
199
+ ),
200
+ )
201
+
202
+ chat_input.input(lambda: gr.Button(interactive=True), None, submit)
203
+ submit.click(
204
+ add_message, [chatbot, chat_input], [chatbot, chat_input, submit]
205
+ ).then(
206
+ call_bot, [chatbot, cur_ref_embs], chatbot, api_name="bot_response"
207
+ )
208
+ return chat_tab
209
+
210
+ def get_tts_tab():
211
+ cur_ref_embs = gr.State(default_ref_emb_mask_list[default_ref_name])
212
+ with gr.Row() as tts_tab:
213
+ with gr.Column(scale=1):
214
+ ref_row = get_ref_modules(cur_ref_embs)
215
+ gr.Markdown(intro)
216
+ with gr.Column(scale=9):
217
+ tts_output = gr.Audio(label="TTS output", interactive=False)
218
+ with gr.Row():
219
+ text_input = gr.Textbox(label="Text", placeholder="Text to TTS")
220
+ submit = gr.Button("Submit")
221
+ submit.click(
222
+ run_tts, [text_input, cur_ref_embs], tts_output
223
+ )
224
+ return tts_tab
225
+
226
+ def get_asr_tab():
227
+ with gr.Row() as asr_tab:
228
+ with gr.Column():
229
+ asr_input = gr.Audio(
230
+ label="ASR input",
231
+ sources=["microphone", "upload"],
232
+ type="filepath",
233
+ )
234
+ submit = gr.Button("Submit")
235
+ gr.Markdown(intro)
236
+ with gr.Column():
237
+ asr_output = gr.Textbox(label="ASR output", interactive=False)
238
+ submit.click(
239
+ run_asr, [asr_input], asr_output
240
+ )
241
+ return asr_tab
242
+
243
+ with gr.Blocks(fill_height=True) as demo:
244
+ with gr.Tab("Chat"):
245
+ chat_tab = get_chat_tab()
246
+ with gr.Tab("TTS"):
247
+ tts_tab = get_tts_tab()
248
+ with gr.Tab("ASR"):
249
+ asr_tab = get_asr_tab()
250
+ demo.unload(delete_directory)
251
+
252
+ if __name__ == "__main__":
253
+ demo.launch()
audio_transformer.py ADDED
@@ -0,0 +1,354 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Optional
3
+ from dataclasses import dataclass
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from torch import Tensor
8
+ from torch.nn import functional as F
9
+
10
+ from einops import rearrange
11
+
12
+
13
+ @dataclass
14
+ class LocalArgs:
15
+ codebook_size: int = 2048
16
+ num_codebooks: int = 4
17
+
18
+ # Modified from https://github.com/fishaudio/fish-speech/blob/main/fish_speech/models/text2semantic/llama.py#L105
19
+ class KVCache(nn.Module):
20
+ def __init__(
21
+ self, n_layer, batch_size, max_seq_len, n_heads, head_dim, dtype, device
22
+ ):
23
+ super().__init__()
24
+ cache_shape = (n_layer, batch_size, n_heads, max_seq_len, head_dim)
25
+ self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype, device=device))
26
+ self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype, device=device))
27
+
28
+ def update(self, layer_idx, input_pos, k_val, v_val):
29
+ # k_val: [B, H, S, D]
30
+
31
+ k_out = self.k_cache
32
+ v_out = self.v_cache
33
+ k_out[layer_idx, :, :, input_pos:input_pos+1] = k_val
34
+ v_out[layer_idx, :, :, input_pos:input_pos+1] = v_val
35
+
36
+ return k_out[layer_idx], v_out[layer_idx]
37
+
38
+ # Modified from https://github.com/fishaudio/fish-speech/blob/main/fish_speech/models/text2semantic/llama.py#L756
39
+ def precompute_freqs_cis(seq_len: int, n_elem: int, base: int = 10000) -> Tensor:
40
+ freqs = 1.0 / (
41
+ base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)
42
+ )
43
+ t = torch.arange(seq_len, device=freqs.device)
44
+ freqs = torch.outer(t, freqs)
45
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
46
+ cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
47
+ return cache
48
+
49
+ # Copied from https://github.com/fishaudio/fish-speech/blob/main/fish_speech/models/text2semantic/llama.py#L767
50
+ def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
51
+ xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
52
+ freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2)
53
+ x_out2 = torch.stack(
54
+ [
55
+ xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
56
+ xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
57
+ ],
58
+ -1,
59
+ )
60
+
61
+ x_out2 = x_out2.flatten(3)
62
+ return x_out2.type_as(x)
63
+
64
+ # Copied from https://github.com/fishaudio/fish-speech/blob/main/fish_speech/models/text2semantic/llama.py#L742
65
+ class RMSNorm(nn.Module):
66
+ def __init__(self, dim: int, eps: float = 1e-5):
67
+ super().__init__()
68
+ self.eps = eps
69
+ self.weight = nn.Parameter(torch.ones(dim))
70
+
71
+ def _norm(self, x):
72
+ return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
73
+
74
+ def forward(self, x: Tensor) -> Tensor:
75
+ output = self._norm(x.float()).type_as(x)
76
+ return output * self.weight
77
+
78
+ # Copied from https://github.com/fishaudio/fish-speech/blob/main/fish_speech/models/text2semantic/llama.py#L731
79
+ class FeedForward(nn.Module):
80
+ def __init__(self, config: LocalArgs) -> None:
81
+ super().__init__()
82
+ self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False)
83
+ self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False)
84
+ self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False)
85
+
86
+ def forward(self, x: Tensor) -> Tensor:
87
+ return self.w2(F.silu(self.w1(x)) * self.w3(x))
88
+
89
+ # Modified from https://github.com/fishaudio/fish-speech/blob/main/fish_speech/models/text2semantic/llama.py#L615
90
+ class Attention(nn.Module):
91
+ def __init__(self, config: LocalArgs, layer_idx: int, use_sdpa: bool = True):
92
+ super().__init__()
93
+ assert config.dim % config.n_head == 0
94
+ self.layer_idx = layer_idx
95
+
96
+ total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
97
+ # key, query, value projections for all heads, but in a batch
98
+ self.wqkv = nn.Linear(
99
+ config.dim, total_head_dim, bias=config.attention_qkv_bias
100
+ )
101
+ self.wo = nn.Linear(config.dim, config.dim, bias=False)
102
+
103
+ self.dropout = config.dropout
104
+ self.n_head = config.n_head
105
+ self.head_dim = config.head_dim
106
+ self.n_local_heads = config.n_local_heads
107
+ self.dim = config.dim
108
+ self.use_sdpa = use_sdpa
109
+ self._register_load_state_dict_pre_hook(self.load_hook)
110
+
111
+ def load_hook(self, state_dict, prefix, *args):
112
+ if prefix + "wq.weight" in state_dict:
113
+ wq = state_dict.pop(prefix + "wq.weight")
114
+ wk = state_dict.pop(prefix + "wk.weight")
115
+ wv = state_dict.pop(prefix + "wv.weight")
116
+ state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])
117
+
118
+ def forward(
119
+ self,
120
+ x: Tensor,
121
+ freqs_cis: Tensor,
122
+ mask: Tensor,
123
+ input_pos: Optional[int] = None,
124
+ kv_cache: Optional[KVCache] = None,
125
+ ) -> Tensor:
126
+ bsz, seqlen, _ = x.shape
127
+
128
+ kv_size = self.n_local_heads * self.head_dim
129
+ q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1)
130
+
131
+ q = q.view(bsz, seqlen, self.n_head, self.head_dim)
132
+ k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim)
133
+ v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim)
134
+
135
+ q = apply_rotary_emb(q, freqs_cis)
136
+ k = apply_rotary_emb(k, freqs_cis)
137
+
138
+ q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
139
+
140
+ if kv_cache is not None:
141
+ k, v = kv_cache.update(self.layer_idx, input_pos, k, v)
142
+
143
+ k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
144
+ v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
145
+
146
+ if self.use_sdpa:
147
+ if mask is None:
148
+ with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
149
+ y = F.scaled_dot_product_attention(
150
+ q,
151
+ k,
152
+ v,
153
+ dropout_p=self.dropout if self.training else 0.0,
154
+ is_causal=True,
155
+ # No third party attn_mask here to use flash_attention
156
+ )
157
+ else:
158
+ y = F.scaled_dot_product_attention(
159
+ q,
160
+ k,
161
+ v,
162
+ attn_mask=mask,
163
+ dropout_p=self.dropout if self.training else 0.0,
164
+ )
165
+ else:
166
+ y = self.eq_scaled_dot_product_attention(
167
+ q,
168
+ k,
169
+ v,
170
+ attn_mask=mask,
171
+ dropout_p=self.dropout if self.training else 0.0,
172
+ )
173
+
174
+ y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
175
+
176
+ return self.wo(y)
177
+
178
+ def eq_scaled_dot_product_attention(
179
+ self,
180
+ query,
181
+ key,
182
+ value,
183
+ attn_mask=None,
184
+ dropout_p=0.0,
185
+ ) -> torch.Tensor:
186
+ # This is a standard scaled dot product attention
187
+ # It's low efficient, but it doesn't raise cuda error
188
+
189
+ L, S = query.size(-2), key.size(-2)
190
+ scale_factor = 1 / math.sqrt(query.size(-1))
191
+ attn_bias = torch.zeros(1, 1, L, S, dtype=query.dtype, device=query.device)
192
+
193
+ if attn_mask is not None:
194
+ if attn_mask.dtype == torch.bool:
195
+ attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
196
+ else:
197
+ attn_bias += attn_mask
198
+
199
+ attn_weight = query @ key.transpose(-2, -1) * scale_factor
200
+ attn_weight += attn_bias
201
+ attn_weight = torch.softmax(attn_weight, dim=-1)
202
+ attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
203
+
204
+ return attn_weight @ value
205
+
206
+ # Copied from https://github.com/fishaudio/fish-speech/blob/main/fish_speech/models/text2semantic/llama.py#L599
207
+ class TransformerBlock(nn.Module):
208
+ def __init__(self, config: LocalArgs, layer_idx: int, use_sdpa: bool = True) -> None:
209
+ super().__init__()
210
+ self.attention = Attention(config, layer_idx, use_sdpa=use_sdpa)
211
+ self.feed_forward = FeedForward(config)
212
+ self.ffn_norm = RMSNorm(config.dim, config.norm_eps)
213
+ self.attention_norm = RMSNorm(config.dim, config.norm_eps)
214
+
215
+ def forward(
216
+ self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: int = None, kv_cache: KVCache = None
217
+ ) -> Tensor:
218
+ h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos, kv_cache)
219
+ out = h + self.feed_forward(self.ffn_norm(h))
220
+ return out
221
+
222
+ # Modified from https://github.com/fishaudio/fish-speech/blob/main/fish_speech/models/text2semantic/llama.py#L470
223
+ class AudioTransformer(nn.Module):
224
+ def __init__(self, config, use_sdpa: bool = False):
225
+ super().__init__()
226
+ self.config = LocalArgs()
227
+ self.config.codebook_size = config.codebook_size
228
+ self.config.num_codebooks = config.num_codebooks
229
+ if hasattr(config, "min_audio_token_id"):
230
+ self.config.min_audio_token_id = config.min_audio_token_id
231
+ self.config.max_audio_token_id = config.max_audio_token_id
232
+ self.config.n_layer = 4
233
+ self.config.dim = 1024
234
+ self.config.n_head = 32
235
+ self.config.n_local_heads = 32
236
+ self.config.intermediate_size = 2816
237
+ self.config.head_dim = self.config.dim // self.config.n_head
238
+ self.config.norm_eps = 1e-5
239
+ self.config.attention_qkv_bias = False
240
+ self.config.dropout = 0.0
241
+
242
+ self.embeddings = nn.Embedding(self.config.codebook_size, self.config.dim)
243
+ if self.config.dim != config.hidden_size:
244
+ self.input_proj = nn.Linear(config.hidden_size, self.config.dim, bias=False)
245
+ else:
246
+ self.input_proj = nn.Identity()
247
+ self.layers = nn.ModuleList(
248
+ TransformerBlock(self.config, layer_idx, use_sdpa=use_sdpa) for layer_idx in range(self.config.n_layer)
249
+ )
250
+ self.norm = RMSNorm(self.config.dim, eps=self.config.norm_eps)
251
+ self.token_head = nn.Linear(self.config.dim, self.config.codebook_size, bias=False)
252
+ self.gradient_checkpointing = False
253
+
254
+ self.register_buffer(
255
+ "freqs_cis",
256
+ precompute_freqs_cis(self.config.num_codebooks, self.config.dim // self.config.n_head, 10000),
257
+ persistent=False,
258
+ )
259
+ self.register_buffer(
260
+ "attention_mask",
261
+ torch.tril(torch.ones(self.config.num_codebooks, self.config.num_codebooks, dtype=torch.bool)),
262
+ persistent=False,
263
+ )
264
+
265
+ def run_model(self, hidden_states, freqs_cis, attention_mask, input_pos: int = None, kv_cache: KVCache = None):
266
+ for layer in self.layers:
267
+ # TODO: gradient_checkpointing is disabled because of bug
268
+ if False: # self.gradient_checkpointing and self.training:
269
+ hidden_states = self._gradient_checkpointing_func(
270
+ layer.__call__,
271
+ hidden_states,
272
+ freqs_cis,
273
+ attention_mask,
274
+ use_reentrant=True,
275
+ )
276
+ else:
277
+ hidden_states = layer(hidden_states, freqs_cis, attention_mask, input_pos, kv_cache)
278
+ hidden_states = self.norm(hidden_states)
279
+ logits = self.token_head(hidden_states)
280
+ return logits.float()
281
+
282
+ # inp: [bs, hidden_size]
283
+ # labels: [bs, num_codebooks]
284
+ # logits: [bs, num_codebooks, codebook_size]
285
+ def forward(self, inp, labels):
286
+ bs = inp.shape[0]
287
+
288
+ hidden_states = self.input_proj(inp)
289
+ if self.freqs_cis.dtype != hidden_states.dtype:
290
+ self.freqs_cis = self.freqs_cis.to(dtype=hidden_states.dtype)
291
+ if labels is not None:
292
+ # Training mode
293
+ # Get embedding
294
+ assert bs == labels.shape[0] and labels.shape[1] == self.config.num_codebooks, f"Labels shape error: {labels.shape}"
295
+ hidden_states = [hidden_states[:, None, :], self.embeddings(labels[..., :-1]).to(hidden_states.dtype)]
296
+ hidden_states = torch.cat(hidden_states, dim=1) # [bs, num_codebooks, hidden_size]
297
+ # Run attention layers
298
+ logits = self.run_model(hidden_states, self.freqs_cis, self.attention_mask)
299
+ else:
300
+ # Inference mode
301
+ raise RuntimeError(f"Please call function \"inference\" in inference mode")
302
+ return logits
303
+
304
+ # inp: [bs, seq_len, hidden_size]
305
+ # out_tokens: [bs, 1, num_codebooks]
306
+ @torch.inference_mode()
307
+ def inference(self, inp, temperature=0, top_k=0):
308
+ # Only use the last hidden states for token computation
309
+ inp = inp[:, -1:, :]
310
+
311
+ bs = inp.shape[0]
312
+ if self.freqs_cis.dtype != inp.dtype:
313
+ self.freqs_cis = self.freqs_cis.to(dtype=inp.dtype)
314
+
315
+ inp = self.input_proj(inp)
316
+
317
+ # Inference mode
318
+ kv_cache = KVCache(
319
+ self.config.n_layer,
320
+ bs,
321
+ self.config.num_codebooks,
322
+ self.config.n_head,
323
+ self.config.head_dim,
324
+ dtype=inp.dtype,
325
+ device=inp.device,
326
+ )
327
+ # Generate one token per step
328
+ out_tokens = []
329
+ for input_pos in range(self.config.num_codebooks):
330
+ inp = inp.reshape(bs, 1, self.config.dim)
331
+ local_freqs_cis = self.freqs_cis[input_pos]
332
+ local_mask = self.attention_mask[None, None, input_pos, :self.config.num_codebooks]
333
+
334
+ logits = self.run_model(inp, local_freqs_cis, local_mask, input_pos, kv_cache)
335
+ logits = logits.squeeze(dim=1)
336
+
337
+ # Apply temperature and top-k
338
+ if temperature > 0:
339
+ logits = logits / temperature
340
+ if top_k > 0:
341
+ top_k = min(top_k, logits.size(-1)) # Safety check
342
+ # Remove all tokens with a probability less than the last token of the top-k
343
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
344
+ logits = logits.masked_fill(indices_to_remove, -float("Inf"))
345
+
346
+ # Do sample
347
+ probs = nn.functional.softmax(logits, dim=-1)
348
+ next_tokens = torch.multinomial(probs, num_samples=1)
349
+
350
+ next_tokens = next_tokens.reshape(bs, 1, 1)
351
+ inp = self.embeddings(next_tokens)
352
+ out_tokens.append(next_tokens)
353
+
354
+ return torch.cat(out_tokens, dim=-1)
examples/character_ref_emb_demo.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a869512a59e4447c19ecb283d6e0097bf71eaf57e8fa98712afd7c41acbbb554
3
+ size 23264
examples/test1.mp3 ADDED
Binary file (19.2 kB). View file
 
examples/test_autonomous1.mp3 ADDED
Binary file (52.7 kB). View file
 
infer.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ import random
4
+ import jsonlines
5
+ import soundfile as sf
6
+ import json
7
+ import copy
8
+ import torch
9
+ from pathlib import Path
10
+ from threading import Thread
11
+
12
+ import torchaudio
13
+ from transformers import AutoTokenizer
14
+
15
+ from model import VoilaAudioAlphaModel, VoilaModel, VoilaAutonomousModel
16
+ from spkr import SpeakerEmbedding
17
+ from voila_tokenizer import VoilaTokenizer
18
+ from tokenize_func import (
19
+ voila_input_format,
20
+ AUDIO_TOKEN_FORMAT,
21
+ DEFAULT_AUDIO_TOKEN,
22
+ DEFAULT_ASSISTANT_TOKEN,
23
+ )
24
+
25
+
26
+ def disable_torch_init():
27
+ """
28
+ Disable the redundant torch default initialization to accelerate model creation.
29
+ """
30
+ import torch
31
+ setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
32
+ setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
33
+
34
+ def load_model(model_name, audio_tokenizer_path):
35
+ disable_torch_init()
36
+
37
+ if "Voila-audio" in model_name:
38
+ model_type = "audio"
39
+ cls = VoilaAudioAlphaModel
40
+ elif "Voila-auto" in model_name:
41
+ model_type = "autonomous"
42
+ cls = VoilaAutonomousModel
43
+ else:
44
+ model_type = "token"
45
+ cls = VoilaModel
46
+
47
+ model = cls.from_pretrained(
48
+ model_name,
49
+ torch_dtype=torch.bfloat16,
50
+ use_flash_attention_2=True,
51
+ use_cache=True,
52
+ )
53
+ model = model.cuda()
54
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
55
+ tokenizer_voila = VoilaTokenizer(model_path=audio_tokenizer_path, device="cuda")
56
+ return model, tokenizer, tokenizer_voila, model_type
57
+
58
+ def is_audio_output_task(task_type):
59
+ return task_type.endswith("ao") or "aiao" in task_type or "tts" in task_type
60
+
61
+ def eval_model(model, tokenizer, tokenizer_voila, model_type, task_type, history, ref_embs, ref_embs_mask, max_new_tokens=512):
62
+ # step1: initializing
63
+ num_codebooks = model.config.num_codebooks
64
+ codebook_size = model.config.codebook_size
65
+
66
+ AUDIO_MIN_TOKEN_ID = tokenizer.convert_tokens_to_ids(AUDIO_TOKEN_FORMAT.format(0))
67
+ assert isinstance(AUDIO_MIN_TOKEN_ID, int)
68
+ AUDIO_MAX_TOKEN_ID = tokenizer.convert_tokens_to_ids(AUDIO_TOKEN_FORMAT.format(codebook_size*num_codebooks-1))
69
+ assert isinstance(AUDIO_MAX_TOKEN_ID, int)
70
+ AUDIO_TOKEN_ID = tokenizer.convert_tokens_to_ids(DEFAULT_AUDIO_TOKEN)
71
+ assert isinstance(AUDIO_TOKEN_ID, int)
72
+ ASSISTANT_TOKEN_ID = tokenizer.convert_tokens_to_ids(DEFAULT_ASSISTANT_TOKEN)
73
+ assert isinstance(ASSISTANT_TOKEN_ID, int)
74
+
75
+ # step2: set infer config
76
+ data_cfg = {
77
+ "input_type": model_type,
78
+ "task_type": task_type,
79
+ "num_codebooks": num_codebooks,
80
+ "codebook_size": codebook_size,
81
+ }
82
+
83
+ # step3: infer
84
+ input_ids, audio_datas, audio_data_masks, streaming_user_input_audio_tokens = voila_input_format(history, tokenizer, tokenizer_voila, data_cfg)
85
+
86
+ # prepare user_streaming_generator to simulate streaming user input
87
+ def get_input_generator(all_tokens):
88
+ assert all_tokens is not None
89
+ for i in range(len(all_tokens[0])):
90
+ yield all_tokens[:,i]
91
+
92
+ if model_type == "autonomous":
93
+ input_generator = get_input_generator(torch.as_tensor(streaming_user_input_audio_tokens).cuda())
94
+ input_ids = [torch.as_tensor([input]).transpose(1,2).cuda() for input in input_ids] # transpose to [bs, seq, num_codebooks]
95
+ input_ids = torch.cat(input_ids, dim=2) # concat to [bs, seq, num_codebooks*2]
96
+ else:
97
+ input_ids = torch.as_tensor([input_ids]).transpose(1,2).cuda() # transpose to [bs, seq, num_codebooks]
98
+ gen_params = {
99
+ "input_ids": input_ids,
100
+ "ref_embs": ref_embs,
101
+ "ref_embs_mask": ref_embs_mask,
102
+ "max_new_tokens": max_new_tokens,
103
+ "pad_token_id": tokenizer.pad_token_id,
104
+ "eos_token_id": tokenizer.eos_token_id,
105
+ "llm_audio_token_id": AUDIO_TOKEN_ID,
106
+ "min_audio_token_id": AUDIO_MIN_TOKEN_ID,
107
+ "temperature": 0.2,
108
+ "top_k": 50,
109
+ "audio_temperature": 0.8,
110
+ "audio_top_k": 50,
111
+ }
112
+ if model_type == "audio":
113
+ audio_datas = torch.tensor([audio_datas], dtype=torch.bfloat16).cuda()
114
+ audio_data_masks = torch.tensor([audio_data_masks]).cuda()
115
+ gen_params["audio_datas"] = audio_datas
116
+ gen_params["audio_data_masks"] = audio_data_masks
117
+ elif model_type == "autonomous":
118
+ gen_params["input_generator"] = input_generator
119
+ gen_params["llm_assistant_token_id"] = ASSISTANT_TOKEN_ID
120
+ print(f"Input str: {tokenizer.decode(input_ids[0, :, 0])}")
121
+ with torch.inference_mode():
122
+ outputs = model.run_generate(**gen_params)
123
+
124
+ if model_type == "autonomous":
125
+ outputs = outputs.chunk(2, dim=2)[1]
126
+ outputs = outputs[0].cpu().tolist()
127
+
128
+ predict_outputs = outputs[input_ids.shape[1]:]
129
+ text_outputs = []
130
+ audio_outputs = []
131
+ for _ in range(num_codebooks):
132
+ audio_outputs.append([])
133
+ for item in predict_outputs:
134
+ if item[0] >= AUDIO_MIN_TOKEN_ID and item[0] <= AUDIO_MAX_TOKEN_ID:
135
+ for n, at in enumerate(item):
136
+ audio_outputs[n].append((at - AUDIO_MIN_TOKEN_ID)%codebook_size)
137
+ elif item[0] != tokenizer.eos_token_id:
138
+ text_outputs.append(item[0])
139
+
140
+ out ={
141
+ 'text': tokenizer.decode(text_outputs),
142
+ }
143
+ if is_audio_output_task(task_type):
144
+ audio_values = tokenizer_voila.decode(torch.tensor(audio_outputs).cuda())
145
+ out['audio'] = (audio_values.detach().cpu().numpy(), 16000)
146
+ return out
147
+
148
+
149
+ if __name__ == "__main__":
150
+ parser = argparse.ArgumentParser()
151
+ parser.add_argument("--instruction", type=str, default="")
152
+ parser.add_argument("--input-text", type=str, default=None)
153
+ parser.add_argument("--input-audio", type=str, default=None)
154
+ parser.add_argument("--result-path", type=str, default="output")
155
+ parser.add_argument("--ref-audio", type=str, default="examples/test1.mp3")
156
+ parser.add_argument("--model-name", type=str, default="maitrix-org/Voila-chat")
157
+ parser.add_argument("--audio-tokenizer-path", type=str, default="maitrix-org/Voila-Tokenizer")
158
+ parser.add_argument("--task-type", type=str, default="chat_aiao")
159
+ args = parser.parse_args()
160
+
161
+ assert args.model_name in [
162
+ "maitrix-org/Voila-audio-alpha",
163
+ "maitrix-org/Voila-base",
164
+ "maitrix-org/Voila-chat",
165
+ "maitrix-org/Voila-autonomous-preview",
166
+ ]
167
+
168
+ # step0: Model loading
169
+ model, tokenizer, tokenizer_voila, model_type = load_model(args.model_name, args.audio_tokenizer_path)
170
+
171
+ # step1: prepare inputs
172
+ Path(args.result_path).mkdir(exist_ok=True, parents=True)
173
+ history = {
174
+ "instruction": args.instruction,
175
+ "conversations": [],
176
+ }
177
+ if args.input_text is not None:
178
+ history["conversations"].append({"from": "user", "text": args.input_text})
179
+ elif args.input_audio is not None:
180
+ history["conversations"].append({"from": "user", "audio": {"file": args.input_audio}})
181
+ else:
182
+ raise Exception("Please provide atleast one of --input-text and --input-audio")
183
+ history["conversations"].append({"from": "assistant"})
184
+
185
+ # step2: encode ref
186
+ ref_embs, ref_embs_mask = None, None
187
+ if is_audio_output_task(args.task_type):
188
+ spkr_model = SpeakerEmbedding(device="cuda")
189
+ wav, sr = torchaudio.load(args.ref_audio)
190
+ ref_embs = spkr_model(wav, sr)
191
+ ref_embs_mask = torch.tensor([1]).cuda()
192
+
193
+ out = eval_model(model, tokenizer, tokenizer_voila, model_type, args.task_type, history, ref_embs, ref_embs_mask)
194
+ print(f"Output str: {out['text']}")
195
+ if 'audio' in out:
196
+ wav, sr = out['audio']
197
+ save_name = f"{args.result_path}/out.wav"
198
+ sf.write(save_name, wav, sr)
model.py ADDED
@@ -0,0 +1,1397 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from dataclasses import dataclass
3
+ from typing import List, Optional, Tuple, Union, Dict, Any
4
+
5
+ import torch
6
+ from torch import nn
7
+ import torch.nn.functional as F
8
+ from torch.nn import CrossEntropyLoss
9
+
10
+ from transformers.cache_utils import Cache, DynamicCache
11
+ from transformers.utils import ModelOutput, logging
12
+ from transformers.models.llama.modeling_llama import LlamaModel, LlamaPreTrainedModel
13
+
14
+ from audio_transformer import AudioTransformer
15
+
16
+ logger = logging.get_logger(__name__)
17
+
18
+
19
+ # Copied from https://github.com/pytorch/audio/blob/main/src/torchaudio/models/wav2vec2/components.py#L43
20
+ class LayerNorm(torch.nn.LayerNorm):
21
+ """Layer norm with transpose"""
22
+
23
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
24
+ x = input.transpose(-2, -1)
25
+ x = torch.nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
26
+ x = x.transpose(-2, -1)
27
+ return x
28
+
29
+ # Copied from https://github.com/pytorch/audio/blob/main/src/torchaudio/models/wav2vec2/components.py#L53
30
+ class ConvLayerBlock(torch.nn.Module):
31
+ """Convolution unit of FeatureExtractor"""
32
+
33
+ def __init__(
34
+ self,
35
+ in_channels: int,
36
+ out_channels: int,
37
+ kernel_size: int,
38
+ stride: int,
39
+ bias: bool,
40
+ layer_norm: Optional[torch.nn.Module],
41
+ ):
42
+ super().__init__()
43
+ self.kernel_size = kernel_size
44
+ self.stride = stride
45
+ self.layer_norm = layer_norm
46
+ self.conv = torch.nn.Conv1d(
47
+ in_channels=in_channels,
48
+ out_channels=out_channels,
49
+ kernel_size=kernel_size,
50
+ stride=stride,
51
+ bias=bias,
52
+ )
53
+
54
+ def forward(
55
+ self,
56
+ x: torch.Tensor,
57
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
58
+ """
59
+ Args:
60
+ x (Tensor): Shape: ``[batch, in_channels, in_frame]``.
61
+ Returns:
62
+ Tensor: Shape ``[batch, out_channels, out_frames]``.
63
+ Optional[Tensor]: Shape ``[batch, ]``.
64
+ """
65
+ x = self.conv(x)
66
+ if self.layer_norm is not None:
67
+ x = self.layer_norm(x)
68
+ x = torch.nn.functional.gelu(x)
69
+
70
+ return x
71
+
72
+ # Copied from https://github.com/pytorch/audio/blob/main/src/torchaudio/models/wav2vec2/components.py#L146
73
+ class FeatureProjection(torch.nn.Module):
74
+ """Layer that connects FeatureExtractor and Encoder
75
+
76
+ Projects features to encoder dimension.
77
+
78
+ Args:
79
+ in_features (int): Input feature dim.
80
+ out_features (int): Output feature dim.
81
+ dropout (float): Dropout probability.
82
+ """
83
+
84
+ def __init__(
85
+ self,
86
+ in_features: int,
87
+ out_features: int,
88
+ dropout=0.1,
89
+ ):
90
+ super().__init__()
91
+ self.layer_norm = torch.nn.LayerNorm(in_features)
92
+ self.projection = torch.nn.Linear(
93
+ in_features,
94
+ out_features,
95
+ )
96
+ self.dropout = torch.nn.Dropout(dropout)
97
+
98
+ def forward(self, x):
99
+ """
100
+ Args:
101
+ x (Tensor):
102
+ Feature Tensor. shape: ``[batch, frame, in_feature]``
103
+ Returns:
104
+ Tensor: Projected features. ``[batch, frame, out_feature]``.
105
+ """
106
+ x = self.layer_norm(x)
107
+ x = self.projection(x)
108
+ x = self.dropout(x)
109
+ return x
110
+
111
+ # Modified from https://github.com/pytorch/audio/blob/main/src/torchaudio/models/wav2vec2/components.py#L102
112
+ class FeatureExtractor(torch.nn.Module):
113
+ """Extract features from audio
114
+
115
+ Args:
116
+ conv_layers (nn.ModuleList):
117
+ convolution layers
118
+ """
119
+
120
+ def __init__(
121
+ self,
122
+ shapes=[(512, 10, 5), (512, 3, 2), (512, 3, 2), (512, 3, 2), (512, 3, 2), (512, 2, 2), (512, 2, 2)],
123
+ bias=False,
124
+ norm_mode="group_norm",
125
+ ):
126
+ super().__init__()
127
+ if norm_mode not in ["group_norm", "layer_norm"]:
128
+ raise ValueError("Invalid norm mode")
129
+ blocks = []
130
+ in_channels = 1
131
+ for i, (out_channels, kernel_size, stride) in enumerate(shapes):
132
+ normalization = None
133
+ if norm_mode == "group_norm" and i == 0:
134
+ normalization = torch.nn.GroupNorm(
135
+ num_groups=out_channels,
136
+ num_channels=out_channels,
137
+ affine=True,
138
+ )
139
+ elif norm_mode == "layer_norm":
140
+ normalization = LayerNorm(
141
+ normalized_shape=out_channels,
142
+ elementwise_affine=True,
143
+ )
144
+ blocks.append(
145
+ ConvLayerBlock(
146
+ in_channels=in_channels,
147
+ out_channels=out_channels,
148
+ kernel_size=kernel_size,
149
+ stride=stride,
150
+ bias=bias,
151
+ layer_norm=normalization,
152
+ )
153
+ )
154
+ in_channels = out_channels
155
+ self.conv_layers = torch.nn.ModuleList(blocks)
156
+
157
+ def forward(
158
+ self,
159
+ x: torch.Tensor,
160
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
161
+ """
162
+ Args:
163
+ x (Tensor):
164
+ Input Tensor representing a batch of audio,
165
+ shape: ``[batch, time]``.
166
+
167
+ Returns:
168
+ Tensor:
169
+ The resulting feature, shape: ``[batch, frame, feature]``
170
+ Optional[Tensor]:
171
+ Valid length of each output sample. shape: ``[batch, ]``.
172
+ """
173
+ if x.ndim != 2:
174
+ raise ValueError(f"Expected the input Tensor to be 2D (batch, time). Found: {list(x.shape)}")
175
+
176
+ x = x.unsqueeze(1) # (batch, channel==1, frame)
177
+ for layer in self.conv_layers:
178
+ x = layer(x) # (batch, feature, frame)
179
+ x = x.transpose(1, 2) # (batch, frame, feature)
180
+ return x
181
+
182
+ # Modified from https://github.com/pytorch/audio/blob/main/src/torchaudio/models/wav2vec2/components.py#L102
183
+ class FeatureExtractorAdapter(torch.nn.Module):
184
+ """Extract features from audio
185
+
186
+ Args:
187
+ conv_layers (nn.ModuleList):
188
+ convolution layers
189
+ """
190
+
191
+ def __init__(
192
+ self,
193
+ shapes=(512, 512, 2, 2),
194
+ hidden_size=2048,
195
+ bias=False,
196
+ norm_mode="group_norm",
197
+ ):
198
+ super().__init__()
199
+ if norm_mode not in ["group_norm", "layer_norm"]:
200
+ raise ValueError("Invalid norm mode")
201
+ in_channels, out_channels, kernel_size, stride = shapes
202
+ normalization = LayerNorm(
203
+ normalized_shape=out_channels,
204
+ elementwise_affine=True,
205
+ )
206
+ self.conv_layers = ConvLayerBlock(
207
+ in_channels=in_channels,
208
+ out_channels=out_channels,
209
+ kernel_size=kernel_size,
210
+ stride=stride,
211
+ bias=False,
212
+ layer_norm=normalization,
213
+ )
214
+ self.feat_proj = FeatureProjection(out_channels, hidden_size)
215
+
216
+ def forward(
217
+ self,
218
+ x: torch.Tensor,
219
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
220
+ """
221
+ Args:
222
+ x (Tensor):
223
+ Input Tensor representing a batch of audio,
224
+ shape: ``[batch, time]``.
225
+
226
+ Returns:
227
+ Tensor:
228
+ The resulting feature, shape: ``[batch, frame, feature]``
229
+ Optional[Tensor]:
230
+ Valid length of each output sample. shape: ``[batch, ]``.
231
+ """
232
+ x = x.transpose(1, 2) # (batch, feature, frame)
233
+ x = self.conv_layers(x) # (batch, feature, frame)
234
+ x = x.transpose(1, 2) # (batch, frame, feature)
235
+ x = self.feat_proj(x)
236
+ return x
237
+
238
+ @dataclass
239
+ class VoilaOutput(ModelOutput):
240
+ """
241
+ Modified from https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_outputs.py#L678
242
+
243
+ Base class for Voila outputs.
244
+
245
+ Args:
246
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
247
+ Language modeling loss (for next-token prediction).
248
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
249
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
250
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
251
+ The hidden state of the last attention layer.
252
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
253
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
254
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
255
+
256
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
257
+ `past_key_values` input) to speed up sequential decoding.
258
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
259
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
260
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
261
+
262
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
263
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
264
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
265
+ sequence_length)`.
266
+
267
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
268
+ heads.
269
+ """
270
+
271
+ loss: Optional[torch.FloatTensor] = None
272
+ logits: torch.FloatTensor = None
273
+ last_hidden_state: torch.FloatTensor = None
274
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
275
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
276
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
277
+ voila_pred: Optional[torch.FloatTensor] = None
278
+
279
+
280
+ # Modified from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L1103
281
+ class VoilaModel(LlamaPreTrainedModel):
282
+ _tied_weights_keys = ["lm_head.weight"]
283
+
284
+ def __init__(self, config):
285
+ super().__init__(config)
286
+ self.model = LlamaModel(config)
287
+ self.vocab_size = config.vocab_size
288
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
289
+ self.pad_vocab_size_multiple = 64
290
+
291
+ self.ref_emb_linear = nn.Linear(256, config.hidden_size, bias=True)
292
+ self.audio_transformer = AudioTransformer(config, use_sdpa=False)
293
+
294
+ # Initialize weights and apply final processing
295
+ self.post_init()
296
+
297
+ def get_input_embeddings(self):
298
+ return self.model.embed_tokens
299
+
300
+ def set_input_embeddings(self, value):
301
+ self.model.embed_tokens = value
302
+
303
+ def get_output_embeddings(self):
304
+ return self.lm_head
305
+
306
+ def set_output_embeddings(self, new_embeddings):
307
+ self.lm_head = new_embeddings
308
+
309
+ def set_decoder(self, decoder):
310
+ self.model = decoder
311
+
312
+ def get_decoder(self):
313
+ return self.model
314
+
315
+ def forward(
316
+ self,
317
+ input_ids: torch.LongTensor = None,
318
+ attention_mask: Optional[torch.Tensor] = None,
319
+ position_ids: Optional[torch.LongTensor] = None,
320
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
321
+ inputs_embeds: Optional[torch.FloatTensor] = None,
322
+ labels: Optional[torch.LongTensor] = None,
323
+ audio_labels: Optional[torch.LongTensor] = None,
324
+ ref_embs: Optional[List[torch.Tensor]] = None,
325
+ ref_embs_mask: Optional[torch.LongTensor] = None,
326
+ use_cache: Optional[bool] = None,
327
+ output_attentions: Optional[bool] = None,
328
+ output_hidden_states: Optional[bool] = None,
329
+ return_dict: Optional[bool] = None,
330
+ cache_position: Optional[torch.LongTensor] = None,
331
+ num_logits_to_keep: int = 0,
332
+ ) -> Union[Tuple, VoilaOutput]:
333
+ r"""
334
+ Args:
335
+ input_ids: [bs, seq_len, num_codebooks]
336
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
337
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
338
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
339
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
340
+ """
341
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
342
+ output_hidden_states = (
343
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
344
+ )
345
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
346
+
347
+ if input_ids is not None and inputs_embeds is not None:
348
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
349
+ if inputs_embeds is None:
350
+ inputs_embeds = self.model.embed_tokens(input_ids)
351
+ assert len(inputs_embeds.shape) == 4
352
+ if len(inputs_embeds.shape) == 4:
353
+ inputs_embeds = inputs_embeds.mean(dim=2)
354
+
355
+ if self.training or \
356
+ (past_key_values is None and ref_embs is not None) or \
357
+ (past_key_values is not None and past_key_values.get_seq_length() < 4 and ref_embs is not None):
358
+ ref_embs = self.ref_emb_linear(ref_embs.to(self.ref_emb_linear.weight.dtype))
359
+ ref_embs = ref_embs * ref_embs_mask.unsqueeze(-1).unsqueeze(-1)
360
+ # (padding_left,padding_right,padding_top,padding_bottom,padding_front,padding_back)
361
+ padding = (0, 0, 4, inputs_embeds.shape[1] - 5, 0, 0)
362
+ ref_embs = torch.nn.functional.pad(ref_embs, padding, mode='constant', value=0.0)
363
+ inputs_embeds = inputs_embeds + ref_embs
364
+
365
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
366
+ outputs = self.model(
367
+ attention_mask=attention_mask,
368
+ position_ids=position_ids,
369
+ past_key_values=past_key_values,
370
+ inputs_embeds=inputs_embeds,
371
+ use_cache=use_cache,
372
+ output_attentions=output_attentions,
373
+ output_hidden_states=output_hidden_states,
374
+ return_dict=return_dict,
375
+ cache_position=cache_position,
376
+ )
377
+
378
+ hidden_states = outputs[0]
379
+ if self.config.pretraining_tp > 1:
380
+ lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
381
+ logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
382
+ logits = torch.cat(logits, dim=-1)
383
+ else:
384
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
385
+ logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
386
+
387
+ loss = None
388
+
389
+ if not return_dict:
390
+ output = (logits,) + outputs[1:]
391
+ return (loss,) + output if loss is not None else output
392
+
393
+ return VoilaOutput(
394
+ loss=loss,
395
+ logits=logits,
396
+ last_hidden_state=hidden_states,
397
+ past_key_values=outputs.past_key_values,
398
+ hidden_states=outputs.hidden_states,
399
+ attentions=outputs.attentions,
400
+ )
401
+
402
+ def _prepare_inputs_for_generation(
403
+ self, input_ids, ref_embs=None, ref_embs_mask=None, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
404
+ ):
405
+ if past_key_values is not None and past_key_values.get_seq_length() > 0:
406
+ if isinstance(past_key_values, Cache):
407
+ cache_length = past_key_values.get_seq_length()
408
+ past_length = past_key_values.seen_tokens
409
+ max_cache_length = past_key_values.get_max_cache_shape()
410
+ else:
411
+ cache_length = past_length = past_key_values[0][0].shape[2]
412
+ max_cache_length = None
413
+
414
+ # Keep only the unprocessed tokens:
415
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
416
+ # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
417
+ # input)
418
+ if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
419
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
420
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
421
+ # input_ids based on the past_length.
422
+ elif past_length < input_ids.shape[1]:
423
+ input_ids = input_ids[:, past_length:]
424
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
425
+
426
+ # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
427
+ if (
428
+ max_cache_length is not None
429
+ and attention_mask is not None
430
+ and cache_length + input_ids.shape[1] > max_cache_length
431
+ ):
432
+ attention_mask = attention_mask[:, -max_cache_length:]
433
+
434
+ position_ids = kwargs.get("position_ids", None)
435
+ if attention_mask is not None and position_ids is None:
436
+ # create position_ids on the fly for batch generation
437
+ position_ids = attention_mask.long().cumsum(-1) - 1
438
+ position_ids.masked_fill_(attention_mask == 0, 1)
439
+ if past_key_values:
440
+ position_ids = position_ids[:, -input_ids.shape[1] :]
441
+
442
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
443
+ if inputs_embeds is None and \
444
+ (past_key_values is None or past_key_values.get_seq_length() <= 0):
445
+ inputs_embeds = self.model.embed_tokens(input_ids)
446
+ if inputs_embeds is not None and \
447
+ (past_key_values is None or past_key_values.get_seq_length() <= 0):
448
+ model_inputs = {"inputs_embeds": inputs_embeds, "ref_embs": ref_embs, "ref_embs_mask": ref_embs_mask}
449
+ else:
450
+ model_inputs = {"input_ids": input_ids, "ref_embs": None}
451
+
452
+ model_inputs.update(
453
+ {
454
+ "position_ids": position_ids,
455
+ "past_key_values": past_key_values,
456
+ "use_cache": kwargs.get("use_cache"),
457
+ "attention_mask": attention_mask,
458
+ }
459
+ )
460
+ return model_inputs
461
+
462
+ def _update_model_kwargs_for_generation(
463
+ self,
464
+ outputs,
465
+ model_kwargs: Dict[str, Any],
466
+ num_new_token: int = 1,
467
+ ) -> Dict[str, Any]:
468
+ # update past_key_values
469
+ model_kwargs["past_key_values"] = outputs.past_key_values
470
+
471
+ # update attention mask
472
+ if "attention_mask" in model_kwargs:
473
+ attention_mask = model_kwargs["attention_mask"]
474
+ model_kwargs["attention_mask"] = torch.cat(
475
+ [attention_mask, attention_mask.new_ones((attention_mask.shape[0], num_new_token))], dim=-1
476
+ )
477
+
478
+ return model_kwargs
479
+
480
+ def _prepare_attention_mask_for_generation(
481
+ self,
482
+ inputs: torch.Tensor,
483
+ pad_token_id: Optional[int],
484
+ eos_token_id: Optional[Union[int, List[int]]],
485
+ ) -> torch.LongTensor:
486
+ is_input_ids = len(inputs.shape) == 2 and inputs.dtype in [torch.int, torch.long]
487
+ is_pad_token_in_inputs = (pad_token_id is not None) and (pad_token_id in inputs)
488
+ if isinstance(eos_token_id, int):
489
+ eos_token_id = [eos_token_id]
490
+ is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or (pad_token_id not in eos_token_id)
491
+
492
+ # Check if input is input_ids and padded -> only then is attention_mask defined
493
+ if is_input_ids and is_pad_token_in_inputs and is_pad_token_not_equal_to_eos_token_id:
494
+ return inputs.ne(pad_token_id).long()
495
+ else:
496
+ return torch.ones(inputs.shape[:2], dtype=torch.long, device=inputs.device)
497
+
498
+ @torch.inference_mode()
499
+ def run_generate(
500
+ self,
501
+ input_ids: torch.LongTensor,
502
+ ref_embs: Optional[List[torch.Tensor]] = None,
503
+ ref_embs_mask: Optional[torch.LongTensor] = None,
504
+ max_new_tokens: Optional[int] = 128,
505
+ pad_token_id: Optional[int] = None,
506
+ eos_token_id: Optional[Union[int, List[int]]] = None,
507
+ streamer: Optional["BaseStreamer"] = None,
508
+ llm_audio_token_id: Optional[int] = None,
509
+ min_audio_token_id: Optional[int] = None,
510
+ temperature=0.2,
511
+ top_k=50,
512
+ audio_temperature=0.2,
513
+ audio_top_k=50,
514
+ ):
515
+ assert eos_token_id is not None and pad_token_id is not None, "eos_token_id and pad_token_id are required for inference"
516
+ assert llm_audio_token_id is not None and min_audio_token_id is not None, "llm_audio_token_id and min_audio_token_id are required for inference"
517
+ assert len(input_ids.shape) == 2 or len(input_ids.shape) == 3, f"input_ids is supposed to be [batch, seq_len] or [batch, seq_len, num_codebooks], and got {input_ids.shape}"
518
+
519
+ eos_token_id_tensor = torch.tensor([eos_token_id]).to(input_ids.device)
520
+
521
+ # keep track of which sequences are already finished
522
+ unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
523
+
524
+ # Extend input_ids with additional num_codebooks dim
525
+ if len(input_ids.shape) == 2:
526
+ input_ids = input_ids[:, :, None].expand(1, 1, self.config.num_codebooks)
527
+
528
+ this_peer_finished = False # used by synced_gpus only
529
+ max_length = input_ids.shape[1] + max_new_tokens
530
+
531
+ model_kwargs = {
532
+ "use_cache": True,
533
+ "past_key_values": DynamicCache(),
534
+ "attention_mask": self._prepare_attention_mask_for_generation(
535
+ input_ids, pad_token_id, eos_token_id
536
+ ),
537
+ }
538
+ # auto-regressive generation
539
+ while True:
540
+ # prepare model inputs
541
+ model_inputs = self._prepare_inputs_for_generation(
542
+ input_ids,
543
+ ref_embs=ref_embs,
544
+ ref_embs_mask=ref_embs_mask,
545
+ **model_kwargs
546
+ )
547
+
548
+ # forward pass to get next token
549
+ outputs = self(
550
+ **model_inputs,
551
+ return_dict=True,
552
+ )
553
+ audio_tokens = self.audio_transformer.inference(
554
+ outputs.last_hidden_state,
555
+ temperature=audio_temperature,
556
+ top_k=audio_top_k,
557
+ )
558
+ audio_tokens = torch.stack(
559
+ [
560
+ audio_tokens[:, :, ci] + min_audio_token_id + ci*self.config.codebook_size
561
+ for ci in range(self.config.num_codebooks)
562
+ ],
563
+ dim=2,
564
+ )
565
+
566
+ next_token_logits = outputs.logits[:, -1, :]
567
+
568
+ # pre-process distribution
569
+ # Apply temperature and top-k
570
+ if temperature > 0:
571
+ next_token_logits = next_token_logits / temperature
572
+ if top_k > 0:
573
+ top_k = min(top_k, next_token_logits.size(-1)) # Safety check
574
+ # Remove all tokens with a probability less than the last token of the top-k
575
+ indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None]
576
+ next_token_logits = next_token_logits.masked_fill(indices_to_remove, -float("Inf"))
577
+
578
+ # sample
579
+ probs = nn.functional.softmax(next_token_logits, dim=-1)
580
+ next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
581
+
582
+ # finished sentences should have their next token be a padding token
583
+ if eos_token_id is not None:
584
+ if pad_token_id is None:
585
+ raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
586
+ next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
587
+
588
+ # Append NUM_CODEBOOK text tokens or audio_tokens
589
+ if len(next_tokens.shape) == 1:
590
+ next_tokens = next_tokens[:, None, None].expand(-1, 1, self.config.num_codebooks)
591
+ next_tokens = torch.where(next_tokens==llm_audio_token_id, audio_tokens, next_tokens)
592
+
593
+ input_ids = torch.cat([input_ids, next_tokens], dim=1)
594
+ if streamer is not None:
595
+ streamer.put(next_tokens.cpu())
596
+ model_kwargs = self._update_model_kwargs_for_generation(
597
+ outputs, model_kwargs
598
+ )
599
+
600
+ # if eos_token was found in one sentence, set sentence to finished
601
+ if eos_token_id_tensor is not None:
602
+ unfinished_sequences = unfinished_sequences.mul(
603
+ next_tokens[:, :, 0].ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=1)
604
+ )
605
+
606
+ # stop when each sentence is finished
607
+ if unfinished_sequences.max() == 0:
608
+ this_peer_finished = True
609
+
610
+ # stop if we exceed the maximum length
611
+ if input_ids.shape[1] >= max_length:
612
+ this_peer_finished = True
613
+
614
+ if this_peer_finished:
615
+ break
616
+
617
+ if streamer is not None:
618
+ streamer.end()
619
+
620
+ return input_ids
621
+
622
+
623
+ # Modified from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L1103
624
+ class VoilaAudioAlphaModel(LlamaPreTrainedModel):
625
+ _tied_weights_keys = ["lm_head.weight"]
626
+
627
+ def __init__(self, config):
628
+ super().__init__(config)
629
+ self.model = LlamaModel(config)
630
+ self.vocab_size = config.vocab_size
631
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
632
+ self.pad_vocab_size_multiple = 64
633
+
634
+
635
+ self.ref_emb_linear = nn.Linear(256, config.hidden_size, bias=True)
636
+ self.audio_transformer = AudioTransformer(config, use_sdpa=False)
637
+
638
+ self.feature_extractor = FeatureExtractor()
639
+ self.audio_feature_extractor_adapter = FeatureExtractorAdapter(hidden_size=config.hidden_size)
640
+
641
+ # Initialize weights and apply final processing
642
+ self.post_init()
643
+
644
+ def get_input_embeddings(self):
645
+ return self.model.embed_tokens
646
+
647
+ def set_input_embeddings(self, value):
648
+ self.model.embed_tokens = value
649
+
650
+ def get_output_embeddings(self):
651
+ return self.lm_head
652
+
653
+ def set_output_embeddings(self, new_embeddings):
654
+ self.lm_head = new_embeddings
655
+
656
+ def set_decoder(self, decoder):
657
+ self.model = decoder
658
+
659
+ def get_decoder(self):
660
+ return self.model
661
+
662
+ def forward(
663
+ self,
664
+ input_ids: torch.LongTensor = None,
665
+ attention_mask: Optional[torch.Tensor] = None,
666
+ position_ids: Optional[torch.LongTensor] = None,
667
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
668
+ inputs_embeds: Optional[torch.FloatTensor] = None,
669
+ labels: Optional[torch.LongTensor] = None,
670
+ audio_labels: Optional[torch.LongTensor] = None,
671
+ ref_embs: Optional[List[torch.Tensor]] = None,
672
+ ref_embs_mask: Optional[torch.LongTensor] = None,
673
+ audio_datas: Optional[torch.FloatTensor] = None,
674
+ audio_data_masks: Optional[torch.LongTensor] = None,
675
+ use_cache: Optional[bool] = None,
676
+ output_attentions: Optional[bool] = None,
677
+ output_hidden_states: Optional[bool] = None,
678
+ return_dict: Optional[bool] = None,
679
+ cache_position: Optional[torch.LongTensor] = None,
680
+ num_logits_to_keep: int = 0,
681
+ ) -> Union[Tuple, VoilaOutput]:
682
+ r"""
683
+ Args:
684
+ input_ids: [bs, seq_len, num_codebooks]
685
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
686
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
687
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
688
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
689
+ """
690
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
691
+ output_hidden_states = (
692
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
693
+ )
694
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
695
+
696
+ if input_ids is not None and inputs_embeds is not None:
697
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
698
+ if inputs_embeds is None:
699
+ inputs_embeds = self.model.embed_tokens(input_ids)
700
+ assert len(inputs_embeds.shape) == 4
701
+ if len(inputs_embeds.shape) == 4:
702
+ inputs_embeds = inputs_embeds.mean(dim=2)
703
+
704
+ if self.training or \
705
+ (past_key_values is None and ref_embs is not None) or \
706
+ (past_key_values is not None and past_key_values.get_seq_length() < 4 and ref_embs is not None):
707
+ ref_embs = self.ref_emb_linear(ref_embs.to(self.ref_emb_linear.weight.dtype))
708
+ ref_embs = ref_embs * ref_embs_mask.unsqueeze(-1).unsqueeze(-1)
709
+ # (padding_left,padding_right,padding_top,padding_bottom,padding_front,padding_back)
710
+ padding = (0, 0, 4, inputs_embeds.shape[1] - 5, 0, 0)
711
+ ref_embs = torch.nn.functional.pad(ref_embs, padding, mode='constant', value=0.0)
712
+ inputs_embeds = inputs_embeds + ref_embs
713
+
714
+ if self.training or audio_datas is not None:
715
+ audio_embeds = self.feature_extractor(audio_datas)
716
+ audio_embeds = self.audio_feature_extractor_adapter(audio_embeds)
717
+ audio_embeds = audio_embeds * audio_data_masks[..., None]
718
+ inputs_embeds = inputs_embeds + audio_embeds
719
+
720
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
721
+ outputs = self.model(
722
+ attention_mask=attention_mask,
723
+ position_ids=position_ids,
724
+ past_key_values=past_key_values,
725
+ inputs_embeds=inputs_embeds,
726
+ use_cache=use_cache,
727
+ output_attentions=output_attentions,
728
+ output_hidden_states=output_hidden_states,
729
+ return_dict=return_dict,
730
+ cache_position=cache_position,
731
+ )
732
+
733
+ hidden_states = outputs[0]
734
+ if self.config.pretraining_tp > 1:
735
+ lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
736
+ logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
737
+ logits = torch.cat(logits, dim=-1)
738
+ else:
739
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
740
+ logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
741
+
742
+ loss = None
743
+ if labels is not None:
744
+ # Upcast to float if we need to compute the loss to avoid potential precision issues
745
+ logits = logits.float()
746
+ # We shift tokens and labels in dataloader
747
+ shift_logits = logits.contiguous()
748
+ shift_labels = labels.contiguous()
749
+ # Flatten the tokens
750
+ loss_fct = CrossEntropyLoss()
751
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
752
+ shift_labels = shift_labels.view(-1)
753
+ # Enable model parallelism
754
+ shift_labels = shift_labels.to(shift_logits.device)
755
+ loss = loss_fct(shift_logits, shift_labels)
756
+
757
+ if audio_labels is not None:
758
+ au_mask = (audio_labels >= 0).all(dim=-1)
759
+ au_hidden_states = hidden_states[au_mask]
760
+ au_audio_labels = audio_labels[au_mask]
761
+ if len(au_hidden_states) <= 0:
762
+ au_hidden_states = hidden_states.reshape(-1, hidden_states.shape[-1])
763
+ au_audio_labels = torch.zeros_like(audio_labels).reshape(-1, self.config.num_codebooks)
764
+ loss_weight = 0.0
765
+ else:
766
+ loss_weight = 1.0
767
+ au_logits = self.audio_transformer(au_hidden_states, au_audio_labels)
768
+ # We shift tokens and labels in dataloader
769
+ shift_au_logits = au_logits.contiguous()
770
+ shift_audio_labels = au_audio_labels.contiguous()
771
+ # Flatten the tokens
772
+ loss_fct = CrossEntropyLoss()
773
+ shift_au_logits = shift_au_logits.view(-1, self.config.codebook_size)
774
+ shift_audio_labels = shift_audio_labels.view(-1)
775
+ # Enable model parallelism
776
+ shift_audio_labels = shift_audio_labels.to(shift_au_logits.device)
777
+ au_loss = loss_fct(shift_au_logits, shift_audio_labels)
778
+
779
+ loss += au_loss * loss_weight
780
+ else:
781
+ # au_tokens = self.audio_transformer.inference(hidden_states)
782
+ pass
783
+
784
+ if not return_dict:
785
+ output = (logits,) + outputs[1:]
786
+ return (loss,) + output if loss is not None else output
787
+
788
+ return VoilaOutput(
789
+ loss=loss,
790
+ logits=logits,
791
+ last_hidden_state=hidden_states,
792
+ past_key_values=outputs.past_key_values,
793
+ hidden_states=outputs.hidden_states,
794
+ attentions=outputs.attentions,
795
+ )
796
+
797
+ def _prepare_inputs_for_generation(
798
+ self, input_ids, ref_embs=None, ref_embs_mask=None, audio_datas=None, audio_data_masks=None, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
799
+ ):
800
+ if past_key_values is not None and past_key_values.get_seq_length() > 0:
801
+ if isinstance(past_key_values, Cache):
802
+ cache_length = past_key_values.get_seq_length()
803
+ past_length = past_key_values.seen_tokens
804
+ max_cache_length = past_key_values.get_max_cache_shape()
805
+ else:
806
+ cache_length = past_length = past_key_values[0][0].shape[2]
807
+ max_cache_length = None
808
+
809
+ # Keep only the unprocessed tokens:
810
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
811
+ # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
812
+ # input)
813
+ if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
814
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
815
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
816
+ # input_ids based on the past_length.
817
+ elif past_length < input_ids.shape[1]:
818
+ input_ids = input_ids[:, past_length:]
819
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
820
+
821
+ # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
822
+ if (
823
+ max_cache_length is not None
824
+ and attention_mask is not None
825
+ and cache_length + input_ids.shape[1] > max_cache_length
826
+ ):
827
+ attention_mask = attention_mask[:, -max_cache_length:]
828
+
829
+ position_ids = kwargs.get("position_ids", None)
830
+ if attention_mask is not None and position_ids is None:
831
+ # create position_ids on the fly for batch generation
832
+ position_ids = attention_mask.long().cumsum(-1) - 1
833
+ position_ids.masked_fill_(attention_mask == 0, 1)
834
+ if past_key_values:
835
+ position_ids = position_ids[:, -input_ids.shape[1] :]
836
+
837
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
838
+ if inputs_embeds is None and \
839
+ (past_key_values is None or past_key_values.get_seq_length() <= 0):
840
+ inputs_embeds = self.model.embed_tokens(input_ids)
841
+ if inputs_embeds is not None and \
842
+ (past_key_values is None or past_key_values.get_seq_length() <= 0):
843
+ model_inputs = {"inputs_embeds": inputs_embeds, "ref_embs": ref_embs, "ref_embs_mask": ref_embs_mask, "audio_datas": audio_datas, "audio_data_masks": audio_data_masks}
844
+ else:
845
+ model_inputs = {"input_ids": input_ids, "ref_embs": None, "audio_datas": None, "audio_data_masks": None}
846
+
847
+ model_inputs.update(
848
+ {
849
+ "position_ids": position_ids,
850
+ "past_key_values": past_key_values,
851
+ "use_cache": kwargs.get("use_cache"),
852
+ "attention_mask": attention_mask,
853
+ }
854
+ )
855
+ return model_inputs
856
+
857
+ def _update_model_kwargs_for_generation(
858
+ self,
859
+ outputs,
860
+ model_kwargs: Dict[str, Any],
861
+ num_new_token: int = 1,
862
+ ) -> Dict[str, Any]:
863
+ # update past_key_values
864
+ model_kwargs["past_key_values"] = outputs.past_key_values
865
+
866
+ # update attention mask
867
+ if "attention_mask" in model_kwargs:
868
+ attention_mask = model_kwargs["attention_mask"]
869
+ model_kwargs["attention_mask"] = torch.cat(
870
+ [attention_mask, attention_mask.new_ones((attention_mask.shape[0], num_new_token))], dim=-1
871
+ )
872
+
873
+ return model_kwargs
874
+
875
+ def _prepare_attention_mask_for_generation(
876
+ self,
877
+ inputs: torch.Tensor,
878
+ pad_token_id: Optional[int],
879
+ eos_token_id: Optional[Union[int, List[int]]],
880
+ ) -> torch.LongTensor:
881
+ is_input_ids = len(inputs.shape) == 2 and inputs.dtype in [torch.int, torch.long]
882
+ is_pad_token_in_inputs = (pad_token_id is not None) and (pad_token_id in inputs)
883
+ if isinstance(eos_token_id, int):
884
+ eos_token_id = [eos_token_id]
885
+ is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or (pad_token_id not in eos_token_id)
886
+
887
+ # Check if input is input_ids and padded -> only then is attention_mask defined
888
+ if is_input_ids and is_pad_token_in_inputs and is_pad_token_not_equal_to_eos_token_id:
889
+ return inputs.ne(pad_token_id).long()
890
+ else:
891
+ return torch.ones(inputs.shape[:2], dtype=torch.long, device=inputs.device)
892
+
893
+ @torch.inference_mode()
894
+ def run_generate(
895
+ self,
896
+ input_ids: torch.LongTensor,
897
+ ref_embs: Optional[List[torch.Tensor]] = None,
898
+ ref_embs_mask: Optional[torch.LongTensor] = None,
899
+ audio_datas: Optional[torch.FloatTensor] = None,
900
+ audio_data_masks: Optional[torch.LongTensor] = None,
901
+ max_new_tokens: Optional[int] = 128,
902
+ pad_token_id: Optional[int] = None,
903
+ eos_token_id: Optional[Union[int, List[int]]] = None,
904
+ streamer: Optional["BaseStreamer"] = None,
905
+ llm_audio_token_id: Optional[int] = None,
906
+ min_audio_token_id: Optional[int] = None,
907
+ temperature=0.2,
908
+ top_k=50,
909
+ audio_temperature=0.2,
910
+ audio_top_k=50,
911
+ ):
912
+ assert eos_token_id is not None and pad_token_id is not None, "eos_token_id and pad_token_id are required for inference"
913
+ assert llm_audio_token_id is not None and min_audio_token_id is not None, "llm_audio_token_id and min_audio_token_id are required for inference"
914
+ assert len(input_ids.shape) == 2 or len(input_ids.shape) == 3, f"input_ids is supposed to be [batch, seq_len] or [batch, seq_len, num_codebooks], and got {input_ids.shape}"
915
+
916
+ eos_token_id_tensor = torch.tensor([eos_token_id]).to(input_ids.device)
917
+
918
+ # keep track of which sequences are already finished
919
+ unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
920
+
921
+ # Extend input_ids with additional num_codebooks dim
922
+ if len(input_ids.shape) == 2:
923
+ input_ids = input_ids[:, :, None].expand(1, 1, self.config.num_codebooks)
924
+
925
+ this_peer_finished = False # used by synced_gpus only
926
+ max_length = input_ids.shape[1] + max_new_tokens
927
+
928
+ model_kwargs = {
929
+ "use_cache": True,
930
+ "past_key_values": DynamicCache(),
931
+ "attention_mask": self._prepare_attention_mask_for_generation(
932
+ input_ids, pad_token_id, eos_token_id
933
+ ),
934
+ }
935
+ # auto-regressive generation
936
+ while True:
937
+ # prepare model inputs
938
+ model_inputs = self._prepare_inputs_for_generation(
939
+ input_ids,
940
+ ref_embs=ref_embs,
941
+ ref_embs_mask=ref_embs_mask,
942
+ audio_datas=audio_datas,
943
+ audio_data_masks=audio_data_masks,
944
+ **model_kwargs
945
+ )
946
+
947
+ # forward pass to get next token
948
+ outputs = self(
949
+ **model_inputs,
950
+ return_dict=True,
951
+ )
952
+ audio_tokens = self.audio_transformer.inference(
953
+ outputs.last_hidden_state,
954
+ temperature=audio_temperature,
955
+ top_k=audio_top_k,
956
+ )
957
+ audio_tokens = torch.stack(
958
+ [
959
+ audio_tokens[:, :, ci] + min_audio_token_id + ci*self.config.codebook_size
960
+ for ci in range(self.config.num_codebooks)
961
+ ],
962
+ dim=2,
963
+ )
964
+
965
+ next_token_logits = outputs.logits[:, -1, :]
966
+
967
+ # pre-process distribution
968
+ # Apply temperature and top-k
969
+ if temperature > 0:
970
+ next_token_logits = next_token_logits / temperature
971
+ if top_k > 0:
972
+ top_k = min(top_k, next_token_logits.size(-1)) # Safety check
973
+ # Remove all tokens with a probability less than the last token of the top-k
974
+ indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None]
975
+ next_token_logits = next_token_logits.masked_fill(indices_to_remove, -float("Inf"))
976
+
977
+ # sample
978
+ probs = nn.functional.softmax(next_token_logits, dim=-1)
979
+ next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
980
+
981
+ # finished sentences should have their next token be a padding token
982
+ if eos_token_id is not None:
983
+ if pad_token_id is None:
984
+ raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
985
+ next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
986
+
987
+ # Append NUM_CODEBOOK text tokens or audio_tokens
988
+ if len(next_tokens.shape) == 1:
989
+ next_tokens = next_tokens[:, None, None].expand(-1, 1, self.config.num_codebooks)
990
+ next_tokens = torch.where(next_tokens==llm_audio_token_id, audio_tokens, next_tokens)
991
+
992
+ input_ids = torch.cat([input_ids, next_tokens], dim=1)
993
+ if streamer is not None:
994
+ streamer.put(next_tokens.cpu())
995
+ model_kwargs = self._update_model_kwargs_for_generation(
996
+ outputs, model_kwargs
997
+ )
998
+
999
+ # if eos_token was found in one sentence, set sentence to finished
1000
+ if eos_token_id_tensor is not None:
1001
+ unfinished_sequences = unfinished_sequences.mul(
1002
+ next_tokens[:, :, 0].ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=1)
1003
+ )
1004
+
1005
+ # stop when each sentence is finished
1006
+ if unfinished_sequences.max() == 0:
1007
+ this_peer_finished = True
1008
+
1009
+ # stop if we exceed the maximum length
1010
+ if input_ids.shape[1] >= max_length:
1011
+ this_peer_finished = True
1012
+
1013
+ if this_peer_finished:
1014
+ break
1015
+
1016
+ if streamer is not None:
1017
+ streamer.end()
1018
+
1019
+ return input_ids
1020
+
1021
+
1022
+ # Modified from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L1103
1023
+ class VoilaAutonomousModel(LlamaPreTrainedModel):
1024
+ _tied_weights_keys = ["lm_head.weight"]
1025
+
1026
+ def __init__(self, config):
1027
+ super().__init__(config)
1028
+ self.model = LlamaModel(config)
1029
+ self.vocab_size = config.vocab_size
1030
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1031
+ self.pad_vocab_size_multiple = 64
1032
+
1033
+ self.ref_emb_linear = nn.Linear(256, config.hidden_size, bias=True)
1034
+ self.audio_transformer = AudioTransformer(config, use_sdpa=False)
1035
+ self.voila_predictor = nn.Sequential(nn.Linear(config.hidden_size, 2, bias=True),)
1036
+
1037
+ # Initialize weights and apply final processing
1038
+ self.post_init()
1039
+
1040
+ def get_input_embeddings(self):
1041
+ return self.model.embed_tokens
1042
+
1043
+ def set_input_embeddings(self, value):
1044
+ self.model.embed_tokens = value
1045
+
1046
+ def get_output_embeddings(self):
1047
+ return self.lm_head
1048
+
1049
+ def set_output_embeddings(self, new_embeddings):
1050
+ self.lm_head = new_embeddings
1051
+
1052
+ def set_decoder(self, decoder):
1053
+ self.model = decoder
1054
+
1055
+ def get_decoder(self):
1056
+ return self.model
1057
+
1058
+ def forward(
1059
+ self,
1060
+ input_ids: torch.LongTensor = None,
1061
+ attention_mask: Optional[torch.Tensor] = None,
1062
+ position_ids: Optional[torch.LongTensor] = None,
1063
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
1064
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1065
+ labels: Optional[torch.LongTensor] = None,
1066
+ audio_labels: Optional[torch.LongTensor] = None,
1067
+ voila_labels: Optional[torch.LongTensor] = None,
1068
+ ref_embs: Optional[List[torch.Tensor]] = None,
1069
+ ref_embs_mask: Optional[torch.LongTensor] = None,
1070
+ use_cache: Optional[bool] = None,
1071
+ output_attentions: Optional[bool] = None,
1072
+ output_hidden_states: Optional[bool] = None,
1073
+ return_dict: Optional[bool] = None,
1074
+ cache_position: Optional[torch.LongTensor] = None,
1075
+ num_logits_to_keep: int = 0,
1076
+ ) -> Union[Tuple, VoilaOutput]:
1077
+ r"""
1078
+ Args:
1079
+ input_ids: [bs, seq_len, num_codebooks]
1080
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1081
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1082
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1083
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1084
+ """
1085
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1086
+ output_hidden_states = (
1087
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1088
+ )
1089
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1090
+
1091
+ if input_ids is not None and inputs_embeds is not None:
1092
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
1093
+ if inputs_embeds is None:
1094
+ inputs_embeds = self.model.embed_tokens(input_ids)
1095
+ assert len(inputs_embeds.shape) == 4
1096
+ if len(inputs_embeds.shape) == 4:
1097
+ inputs_embeds = inputs_embeds.mean(dim=2)
1098
+
1099
+ if self.training or \
1100
+ (past_key_values is None and ref_embs is not None) or \
1101
+ (past_key_values is not None and past_key_values.get_seq_length() < 4 and ref_embs is not None):
1102
+ ref_embs = self.ref_emb_linear(ref_embs.to(self.ref_emb_linear.weight.dtype))
1103
+ ref_embs = ref_embs * ref_embs_mask.unsqueeze(-1).unsqueeze(-1)
1104
+ # (padding_left,padding_right,padding_top,padding_bottom,padding_front,padding_back)
1105
+ padding = (0, 0, 4, inputs_embeds.shape[1] - 5, 0, 0)
1106
+ ref_embs = torch.nn.functional.pad(ref_embs, padding, mode='constant', value=0.0)
1107
+ inputs_embeds = inputs_embeds + ref_embs
1108
+
1109
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1110
+ outputs = self.model(
1111
+ attention_mask=attention_mask,
1112
+ position_ids=position_ids,
1113
+ past_key_values=past_key_values,
1114
+ inputs_embeds=inputs_embeds,
1115
+ use_cache=use_cache,
1116
+ output_attentions=output_attentions,
1117
+ output_hidden_states=output_hidden_states,
1118
+ return_dict=return_dict,
1119
+ cache_position=cache_position,
1120
+ )
1121
+
1122
+ hidden_states = outputs[0]
1123
+ if self.config.pretraining_tp > 1:
1124
+ lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
1125
+ logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
1126
+ logits = torch.cat(logits, dim=-1)
1127
+ else:
1128
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
1129
+ logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
1130
+
1131
+ # calc voila_predict_loss
1132
+ voila_pred = self.voila_predictor(hidden_states)
1133
+ voila_pred = voila_pred.float()
1134
+
1135
+ loss = None
1136
+
1137
+ if not return_dict:
1138
+ output = (logits,) + outputs[1:]
1139
+ return (loss,) + output if loss is not None else output
1140
+
1141
+ return VoilaOutput(
1142
+ loss=loss,
1143
+ logits=logits,
1144
+ last_hidden_state=hidden_states,
1145
+ past_key_values=outputs.past_key_values,
1146
+ hidden_states=outputs.hidden_states,
1147
+ attentions=outputs.attentions,
1148
+ voila_pred=voila_pred,
1149
+ )
1150
+
1151
+ def _prepare_inputs_for_generation(
1152
+ self, input_ids, ref_embs=None, ref_embs_mask=None, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
1153
+ ):
1154
+ if past_key_values is not None and past_key_values.get_seq_length() > 0:
1155
+ if isinstance(past_key_values, Cache):
1156
+ cache_length = past_key_values.get_seq_length()
1157
+ past_length = past_key_values.seen_tokens
1158
+ max_cache_length = past_key_values.get_max_cache_shape()
1159
+ else:
1160
+ cache_length = past_length = past_key_values[0][0].shape[2]
1161
+ max_cache_length = None
1162
+
1163
+ # Keep only the unprocessed tokens:
1164
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
1165
+ # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
1166
+ # input)
1167
+ if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
1168
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
1169
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
1170
+ # input_ids based on the past_length.
1171
+ elif past_length < input_ids.shape[1]:
1172
+ input_ids = input_ids[:, past_length:]
1173
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
1174
+
1175
+ # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
1176
+ if (
1177
+ max_cache_length is not None
1178
+ and attention_mask is not None
1179
+ and cache_length + input_ids.shape[1] > max_cache_length
1180
+ ):
1181
+ attention_mask = attention_mask[:, -max_cache_length:]
1182
+
1183
+ position_ids = kwargs.get("position_ids", None)
1184
+ if attention_mask is not None and position_ids is None:
1185
+ # create position_ids on the fly for batch generation
1186
+ position_ids = attention_mask.long().cumsum(-1) - 1
1187
+ position_ids.masked_fill_(attention_mask == 0, 1)
1188
+ if past_key_values:
1189
+ position_ids = position_ids[:, -input_ids.shape[1] :]
1190
+
1191
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1192
+ if inputs_embeds is None and \
1193
+ (past_key_values is None or past_key_values.get_seq_length() <= 0):
1194
+ inputs_embeds = self.model.embed_tokens(input_ids)
1195
+ if inputs_embeds is not None and \
1196
+ (past_key_values is None or past_key_values.get_seq_length() <= 0):
1197
+ model_inputs = {"inputs_embeds": inputs_embeds, "ref_embs": ref_embs, "ref_embs_mask": ref_embs_mask}
1198
+ else:
1199
+ model_inputs = {"input_ids": input_ids, "ref_embs": None}
1200
+
1201
+ model_inputs.update(
1202
+ {
1203
+ "position_ids": position_ids,
1204
+ "past_key_values": past_key_values,
1205
+ "use_cache": kwargs.get("use_cache"),
1206
+ "attention_mask": attention_mask,
1207
+ }
1208
+ )
1209
+ return model_inputs
1210
+
1211
+ def _update_model_kwargs_for_generation(
1212
+ self,
1213
+ outputs,
1214
+ model_kwargs: Dict[str, Any],
1215
+ num_new_token: int = 1,
1216
+ ) -> Dict[str, Any]:
1217
+ # update past_key_values
1218
+ model_kwargs["past_key_values"] = outputs.past_key_values
1219
+
1220
+ # update attention mask
1221
+ if "attention_mask" in model_kwargs:
1222
+ attention_mask = model_kwargs["attention_mask"]
1223
+ model_kwargs["attention_mask"] = torch.cat(
1224
+ [attention_mask, attention_mask.new_ones((attention_mask.shape[0], num_new_token))], dim=-1
1225
+ )
1226
+
1227
+ return model_kwargs
1228
+
1229
+ def _prepare_attention_mask_for_generation(
1230
+ self,
1231
+ inputs: torch.Tensor,
1232
+ pad_token_id: Optional[int],
1233
+ eos_token_id: Optional[Union[int, List[int]]],
1234
+ ) -> torch.LongTensor:
1235
+ is_input_ids = len(inputs.shape) == 2 and inputs.dtype in [torch.int, torch.long]
1236
+ is_pad_token_in_inputs = (pad_token_id is not None) and (pad_token_id in inputs)
1237
+ if isinstance(eos_token_id, int):
1238
+ eos_token_id = [eos_token_id]
1239
+ is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or (pad_token_id not in eos_token_id)
1240
+
1241
+ # Check if input is input_ids and padded -> only then is attention_mask defined
1242
+ if is_input_ids and is_pad_token_in_inputs and is_pad_token_not_equal_to_eos_token_id:
1243
+ return inputs.ne(pad_token_id).long()
1244
+ else:
1245
+ return torch.ones(inputs.shape[:2], dtype=torch.long, device=inputs.device)
1246
+
1247
+ @torch.inference_mode()
1248
+ def run_generate(
1249
+ self,
1250
+ input_ids: torch.LongTensor,
1251
+ input_generator,
1252
+ ref_embs: Optional[List[torch.Tensor]] = None,
1253
+ ref_embs_mask: Optional[torch.LongTensor] = None,
1254
+ max_new_tokens: Optional[int] = 128,
1255
+ pad_token_id: Optional[int] = None,
1256
+ eos_token_id: Optional[Union[int, List[int]]] = None,
1257
+ streamer: Optional["BaseStreamer"] = None,
1258
+ llm_audio_token_id: Optional[int] = None,
1259
+ min_audio_token_id: Optional[int] = None,
1260
+ llm_assistant_token_id: Optional[int] = None,
1261
+ temperature=0.2,
1262
+ top_k=50,
1263
+ audio_temperature=0.8,
1264
+ audio_top_k=50,
1265
+ ):
1266
+ assert eos_token_id is not None and pad_token_id is not None, "eos_token_id and pad_token_id are required for inference"
1267
+ assert llm_audio_token_id is not None and min_audio_token_id is not None, "llm_audio_token_id and min_audio_token_id are required for inference"
1268
+ assert len(input_ids.shape) == 2 or len(input_ids.shape) == 3, f"input_ids is supposed to be [batch, seq_len] or [batch, seq_len, num_codebooks], and got {input_ids.shape}"
1269
+
1270
+ eos_token_id_tensor = torch.tensor([eos_token_id]).to(input_ids.device)
1271
+
1272
+ # keep track of which sequences are already finished
1273
+ unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
1274
+
1275
+ # Extend input_ids with additional num_codebooks dim
1276
+ input_ids = input_ids.clone()
1277
+ if len(input_ids.shape) == 2:
1278
+ input_ids = input_ids[:, :, None].expand(1, 1, self.config.num_codebooks)
1279
+
1280
+ this_peer_finished = False # used by synced_gpus only
1281
+ max_length = input_ids.shape[1] + max_new_tokens
1282
+
1283
+ model_kwargs = {
1284
+ "use_cache": True,
1285
+ "past_key_values": DynamicCache(),
1286
+ "attention_mask": self._prepare_attention_mask_for_generation(
1287
+ input_ids, pad_token_id, eos_token_id
1288
+ ),
1289
+ }
1290
+ speaking = False
1291
+ # auto-regressive generation
1292
+ while True:
1293
+ # prepare model inputs
1294
+ model_inputs = self._prepare_inputs_for_generation(
1295
+ input_ids,
1296
+ ref_embs=ref_embs,
1297
+ ref_embs_mask=ref_embs_mask,
1298
+ **model_kwargs
1299
+ )
1300
+
1301
+ # forward pass to get next token
1302
+ outputs = self(
1303
+ **model_inputs,
1304
+ return_dict=True,
1305
+ )
1306
+ audio_tokens = self.audio_transformer.inference(
1307
+ outputs.last_hidden_state,
1308
+ temperature=audio_temperature,
1309
+ top_k=audio_top_k,
1310
+ )
1311
+ audio_tokens = torch.stack(
1312
+ [
1313
+ audio_tokens[:, :, ci] + min_audio_token_id + ci*self.config.codebook_size
1314
+ for ci in range(self.config.num_codebooks)
1315
+ ],
1316
+ dim=2,
1317
+ )
1318
+
1319
+ next_token_logits = outputs.logits[:, -1, :]
1320
+
1321
+ # voila head output
1322
+ voila_head_pred = outputs.voila_pred[:, -1, :]
1323
+ voila_head_pred = torch.argmax(voila_head_pred, dim=-1)
1324
+ voila_head_pred = voila_head_pred.cpu()[0].item()
1325
+
1326
+ # pre-process distribution
1327
+ # Apply temperature and top-k
1328
+ if temperature > 0:
1329
+ next_token_logits = next_token_logits / temperature
1330
+ if top_k > 0:
1331
+ top_k = min(top_k, next_token_logits.size(-1)) # Safety check
1332
+ # Remove all tokens with a probability less than the last token of the top-k
1333
+ indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None]
1334
+ next_token_logits = next_token_logits.masked_fill(indices_to_remove, -float("Inf"))
1335
+
1336
+ # sample
1337
+ probs = nn.functional.softmax(next_token_logits, dim=-1)
1338
+ next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
1339
+
1340
+ # voila head pred == 1, use assistant token
1341
+ if voila_head_pred == 1 and not speaking:
1342
+ next_tokens[0] = llm_assistant_token_id
1343
+ speaking = True
1344
+ elif next_tokens[0] == eos_token_id:
1345
+ speaking = False
1346
+
1347
+ # finished sentences should have their next token be a padding token
1348
+ if eos_token_id is not None:
1349
+ if pad_token_id is None:
1350
+ raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
1351
+ next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
1352
+
1353
+ # Append NUM_CODEBOOK text tokens or audio_tokens
1354
+ if len(next_tokens.shape) == 1:
1355
+ next_tokens = next_tokens[:, None, None].expand(-1, 1, self.config.num_codebooks)
1356
+ audio_token_mask = next_tokens == llm_audio_token_id
1357
+ next_tokens = next_tokens * torch.logical_not(audio_token_mask) + audio_tokens * audio_token_mask
1358
+
1359
+ if audio_token_mask[0, 0, 0].item():
1360
+ try:
1361
+ new_input_tokens = next(input_generator)
1362
+ except:
1363
+ this_peer_finished = True
1364
+ break
1365
+ new_input_tokens = new_input_tokens[None,None,:]
1366
+ else:
1367
+ new_input_tokens = next_tokens
1368
+ new_input_tokens = torch.cat([new_input_tokens, next_tokens], dim=2)
1369
+
1370
+ input_ids = torch.cat([input_ids, new_input_tokens], dim=1)
1371
+ if streamer is not None:
1372
+ streamer.put(next_tokens.cpu())
1373
+ model_kwargs = self._update_model_kwargs_for_generation(
1374
+ outputs, model_kwargs
1375
+ )
1376
+
1377
+ # # if eos_token was found in one sentence, set sentence to finished
1378
+ # if eos_token_id_tensor is not None:
1379
+ # unfinished_sequences = unfinished_sequences.mul(
1380
+ # next_tokens[:, :, 0].ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=1)
1381
+ # )
1382
+
1383
+ # # stop when each sentence is finished
1384
+ # if unfinished_sequences.max() == 0:
1385
+ # this_peer_finished = True
1386
+
1387
+ # stop if we exceed the maximum length
1388
+ if input_ids.shape[1] >= max_length:
1389
+ this_peer_finished = True
1390
+
1391
+ if this_peer_finished:
1392
+ break
1393
+
1394
+ if streamer is not None:
1395
+ streamer.end()
1396
+
1397
+ return input_ids
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ torchaudio
4
+ transformers
5
+ soundfile
6
+ librosa
7
+ jsonlines
8
+ gradio
9
+ pyannote.audio
spkr.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchaudio
3
+ from torchaudio.functional import resample
4
+
5
+ from pyannote.audio import Model
6
+ from pyannote.audio.pipelines.speaker_verification import PretrainedSpeakerEmbedding
7
+
8
+
9
+ class SpeakerEmbedding:
10
+ def __init__(self, model_path="pyannote/wespeaker-voxceleb-resnet34-LM", device="cuda"):
11
+ model = Model.from_pretrained(model_path).eval()
12
+
13
+ self.device = torch.device(device)
14
+ self.sample_rate = 16000
15
+ self.model = model.to(self.device)
16
+
17
+ @torch.no_grad()
18
+ def __call__(self, wav, sr):
19
+ wav = torch.tensor(wav, device=self.device)
20
+ if sr != self.sample_rate:
21
+ wav = resample(wav, sr, self.sample_rate)
22
+ sr = self.sample_rate
23
+
24
+ assert len(wav.shape) <= 3
25
+ is_batch = False
26
+ if len(wav.shape) == 3:
27
+ is_batch = True
28
+ elif len(wav.shape) == 2:
29
+ wav = wav[None, :, :]
30
+ else:
31
+ wav = wav[None, None, :]
32
+
33
+ with torch.inference_mode():
34
+ embeddings = self.model(wav)
35
+
36
+ if is_batch:
37
+ return embeddings
38
+ else:
39
+ return embeddings[0]
40
+
41
+ if __name__ == '__main__':
42
+ import argparse
43
+ parser = argparse.ArgumentParser()
44
+ parser.add_argument("--wav", type=str, required=True)
45
+ args = parser.parse_args()
46
+
47
+ model = SpeakerEmbedding(device="cuda")
48
+
49
+ wav, sr = torchaudio.load(args.wav)
50
+ print(model(wav, sr))
tokenize_func.py ADDED
@@ -0,0 +1,443 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import copy
3
+ import librosa
4
+ import numpy as np
5
+
6
+
7
+ AUDIO_TOKEN_FORMAT = "<|{}|>"
8
+
9
+ DEFAULT_SYSTEM_START_TOKEN = "<SYSTEM>"
10
+ DEFAULT_SYSTEM_END_TOKEN = "</SYSTEM>"
11
+
12
+ DEFAULT_TTS_REF_START_TOKEN = "<au_tts_ref_start>"
13
+ DEFAULT_TTS_REF_END_TOKEN = "<au_tts_ref_end>"
14
+ DEFAULT_TTS_REF_TOKEN = "<au_tts_ref>"
15
+
16
+ DEFAULT_CHAT_REF_START_TOKEN = "<au_chat_ref_start>"
17
+ DEFAULT_CHAT_REF_END_TOKEN = "<au_chat_ref_end>"
18
+ DEFAULT_CHAT_REF_TOKEN = "<au_chat_ref>"
19
+
20
+ DEFAULT_HUMAN_TOKEN = "<|HUMAN|>"
21
+ DEFAULT_ASSISTANT_TOKEN = "<|VOILA|>"
22
+
23
+ DEFAULT_AUDIO_TOKEN = "<au_token>"
24
+
25
+ # ===================================
26
+ # task special token
27
+ # -----------------------------------
28
+ TASK_ASR_TOKEN = "<asr>"
29
+ TASK_TTS_TOKEN = "<tts>"
30
+ TASK_CHAT_TOKEN = "<chat>"
31
+ TASK_STREAM_CHAT_TOKEN = "<stream_chat>"
32
+
33
+ TASK_ASR_TEXT_OUTPUT = "<asr_text_output>"
34
+ TASK_TTS_AUDIO_OUTPUT = "<tts_audio_output>"
35
+ TASK_CHAT_TEXT_OUTPUT = "<chat_text_output>"
36
+ TASK_CHAT_AUDIO_OUTPUT = "<chat_audio_output>"
37
+
38
+ CHAT_AUDIO_TEXT_SPLIT_TOKEN = "<chat_audio_text_split>"
39
+ # ===================================
40
+
41
+ PREPEND_LEN = 80
42
+ SEG_LEN = 640
43
+ AUDIO_SR = 16000
44
+
45
+ TASK_TYPE_CONF = {
46
+ "chat_asr": TASK_ASR_TOKEN + TASK_ASR_TEXT_OUTPUT,
47
+ "chat_tts": TASK_TTS_TOKEN + TASK_TTS_AUDIO_OUTPUT,
48
+ "chat_tito": TASK_CHAT_TOKEN + TASK_CHAT_TEXT_OUTPUT,
49
+ "chat_tiao": TASK_CHAT_TOKEN + TASK_CHAT_AUDIO_OUTPUT,
50
+ "chat_aiao": TASK_CHAT_TOKEN + TASK_CHAT_AUDIO_OUTPUT,
51
+ "chat_atiao": TASK_CHAT_TOKEN + TASK_CHAT_AUDIO_OUTPUT,
52
+ "chat_aiao_auto": TASK_STREAM_CHAT_TOKEN + TASK_CHAT_AUDIO_OUTPUT,
53
+ }
54
+
55
+
56
+ def _get_zero_audio_pad(token_num):
57
+ return np.zeros(SEG_LEN*token_num)
58
+
59
+ def _wrapper_audio_tokens(audio_tokens, num_codebooks, codebook_size):
60
+ ret_audio_tokens = []
61
+ for n in range(num_codebooks):
62
+ audio_token = audio_tokens[n]
63
+ ret_audio_tokens.append(''.join([AUDIO_TOKEN_FORMAT.format(au + n*codebook_size) if isinstance(au, int) else au for au in audio_token]))
64
+ return ret_audio_tokens
65
+
66
+ def _wrapper_audio_tokens_autonomous(audio_tokens, num_codebooks, codebook_size, audio_token_min_id):
67
+ ret_audio_tokens = []
68
+ for n in range(num_codebooks):
69
+ audio_token = audio_tokens[n]
70
+ ret_audio_tokens.append([(au + n*codebook_size + audio_token_min_id) for au in audio_token])
71
+ return ret_audio_tokens
72
+
73
+ # Item format
74
+ # {
75
+ # "instruction": "",
76
+ # "conversations": [
77
+ # {
78
+ # "from": "user" or "assistant",
79
+ # "text": "",
80
+ # "audio": {
81
+ # "array": [],
82
+ # "sr": 16000,
83
+ # "bytes": "",
84
+ # "file": "",
85
+ # },
86
+ # }
87
+ # ],
88
+ # }
89
+ def _token_input_format(item, tokenizer, tokenizer_voila, dataset_cfg):
90
+ task_type = dataset_cfg["task_type"]
91
+ num_codebooks = dataset_cfg["num_codebooks"]
92
+ codebook_size = dataset_cfg["codebook_size"]
93
+
94
+ task_token = TASK_TYPE_CONF[task_type]
95
+
96
+ # Construct system message
97
+ system = item["instruction"]
98
+ if task_type in ["chat_aiao", "chat_atiao", "chat_tiao"]:
99
+ system = DEFAULT_CHAT_REF_START_TOKEN + DEFAULT_CHAT_REF_TOKEN + DEFAULT_CHAT_REF_END_TOKEN + system
100
+ elif task_type == "chat_tts":
101
+ system = DEFAULT_TTS_REF_START_TOKEN + DEFAULT_TTS_REF_TOKEN + DEFAULT_TTS_REF_END_TOKEN + system
102
+ else:
103
+ print (f"task type {task_type} do not use ref.")
104
+ system = task_token + system
105
+ system = DEFAULT_SYSTEM_START_TOKEN + system + DEFAULT_SYSTEM_END_TOKEN
106
+
107
+ # Get ids for system
108
+ system_ids = tokenizer.encode(system, add_special_tokens=False)
109
+
110
+ # Copy into num_codebooks input ids
111
+ input_ids_list = []
112
+ for _ in range(num_codebooks):
113
+ input_ids_list.append(copy.deepcopy(system_ids))
114
+
115
+ # Assemble conversations
116
+ for i, turn in enumerate(item["conversations"]):
117
+ if turn['from'] == 'assistant':
118
+ # task with audio token as input, prepare audio token
119
+ if task_type in ["chat_aiao", "chat_tts"]:
120
+ if "audio" not in turn:
121
+ content = DEFAULT_ASSISTANT_TOKEN
122
+ content_ids = tokenizer.encode(content, add_special_tokens=False)
123
+ for n in range(num_codebooks):
124
+ input_ids_list[n] += copy.deepcopy(content_ids)
125
+ else:
126
+ # Load audio
127
+ if 'array' in turn['audio']:
128
+ assert "sr" in turn["audio"]
129
+ if len(turn["audio"]['array'].shape) > 1:
130
+ assert turn["audio"]['array'].shape[0] <= 2
131
+ turn["audio"]['array'] = librosa.to_mono(turn["audio"]['array'])
132
+ audio = librosa.resample(turn["audio"]['array'], orig_sr=turn["audio"]["sr"], target_sr=AUDIO_SR)
133
+ elif "bytes" in turn['audio']:
134
+ audio, _ = librosa.load(io.BytesIO(turn["audio"]['bytes']), sr=AUDIO_SR)
135
+ elif "file" in turn['audio']:
136
+ audio, _ = librosa.load(turn["audio"]['file'], sr=AUDIO_SR)
137
+ else:
138
+ raise Exception(f"No audio input for task {task_type}")
139
+
140
+ # get audio token
141
+ audio_tokens = tokenizer_voila.encode(audio, sr=AUDIO_SR)
142
+ audio_tokens = audio_tokens.cpu().numpy().tolist()
143
+ audio_tokens = _wrapper_audio_tokens(audio_tokens, num_codebooks, codebook_size)
144
+
145
+ for n in range(num_codebooks):
146
+ content = DEFAULT_ASSISTANT_TOKEN + audio_tokens[n] + tokenizer.eos_token
147
+ content_ids = tokenizer.encode(content, add_special_tokens=False, truncation=True,
148
+ max_length=tokenizer.model_max_length)
149
+ input_ids_list[n] += content_ids
150
+
151
+ elif task_type in ["chat_tito", "chat_asr"]:
152
+ if "text" not in turn:
153
+ content = DEFAULT_ASSISTANT_TOKEN
154
+ content_ids = tokenizer.encode(content, add_special_tokens=False)
155
+ for n in range(num_codebooks):
156
+ input_ids_list[n] += copy.deepcopy(content_ids)
157
+ else:
158
+ text = turn['text'].strip()
159
+ content = DEFAULT_ASSISTANT_TOKEN + text + tokenizer.eos_token
160
+ content_ids = tokenizer.encode(content, add_special_tokens=False, truncation=True,
161
+ max_length=tokenizer.model_max_length)
162
+ for n in range(num_codebooks):
163
+ input_ids_list[n] += copy.deepcopy(content_ids)
164
+ else:
165
+ raise ValueError (f"[Error] Invalid data type of {task_type}.")
166
+ else:
167
+ # task with audio token as input, prepare audio token
168
+ if task_type in ["chat_aiao", "chat_asr"]:
169
+ # Load audio
170
+ assert "audio" in turn
171
+ if 'array' in turn['audio']:
172
+ assert "sr" in turn["audio"]
173
+ if len(turn["audio"]['array'].shape) > 1:
174
+ assert turn["audio"]['array'].shape[0] <= 2
175
+ turn["audio"]['array'] = librosa.to_mono(turn["audio"]['array'])
176
+ audio = librosa.resample(turn["audio"]['array'], orig_sr=turn["audio"]["sr"], target_sr=AUDIO_SR)
177
+ elif "bytes" in turn['audio']:
178
+ audio, _ = librosa.load(io.BytesIO(turn["audio"]['bytes']), sr=AUDIO_SR)
179
+ elif "file" in turn['audio']:
180
+ audio, _ = librosa.load(turn["audio"]['file'], sr=AUDIO_SR)
181
+ else:
182
+ raise Exception(f"No audio input for task {task_type}")
183
+
184
+ # get audio token
185
+ audio_tokens = tokenizer_voila.encode(audio, sr=AUDIO_SR)
186
+ audio_tokens = audio_tokens.cpu().numpy().tolist()
187
+ audio_tokens = _wrapper_audio_tokens(audio_tokens, num_codebooks, codebook_size)
188
+
189
+ for n in range(num_codebooks):
190
+ content = DEFAULT_HUMAN_TOKEN + audio_tokens[n]
191
+ content_ids = tokenizer.encode(content, add_special_tokens=False, truncation=True,
192
+ max_length=tokenizer.model_max_length)
193
+ input_ids_list[n] += copy.deepcopy(content_ids)
194
+ elif task_type in ["chat_tito", "chat_tts"]:
195
+ text = turn['text'].strip()
196
+ content = DEFAULT_HUMAN_TOKEN + text
197
+ content_ids = tokenizer.encode(content, add_special_tokens=False, truncation=True,
198
+ max_length=tokenizer.model_max_length)
199
+ for n in range(num_codebooks):
200
+ input_ids_list[n] += copy.deepcopy(content_ids)
201
+ else:
202
+ raise ValueError (f"[Error] Invalid data type of {task_type}.")
203
+
204
+ for n in range(num_codebooks):
205
+ input_ids_list[n] = input_ids_list[n][:tokenizer.model_max_length]
206
+
207
+ return input_ids_list, None, None, None
208
+
209
+ def _token_input_format_autonomous(item, tokenizer, tokenizer_voila, dataset_cfg):
210
+ task_type = dataset_cfg["task_type"]
211
+ num_codebooks = dataset_cfg["num_codebooks"]
212
+ codebook_size = dataset_cfg["codebook_size"]
213
+ assert task_type == "chat_aiao_auto", f"only support chat_aiao_auto, {task_type} is invalid"
214
+
215
+ DEFAULT_HUMAN_TOKEN_ID = tokenizer.convert_tokens_to_ids(DEFAULT_HUMAN_TOKEN)
216
+ assert isinstance(DEFAULT_HUMAN_TOKEN_ID, int), "DEFAULT_HUMAN_TOKEN_ID should be an integer"
217
+ AUDIO_MIN_TOKEN_ID = tokenizer.convert_tokens_to_ids(AUDIO_TOKEN_FORMAT.format(0))
218
+ assert isinstance(AUDIO_MIN_TOKEN_ID, int), "AUDIO_MIN_TOKEN_ID should be an integer"
219
+
220
+ task_token = TASK_TYPE_CONF[task_type]
221
+
222
+ # Construct system message
223
+ system = DEFAULT_CHAT_REF_START_TOKEN + DEFAULT_CHAT_REF_TOKEN + DEFAULT_CHAT_REF_END_TOKEN
224
+ system = task_token + system
225
+ system = DEFAULT_SYSTEM_START_TOKEN + system + DEFAULT_SYSTEM_END_TOKEN
226
+
227
+ # Get ids for system
228
+ system_ids_list = [[], []]
229
+ system_ids = tokenizer.encode(system, add_special_tokens=False)
230
+
231
+ # Insert instruction tokens into system prompt tokens
232
+ instruction = item["instruction"]
233
+ if instruction != "":
234
+ instruction_ids = tokenizer.encode(instruction, add_special_tokens=False)
235
+ else:
236
+ instruction_ids = []
237
+
238
+ system_ids_list[0] = system_ids[:-1] + instruction_ids + system_ids[-1:]
239
+ system_ids_list[1] = system_ids[:-1] + instruction_ids + system_ids[-1:]
240
+
241
+ # Copy into num_codebooks input ids
242
+ channel1_input_ids_list = [[] for _ in range(num_codebooks)]
243
+ channel2_input_ids_list = [[] for _ in range(num_codebooks)]
244
+ for n in range(num_codebooks):
245
+ channel1_input_ids_list[n] += copy.deepcopy(system_ids_list[0]) + [DEFAULT_HUMAN_TOKEN_ID]
246
+ channel2_input_ids_list[n] += copy.deepcopy(system_ids_list[1]) + [DEFAULT_HUMAN_TOKEN_ID]
247
+
248
+ # prepare audio token to simulate streaming input
249
+ audio_meta = item['conversations'][0]['audio']
250
+ if 'array' in audio_meta:
251
+ assert "sr" in audio_meta
252
+ if len(audio_meta['array'].shape) > 1:
253
+ assert audio_meta['array'].shape[0] <= 2
254
+ audio_meta['array'] = librosa.to_mono(audio_meta['array'])
255
+ audio = librosa.resample(audio_meta['array'], orig_sr=audio_meta["sr"], target_sr=AUDIO_SR)
256
+ elif "bytes" in audio_meta:
257
+ audio, _ = librosa.load(io.BytesIO(audio_meta['bytes']), sr=AUDIO_SR)
258
+ elif "file" in audio_meta:
259
+ audio, _ = librosa.load(audio_meta['file'], sr=AUDIO_SR)
260
+ else:
261
+ raise Exception(f"No audio input for task {task_type}")
262
+
263
+ # get audio token
264
+ streaming_user_input_audio_tokens = tokenizer_voila.encode(audio, sr=AUDIO_SR)
265
+ streaming_user_input_audio_tokens = streaming_user_input_audio_tokens.cpu().numpy().tolist()
266
+ streaming_user_input_audio_tokens = _wrapper_audio_tokens_autonomous(streaming_user_input_audio_tokens, num_codebooks, codebook_size, AUDIO_MIN_TOKEN_ID)
267
+
268
+ return [channel1_input_ids_list, channel2_input_ids_list], None, None, streaming_user_input_audio_tokens
269
+
270
+ def _alpha_audio_input_format(item, tokenizer, dataset_cfg):
271
+ task_type = dataset_cfg["task_type"]
272
+ num_codebooks = dataset_cfg["num_codebooks"]
273
+ codebook_size = dataset_cfg["codebook_size"]
274
+
275
+ task_token = TASK_TYPE_CONF[task_type]
276
+
277
+ # Construct system message
278
+ system = item["instruction"]
279
+ if task_type in ["chat_aiao", "chat_atiao", "chat_tiao"]:
280
+ system = DEFAULT_CHAT_REF_START_TOKEN + DEFAULT_CHAT_REF_TOKEN + DEFAULT_CHAT_REF_END_TOKEN + system
281
+ elif task_type == "chat_tts":
282
+ system = DEFAULT_TTS_REF_START_TOKEN + DEFAULT_TTS_REF_TOKEN + DEFAULT_TTS_REF_END_TOKEN + system
283
+ else:
284
+ print (f"task type {task_type} do not use ref.")
285
+ system = task_token + system
286
+ system = DEFAULT_SYSTEM_START_TOKEN + system + DEFAULT_SYSTEM_END_TOKEN
287
+
288
+ # Get ids for system
289
+ system_ids = tokenizer.encode(system, add_special_tokens=False)
290
+
291
+ # Copy into num_codebooks input ids
292
+ input_ids_list = []
293
+ for _ in range(num_codebooks):
294
+ input_ids_list.append(copy.deepcopy(system_ids))
295
+
296
+ # Construct audio data and mask
297
+ audio_data = [np.array([0]*PREPEND_LEN)]
298
+ audio_data.append(_get_zero_audio_pad(len(system_ids)))
299
+ audio_data_mask = [0] * len(system_ids)
300
+
301
+ # Assemble conversations
302
+ for i, turn in enumerate(item["conversations"]):
303
+ if turn['from'] == 'assistant':
304
+ # task with audio token as input, prepare audio token
305
+ if task_type in ["chat_aiao"]:
306
+ if "audio" not in turn:
307
+ content = DEFAULT_ASSISTANT_TOKEN
308
+ content_ids = tokenizer.encode(content, add_special_tokens=False)
309
+ for n in range(num_codebooks):
310
+ input_ids_list[n] += copy.deepcopy(content_ids)
311
+ # preprocess audio_data & audio_data_mask
312
+ audio_data.append(_get_zero_audio_pad(len(content_ids)))
313
+ audio_data_mask += [0] * len(content_ids)
314
+ else:
315
+ # Load audio
316
+ if 'array' in turn['audio']:
317
+ assert "sr" in turn["audio"]
318
+ if len(turn["audio"]['array'].shape) > 1:
319
+ assert turn["audio"]['array'].shape[0] <= 2
320
+ turn["audio"]['array'] = librosa.to_mono(turn["audio"]['array'])
321
+ audio = librosa.resample(turn["audio"]['array'], orig_sr=turn["audio"]["sr"], target_sr=AUDIO_SR)
322
+ elif "bytes" in turn['audio']:
323
+ audio, _ = librosa.load(io.BytesIO(turn["audio"]['bytes']), sr=AUDIO_SR)
324
+ elif "file" in turn['audio']:
325
+ audio, _ = librosa.load(turn["audio"]['file'], sr=AUDIO_SR)
326
+ else:
327
+ raise Exception(f"No audio input for task {task_type}")
328
+
329
+ # get audio token
330
+ audio_token_num = int(len(audio) / SEG_LEN)
331
+ audio_token = [DEFAULT_AUDIO_TOKEN] * audio_token_num
332
+ audio_token = ''.join(audio_token)
333
+ audio = audio[:SEG_LEN*audio_token_num] # trim audio
334
+
335
+ content = DEFAULT_ASSISTANT_TOKEN + audio_token + tokenizer.eos_token
336
+ content_ids = tokenizer.encode(content, add_special_tokens=False, truncation=True,
337
+ max_length=tokenizer.model_max_length)
338
+ for n in range(num_codebooks):
339
+ input_ids_list[n] += copy.deepcopy(content_ids)
340
+
341
+ audio_data.append(_get_zero_audio_pad(1))
342
+ audio_data_mask += [0]
343
+ audio_data.append(audio)
344
+ audio_data_mask += [1] * audio_token_num
345
+ audio_data.append(_get_zero_audio_pad(1))
346
+ audio_data_mask += [0]
347
+ elif task_type in ["chat_tito"]:
348
+ if "text" not in turn:
349
+ content = DEFAULT_ASSISTANT_TOKEN
350
+ content_ids = tokenizer.encode(content, add_special_tokens=False)
351
+ for n in range(num_codebooks):
352
+ input_ids_list[n] += copy.deepcopy(content_ids)
353
+ # preprocess audio_data & audio_data_mask
354
+ audio_data.append(_get_zero_audio_pad(len(content_ids)))
355
+ audio_data_mask += [0] * len(content_ids)
356
+ else:
357
+ text = turn['text'].strip()
358
+ content = DEFAULT_ASSISTANT_TOKEN + text + tokenizer.eos_token
359
+ content_ids = tokenizer.encode(content, add_special_tokens=False, truncation=True,
360
+ max_length=tokenizer.model_max_length)
361
+ for n in range(num_codebooks):
362
+ input_ids_list[n] += copy.deepcopy(content_ids)
363
+ audio_data.append(_get_zero_audio_pad(len(content_ids)))
364
+ audio_data_mask += [0] * len(content_ids)
365
+ else:
366
+ raise ValueError (f"[Error] Invalid data type of {task_type}.")
367
+ else:
368
+ # task with audio token as input, prepare audio token
369
+ if task_type in ["chat_aiao"]:
370
+ # Load audio
371
+ assert "audio" in turn
372
+ if 'array' in turn['audio']:
373
+ assert "sr" in turn["audio"]
374
+ if len(turn["audio"]['array'].shape) > 1:
375
+ assert turn["audio"]['array'].shape[0] <= 2
376
+ turn["audio"]['array'] = librosa.to_mono(turn["audio"]['array'])
377
+ audio = librosa.resample(turn["audio"]['array'], orig_sr=turn["audio"]["sr"], target_sr=AUDIO_SR)
378
+ elif "bytes" in turn['audio']:
379
+ audio, _ = librosa.load(io.BytesIO(turn["audio"]['bytes']), sr=AUDIO_SR)
380
+ elif "file" in turn['audio']:
381
+ audio, _ = librosa.load(turn["audio"]['file'], sr=AUDIO_SR)
382
+ else:
383
+ raise Exception(f"No audio input for task {task_type}")
384
+
385
+ # get audio token
386
+ audio_token_num = int(len(audio) / SEG_LEN)
387
+ audio_token = [DEFAULT_AUDIO_TOKEN] * audio_token_num
388
+ audio_token = ''.join(audio_token)
389
+ audio = audio[:SEG_LEN*audio_token_num] # trim audio
390
+
391
+ content = DEFAULT_HUMAN_TOKEN + audio_token
392
+ content_ids = tokenizer.encode(content, add_special_tokens=False, truncation=True,
393
+ max_length=tokenizer.model_max_length)
394
+ for n in range(num_codebooks):
395
+ input_ids_list[n] += copy.deepcopy(content_ids)
396
+
397
+ audio_data.append(_get_zero_audio_pad(1))
398
+ audio_data_mask += [0]
399
+ audio_data.append(audio)
400
+ audio_data_mask += [1] * audio_token_num
401
+ elif task_type in ["chat_tito"]:
402
+ text = turn['text'].strip()
403
+ content = DEFAULT_HUMAN_TOKEN + text
404
+ content_ids = tokenizer.encode(content, add_special_tokens=False, truncation=True,
405
+ max_length=tokenizer.model_max_length)
406
+ for n in range(num_codebooks):
407
+ input_ids_list[n] += copy.deepcopy(content_ids)
408
+ audio_data.append(_get_zero_audio_pad(len(content_ids)))
409
+ audio_data_mask += [0] * len(content_ids)
410
+ else:
411
+ raise ValueError (f"[Error] Invalid data type of {task_type}.")
412
+
413
+ for n in range(num_codebooks):
414
+ input_ids_list[n] = input_ids_list[n][:tokenizer.model_max_length]
415
+ audio_data_mask = audio_data_mask[:tokenizer.model_max_length]
416
+ audio_data = np.concatenate(audio_data)
417
+ audio_data = audio_data[:PREPEND_LEN + tokenizer.model_max_length*SEG_LEN]
418
+
419
+ return input_ids_list, audio_data, audio_data_mask, None
420
+
421
+ # Item format
422
+ # {
423
+ # "instruction": "",
424
+ # "conversations": [
425
+ # {
426
+ # "from": "user" or "assistant",
427
+ # "text": "",
428
+ # "audio": {
429
+ # "array": [],
430
+ # "sr": 16000,
431
+ # "bytes": "",
432
+ # "file": "",
433
+ # },
434
+ # }
435
+ # ],
436
+ # }
437
+ def voila_input_format(item, tokenizer, tokenizer_voila, dataset_cfg):
438
+ if dataset_cfg["input_type"] == "audio":
439
+ return _alpha_audio_input_format(item, tokenizer, dataset_cfg)
440
+ elif dataset_cfg["input_type"] == "autonomous":
441
+ return _token_input_format_autonomous(item, tokenizer, tokenizer_voila, dataset_cfg)
442
+ else:
443
+ return _token_input_format(item, tokenizer, tokenizer_voila, dataset_cfg)
voila_tokenizer.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchaudio
3
+ from torchaudio.functional import resample
4
+
5
+ from transformers import AutoProcessor, EncodecModel
6
+
7
+
8
+ ALL_BANDWIDTHS = [1.1]
9
+
10
+ class VoilaTokenizer:
11
+ def __init__(
12
+ self,
13
+ model_path="maitrix-org/Voila-Tokenizer",
14
+ bandwidth_id=0,
15
+ device="cpu",
16
+ ):
17
+ self.device = torch.device(device)
18
+ self.bandwidth = ALL_BANDWIDTHS[bandwidth_id]
19
+ self.bandwidth_id = torch.tensor([bandwidth_id], device=device)
20
+
21
+ self.processor = AutoProcessor.from_pretrained(model_path)
22
+ self.model = EncodecModel.from_pretrained(model_path).to(device)
23
+
24
+ self.sampling_rate = self.processor.sampling_rate
25
+ self.model_version = self.model.config.model_version
26
+
27
+
28
+ @torch.no_grad()
29
+ def encode(self, wav, sr):
30
+ wav = torch.tensor(wav, dtype=torch.float32, device=self.device)
31
+ if sr != self.processor.sampling_rate:
32
+ wav = resample(wav, sr, self.processor.sampling_rate)
33
+ sr = self.processor.sampling_rate
34
+ if len(wav.shape) == 1:
35
+ wav = wav[None, None, :]
36
+ elif len(wav.shape) == 2:
37
+ assert wav.shape[0] == 1
38
+ wav = wav[None, :]
39
+ elif len(wav.shape) == 3:
40
+ assert wav.shape[0] == 1 and wav.shape[1] == 1
41
+
42
+ # inputs = self.processor(raw_audio=wav, sampling_rate=sr, return_tensors="pt")
43
+ encoder_outputs = self.model.encode(wav, bandwidth=self.bandwidth)
44
+ return encoder_outputs.audio_codes[0, 0]
45
+
46
+ @torch.no_grad()
47
+ def decode(self, audio_codes):
48
+ assert len(audio_codes.shape) == 2
49
+ audio_values = self.model.decode(audio_codes[None, None, :, :], [None])[0]
50
+ return audio_values[0, 0]
51
+
52
+ if __name__ == '__main__':
53
+ import argparse
54
+ import soundfile as sf
55
+
56
+ parser = argparse.ArgumentParser()
57
+ parser.add_argument("--wav", type=str)
58
+ args = parser.parse_args()
59
+
60
+ wav, sr = torchaudio.load(args.wav)
61
+ if len(wav.shape) > 1:
62
+ wav = wav[0]
63
+
64
+ model = VoilaTokenizer(device="cuda")
65
+
66
+ audio_codes = model.encode(wav, sr)
67
+ audio_values = model.decode(audio_codes).cpu().numpy()
68
+
69
+ tps = audio_codes.shape[-1] / (audio_values.shape[-1] / model.processor.sampling_rate)
70
+ print(audio_codes.shape, audio_values.shape, tps)
71
+ sf.write("audio_mt.wav", audio_values, model.processor.sampling_rate)