okezieowen commited on
Commit
9a4080b
·
verified ·
1 Parent(s): 4a46f69

Update handler.py

Browse files

Reverting back to non-streaming vLLM implementation

Files changed (1) hide show
  1. handler.py +271 -198
handler.py CHANGED
@@ -1,21 +1,16 @@
1
- import asyncio
2
- import torch
3
  import os
4
- import threading
5
- import queue
6
  import numpy as np
7
- import time
 
 
8
  import base64
9
  import io
10
  import wave
11
- import librosa
12
- import soundfile as sf
13
- import random
14
 
15
- from vllm import LLM, LLMEngine, EngineArgs, AsyncLLMEngine, AsyncEngineArgs, SamplingParams
16
- from vllm.sampling_params import RequestOutputKind
17
- from transformers import AutoTokenizer, pipeline
18
  from snac import SNAC
 
19
 
20
  class EndpointHandler:
21
  def __init__(self, path=""):
@@ -31,15 +26,9 @@ class EndpointHandler:
31
  self.END_OF_AI = 128262
32
  self.AUDIO_TOKENS_START = 128266
33
 
34
- self.engine_args = AsyncEngineArgs(
35
- model = "okezieowen/hypaai_orpheus",
36
- dtype = torch.bfloat16,
37
- max_model_len = 4096,
38
- gpu_memory_utilization = 0.5,
39
- )
40
- self.engine = AsyncLLMEngine.from_engine_args(self.engine_args)
41
-
42
- self.tokenizer = AutoTokenizer.from_pretrained("okezieowen/hypaai_orpheus")
43
 
44
  # Move to devices
45
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -52,201 +41,180 @@ class EndpointHandler:
52
  except Exception as e:
53
  raise RuntimeError(f"Failed to load SNAC model: {e}")
54
 
55
- def _format_prompt(self, prompt, voice="Eniola"):
56
- modified_prompt = f"{voice}: {prompt}"
57
- prompt_tokens = self.tokenizer(modified_prompt, return_tensors="pt").input_ids
58
- all_input_ids = torch.cat([
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  torch.tensor([[self.START_OF_HUMAN]], dtype=torch.int64),
60
  torch.tensor([[self.START_OF_TEXT]], dtype=torch.int64),
61
- prompt_tokens,
62
  torch.tensor([[self.END_OF_TEXT]], dtype=torch.int64),
63
- torch.tensor([[self.END_OF_HUMAN]], dtype=torch.int64),
 
 
 
 
64
  torch.tensor([[self.START_OF_AI]], dtype=torch.int64),
65
  torch.tensor([[self.START_OF_SPEECH]], dtype=torch.int64),
66
- ], dim = 1)
67
- prompt_string = self.tokenizer.decode(all_input_ids[0])
68
- return prompt_string
69
-
70
 
71
- def _convert_codes_to_audio_array(self, code_list):
72
  """
73
- Decode audio from SNAC codes
 
 
 
 
74
  """
75
- if len(code_list) < 7:
76
- raise ValueError(f"Audio codes must have at least 7 tokens.")
77
-
78
- layer_1 = [] # Coarsest layer
79
- layer_2 = [] # Intermediate layer
80
- layer_3 = [] # Finest layer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
 
82
- num_groups = len(code_list) // 7
 
 
83
 
84
- # Trim length to a multiple of 7
85
- n_codes = num_groups * 7
86
- code_list = code_list[:n_codes]
87
 
88
- for i in range(num_groups):
89
- idx = 7 * i
90
- layer_1.append(code_list[idx + 0])
91
- layer_2.append(code_list[idx + 1])
92
- layer_3.append(code_list[idx + 2])
93
- layer_3.append(code_list[idx + 3])
94
- layer_2.append(code_list[idx + 4])
95
- layer_3.append(code_list[idx + 5])
96
- layer_3.append(code_list[idx + 6])
97
 
98
- codes = [
99
- torch.tensor(layer_1).unsqueeze(0).to(self.device),
100
- torch.tensor(layer_2).unsqueeze(0).to(self.device),
101
- torch.tensor(layer_3).unsqueeze(0).to(self.device),
102
- ]
103
 
104
- for i, code in enumerate(codes):
105
- if torch.any(code < 0) or torch.any(code >= 4096):
106
- raise ValueError(f"Invalid code index in layer {i}: found value out of range [0, 4095]")
107
 
108
- # Decode audio
109
  with torch.inference_mode():
110
- audio_hat = self.snac_model.decode(codes)
111
- return audio_hat
112
-
113
- def _turn_token_into_id(self, token_string, index):
114
- # Strip whitespace
115
- token_string = token_string.strip()
116
-
117
- # Find the last token in the string
118
- last_token_start = token_string.rfind("<custom_token_")
119
-
120
- if last_token_start == -1:
121
- print("No token found in the string")
122
- return None
123
 
124
- # Extract the last token
125
- last_token = token_string[last_token_start:]
126
 
127
- # Process the last token
128
- if last_token.startswith("<custom_token_") and last_token.endswith(">"):
129
- try:
130
- number_str = last_token[14:-1]
131
- return int(number_str) - 10 - ((index % 7) * 4096)
132
- except ValueError:
133
- return None
134
- else:
135
- return None
136
-
137
- async def _generate_token(self, prompt_string, sampling_params, request_id):
138
- async for ro in self.engine.generate(prompt=prompt_string, sampling_params=sampling_params, request_id=request_id):
139
- token = ro.outputs[0].text
140
- yield token
141
-
142
- async def _generate_token_buffer(self, token_gen, audio_frame_width, audio_frame_overlap):
143
- last_emit = 0
144
- buffer = []
145
- count = 0
146
- hop_length = (audio_frame_width - audio_frame_overlap) * 7
147
- token_frame_width = audio_frame_width * 7
148
-
149
- async for token in token_gen:
150
- token_id = self._turn_token_into_id(token, count)
151
- if token_id is None:
152
- continue
153
-
154
- # Accept only token IDs in [0, 4095]
155
- if 0 <= token_id < 4096:
156
- buffer.append(token_id)
157
- count += 1
158
- else:
159
- continue
160
-
161
- while count - last_emit >= hop_length and count >= token_frame_width:
162
- buffer_to_process = buffer[-token_frame_width:]
163
- yield buffer_to_process
164
- last_emit += hop_length
165
-
166
- # After the vLLM engine finishes, yield any remaining tokens.
167
- if count > last_emit:
168
- # Pad the final buffer to be a multiple of 7 before yielding.
169
- remaining_len = len(buffer) % token_frame_width
170
- if remaining_len != 0:
171
- padding_needed = token_frame_width - remaining_len
172
- buffer.extend([0] * padding_needed)
173
-
174
- # Process and yield the final, potentially incomplete but padded buffer.
175
- buffer_to_process = buffer[-token_frame_width:]
176
- yield buffer_to_process
177
-
178
- async def _decode_tokens(self, token_buffer_generator):
179
- async for audio_token_buffer in token_buffer_generator:
180
- audio_samples = self._convert_codes_to_audio_array(audio_token_buffer)
181
- yield audio_samples
182
-
183
- async def _convert_audio_tensor_to_audio_numpy(self, audio_tensor_generator):
184
- async for audio_tensor in audio_tensor_generator:
185
- audio_numpy = audio_tensor.detach().squeeze().cpu().numpy()
186
-
187
- # # Convert float32 array to int16 for WAV format
188
- # audio_int16 = (audio_numpy * 32767).astype(np.int16)
189
-
190
- # # Write to WAV in memory (float32 or int16 depending on your preference)
191
- # buffer = io.BytesIO()
192
- # sf.write(buffer, audio_numpy, samplerate=24000, format='WAV', subtype='PCM_16') # or PCM_32
193
- # buffer.seek(0)
194
-
195
- # # Encode WAV bytes as base64
196
- # audio_b64 = base64.b64encode(buffer.read()).decode('utf-8')
197
-
198
- yield audio_numpy
199
-
200
- async def _generate_speech(self, prompt_string, sampling_params, request_id, audio_frame_width, audio_frame_overlap):
201
-
202
- # Step 1: Generate tokens from prompt
203
- token_gen = self._generate_token(
204
- prompt_string=prompt_string,
205
- sampling_params=sampling_params,
206
- request_id=request_id,
207
- )
208
-
209
- # Step 2 : Buffer tokens
210
- token_buffer_gen = self._generate_token_buffer(
211
- token_gen,
212
- audio_frame_width,
213
- audio_frame_overlap,
214
- )
215
 
216
- # Step 3: Decode to audio tensors
217
- audio_tensor_gen = self._decode_tokens(token_buffer_gen)
218
 
219
- # Step 4: Move tensors to CPU and convert to NumPy
220
- audio_numpy_gen = self._convert_audio_tensor_to_audio_numpy(audio_tensor_gen)
221
 
222
- # Step 5: Yield NumPy arrays
223
- async for audio_numpy in audio_numpy_gen:
224
- yield audio_numpy
225
 
226
- def preprocess(self, data):
227
 
228
- prompt = data.get("input", "You did not provide any prompt.")
229
- voice = data.get("voice", "Eniola")
230
-
231
  parameters = data.get("parameters", {})
 
232
 
233
  temperature = float(parameters.get("temperature", 0.6))
234
  top_p = float(parameters.get("top_p", 0.95))
235
  max_new_tokens = int(parameters.get("max_new_tokens", 1200))
236
  repetition_penalty = float(parameters.get("repetition_penalty", 1.1))
237
- audio_frame_width = int(parameters.get("audio_frame_width", 10))
238
- audio_frame_overlap = int(parameters.get("audio_frame_overlap", 5))
239
 
240
- formatted_prompt = self._format_prompt(prompt, voice)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
241
 
242
  return {
243
- "formatted_prompt": formatted_prompt,
244
  "temperature": temperature,
245
  "top_p": top_p,
246
  "max_new_tokens": max_new_tokens,
247
  "repetition_penalty": repetition_penalty,
248
- "audio_frame_width": audio_frame_width,
249
- "audio_frame_overlap": audio_frame_overlap,
250
  }
251
 
252
  def inference(self, inputs):
@@ -254,34 +222,139 @@ class EndpointHandler:
254
  Run model inference on the preprocessed inputs
255
  """
256
  # Extract parameters
257
- formatted_prompt = inputs["formatted_prompt"]
258
 
259
- self.sampling_params = SamplingParams(
260
  temperature = inputs["temperature"],
261
  top_p = inputs["top_p"],
262
- max_tokens = inputs["max_new_tokens"],
263
  repetition_penalty = inputs["repetition_penalty"],
264
  stop_token_ids = [self.END_OF_SPEECH],
265
  )
 
 
266
 
267
- speech_gen = self._generate_speech(
268
- prompt_string = formatted_prompt,
269
- sampling_params = self.sampling_params,
270
- request_id = str(random.randint(1000, 9999)),
271
- audio_frame_width = inputs["audio_frame_width"],
272
- audio_frame_overlap = inputs["audio_frame_overlap"],
273
- )
274
- return speech_gen
275
 
276
- # Main entry point for the handler
277
- async def __call__(self, data):
278
  try:
279
- preprocessed_inputs = self.preprocess(data)
280
- speech_gen = self.inference(preprocessed_inputs)
281
- async for response in speech_gen:
282
- yield response
 
 
 
 
 
 
 
 
 
283
 
284
  # Catch that error, baby
285
  except Exception as e:
286
  traceback.print_exc()
287
- yield {"error": str(e)}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ import torch
 
3
  import numpy as np
4
+ import librosa
5
+ import soundfile as sf
6
+ import traceback
7
  import base64
8
  import io
9
  import wave
 
 
 
10
 
11
+ from transformers import AutoModelForCausalLM, AutoTokenizer
 
 
12
  from snac import SNAC
13
+ from vllm import LLM, SamplingParams
14
 
15
  class EndpointHandler:
16
  def __init__(self, path=""):
 
26
  self.END_OF_AI = 128262
27
  self.AUDIO_TOKENS_START = 128266
28
 
29
+ # Load the models and tokenizer
30
+ self.model = LLM(path, max_model_len = 4096, gpu_memory_utilization = 0.3)
31
+ self.tokenizer = AutoTokenizer.from_pretrained(path)
 
 
 
 
 
 
32
 
33
  # Move to devices
34
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
 
41
  except Exception as e:
42
  raise RuntimeError(f"Failed to load SNAC model: {e}")
43
 
44
+ # Set up functions to format and encode text/audio
45
+ def encode_text(self, text):
46
+ return self.tokenizer.encode(text, return_tensors="pt", add_special_tokens=False)
47
+
48
+ def encode_audio(self, base64_audio_str):
49
+ audio_bytes = base64.b64decode(base64_audio_str)
50
+ audio_buffer = io.BytesIO(audio_bytes)
51
+ waveform, sr = sf.read(audio_buffer, dtype='float32')
52
+
53
+ if waveform.ndim > 1:
54
+ waveform = np.mean(waveform, axis=1)
55
+ if sr != 24000:
56
+ waveform = librosa.resample(waveform, orig_sr=sr, target_sr=24000)
57
+ return self.tokenize_audio(waveform)
58
+
59
+ def format_text_block(self, text_ids):
60
+ return [
61
  torch.tensor([[self.START_OF_HUMAN]], dtype=torch.int64),
62
  torch.tensor([[self.START_OF_TEXT]], dtype=torch.int64),
63
+ text_ids,
64
  torch.tensor([[self.END_OF_TEXT]], dtype=torch.int64),
65
+ torch.tensor([[self.END_OF_HUMAN]], dtype=torch.int64)
66
+ ]
67
+
68
+ def format_audio_block(self, audio_codes):
69
+ return [
70
  torch.tensor([[self.START_OF_AI]], dtype=torch.int64),
71
  torch.tensor([[self.START_OF_SPEECH]], dtype=torch.int64),
72
+ torch.tensor([audio_codes], dtype=torch.int64),
73
+ torch.tensor([[self.END_OF_SPEECH]], dtype=torch.int64),
74
+ torch.tensor([[self.END_OF_AI]], dtype=torch.int64)
75
+ ]
76
 
77
+ def enroll_user(self, enrollment_pairs):
78
  """
79
+ Parameters:
80
+ - enrollment_pairs: List of tuples (text, audio_data), where audio_data is
81
+ base64-encoded audio data
82
+ Returns:
83
+ - cloning_features (str): serialized enrollment data
84
  """
85
+ enrollment_data = []
86
+
87
+ for text, base64_audio in enrollment_pairs:
88
+ text_ids = self.encode_text(text).cpu()
89
+ audio_codes = self.encode_audio(base64_audio)
90
+ enrollment_data.append({
91
+ "text_ids": text_ids,
92
+ "audio_codes": audio_codes
93
+ })
94
+
95
+ # Serialize enrollment data
96
+ buffer = io.BytesIO()
97
+ torch.save(enrollment_data, buffer)
98
+ buffer.seek(0)
99
+
100
+ # Encode as base64 string and assign to attribute
101
+ cloning_features = base64.b64encode(buffer.read()).decode('utf-8')
102
+ return cloning_features
103
+
104
+ def prepare_audio_tokens_for_decoder(self, audio_codes_list):
105
+ """
106
+ Given a list containing sequences of generated audio codes, do the following:
107
+ 1. Trim length to a multiple of 7 (SNAC decoder requires 7 tokens per audio frame)
108
+ 2. Adjust token values to SNAC decoder's expected range
109
+ """
110
+ modified_audio_codes_list = []
111
+ for audio_codes in audio_codes_list:
112
 
113
+ # Trim length to a multiple of 7
114
+ length = (audio_codes.size(0) // 7) * 7
115
+ trimmed = audio_codes[:length]
116
 
117
+ # Adjust token values to SNAC decoder's expected range
118
+ audio_codes = trimmed - self.AUDIO_TOKENS_START
 
119
 
120
+ # Add modified audio codes to list
121
+ modified_audio_codes_list.append(audio_codes)
 
 
 
 
 
 
 
122
 
123
+ return modified_audio_codes_list
 
 
 
 
124
 
125
+ # Convert audio sample to codes and reconstruct
126
+ def tokenize_audio(self, waveform):
127
+ waveform = torch.from_numpy(waveform).unsqueeze(0).unsqueeze(0).to(self.device)
128
 
 
129
  with torch.inference_mode():
130
+ codes = self.snac_model.encode(waveform)
 
 
 
 
 
 
 
 
 
 
 
 
131
 
132
+ all_codes = []
133
+ for i in range(codes[0].shape[1]):
134
 
135
+ all_codes.append(codes[0][0][(1 * i) + 0].item() + self.AUDIO_TOKENS_START + (0 * 4096))
136
+ all_codes.append(codes[1][0][(2 * i) + 0].item() + self.AUDIO_TOKENS_START + (1 * 4096))
137
+ all_codes.append(codes[2][0][(4 * i) + 0].item() + self.AUDIO_TOKENS_START + (2 * 4096))
138
+ all_codes.append(codes[2][0][(4 * i) + 1].item() + self.AUDIO_TOKENS_START + (3 * 4096))
139
+ all_codes.append(codes[1][0][(2 * i) + 1].item() + self.AUDIO_TOKENS_START + (4 * 4096))
140
+ all_codes.append(codes[2][0][(4 * i) + 2].item() + self.AUDIO_TOKENS_START + (5 * 4096))
141
+ all_codes.append(codes[2][0][(4 * i) + 3].item() + self.AUDIO_TOKENS_START + (6 * 4096))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
 
143
+ return all_codes
 
144
 
145
+ def preprocess(self, data):
 
146
 
147
+ # Preprocess input data before inference
 
 
148
 
149
+ self.voice_cloning = data.get("clone", False)
150
 
151
+ # Extract parameters from request
152
+ target_text = data["inputs"]
 
153
  parameters = data.get("parameters", {})
154
+ cloning_features = data.get("cloning_features", None)
155
 
156
  temperature = float(parameters.get("temperature", 0.6))
157
  top_p = float(parameters.get("top_p", 0.95))
158
  max_new_tokens = int(parameters.get("max_new_tokens", 1200))
159
  repetition_penalty = float(parameters.get("repetition_penalty", 1.1))
 
 
160
 
161
+ if self.voice_cloning:
162
+ """Handle voice cloning using cloning features"""
163
+
164
+ if not cloning_features:
165
+ raise ValueError("No cloning features were provided")
166
+ else:
167
+ # Decode back into tensors
168
+ enrollment_data = torch.load(io.BytesIO(base64.b64decode(cloning_features)))
169
+
170
+ # Process pre-tokenized enrollment_data
171
+ input_sequence = []
172
+ for item in enrollment_data:
173
+ text_ids = item["text_ids"]
174
+ audio_codes = item["audio_codes"]
175
+ input_sequence.extend(self.format_text_block(text_ids))
176
+ input_sequence.extend(self.format_audio_block(audio_codes))
177
+
178
+ # Append target text whose audio we want
179
+ target_text_ids = self.encode_text(target_text)
180
+ input_sequence.extend(self.format_text_block(target_text_ids))
181
+
182
+ # Start of target audio - audio codes to be completed by model
183
+ input_sequence.extend([
184
+ torch.tensor([[self.START_OF_AI]], dtype=torch.int64),
185
+ torch.tensor([[self.START_OF_SPEECH]], dtype=torch.int64)
186
+ ])
187
+
188
+ # Final input tensor
189
+ input_ids = torch.cat(input_sequence, dim=1)
190
+
191
+ # Heuristic to determine max_new_tokens based on empirical relationship
192
+ # between the length of the prompt ids and the length of the generated ids
193
+ prompt_ids = self.encode_text(target_text)
194
+ max_new_tokens = int(prompt_ids.size()[1] * 20 + 200)
195
+
196
+ input_ids = input_ids.to(self.device)
197
+
198
+ else:
199
+ # Handle standard text-to-speech
200
+
201
+ # Extract parameters from request
202
+ voice = parameters.get("voice", "Eniola")
203
+ prompt = f"{voice}: {target_text}"
204
+ input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
205
+
206
+ # Add special tokens
207
+ input_ids = torch.cat(self.format_text_block(input_ids), dim=1)
208
+
209
+ # No need for padding as we're processing a single sequence
210
+ input_ids = input_ids.to(self.device)
211
 
212
  return {
213
+ "input_ids": input_ids,
214
  "temperature": temperature,
215
  "top_p": top_p,
216
  "max_new_tokens": max_new_tokens,
217
  "repetition_penalty": repetition_penalty,
 
 
218
  }
219
 
220
  def inference(self, inputs):
 
222
  Run model inference on the preprocessed inputs
223
  """
224
  # Extract parameters
225
+ input_ids = inputs["input_ids"]
226
 
227
+ sampling_params = SamplingParams(
228
  temperature = inputs["temperature"],
229
  top_p = inputs["top_p"],
230
+ max_tokens = inputs["max_new_tokens"],
231
  repetition_penalty = inputs["repetition_penalty"],
232
  stop_token_ids = [self.END_OF_SPEECH],
233
  )
234
+
235
+ prompt_string = self.tokenizer.decode(input_ids[0])
236
 
237
+ # Forward pass through the model
238
+ generated_ids = self.model.generate(prompt_string, sampling_params)
239
+
240
+ return torch.tensor(generated_ids[0].outputs[0].token_ids).unsqueeze(0)
241
+
242
+ def __call__(self, data):
243
+
244
+ # Main entry point for the handler
245
 
 
 
246
  try:
247
+ enroll_user = data.get("enroll_user", False)
248
+
249
+ if enroll_user:
250
+ # We extract cloning features for enrollment
251
+ enrollment_pairs = data.get("enrollments", [])
252
+ cloning_features = self.enroll_user(enrollment_pairs)
253
+ return {"cloning_features": cloning_features}
254
+ else:
255
+ # We want to generate speech using preset cloning features
256
+ preprocessed_inputs = self.preprocess(data)
257
+ model_outputs = self.inference(preprocessed_inputs)
258
+ response = self.postprocess(model_outputs)
259
+ return response
260
 
261
  # Catch that error, baby
262
  except Exception as e:
263
  traceback.print_exc()
264
+ return {"error": str(e)}
265
+
266
+ # Postprocess generated ids
267
+ def convert_codes_to_waveform(self, code_list):
268
+ """
269
+ Reorganize tokens for SNAC decoding
270
+ """
271
+ layer_1 = [] # Coarsest layer
272
+ layer_2 = [] # Intermediate layer
273
+ layer_3 = [] # Finest layer
274
+
275
+ num_groups = len(code_list) // 7
276
+ for i in range(num_groups):
277
+ idx = 7 * i
278
+ layer_1.append(code_list[7 * i + 0] - (0 * 4096))
279
+ layer_2.append(code_list[7 * i + 1] - (1 * 4096))
280
+ layer_3.append(code_list[7 * i + 2] - (2 * 4096))
281
+ layer_3.append(code_list[7 * i + 3] - (3 * 4096))
282
+ layer_2.append(code_list[7 * i + 4] - (4 * 4096))
283
+ layer_3.append(code_list[7 * i + 5] - (5 * 4096))
284
+ layer_3.append(code_list[7 * i + 6] - (6 * 4096))
285
+
286
+ codes = [
287
+ torch.tensor(layer_1).unsqueeze(0).to(self.device),
288
+ torch.tensor(layer_2).unsqueeze(0).to(self.device),
289
+ torch.tensor(layer_3).unsqueeze(0).to(self.device)
290
+ ]
291
+
292
+ # Decode audio
293
+ audio_hat = self.snac_model.decode(codes)
294
+ return audio_hat
295
+
296
+ def postprocess(self, generated_ids):
297
+
298
+ if self.voice_cloning:
299
+ """
300
+ For cloning applications, use this postprocess function to get generated audio samples
301
+ """
302
+ # Modify audio codes to be digestible byb SNAC decoder
303
+ code_lists = self.prepare_audio_tokens_for_decoder(generated_ids)
304
+
305
+ # Generate audio from codes
306
+ temp = self.convert_codes_to_waveform(code_lists[0])
307
+ audio_sample = temp.detach().squeeze().to("cpu").numpy()
308
+
309
+ else:
310
+ """
311
+ Process generated tokens into audio
312
+ """
313
+ # Find Start of Audio token
314
+ token_indices = (generated_ids == self.START_OF_SPEECH).nonzero(as_tuple=True)
315
+
316
+ if len(token_indices[1]) > 0:
317
+ last_occurrence_idx = token_indices[1][-1].item()
318
+ cropped_tensor = generated_ids[:, last_occurrence_idx+1:]
319
+ else:
320
+ cropped_tensor = generated_ids
321
+
322
+ # Remove End of Audio tokens
323
+ processed_rows = []
324
+ for row in cropped_tensor:
325
+ masked_row = row[row != self.END_OF_SPEECH]
326
+ processed_rows.append(masked_row)
327
+
328
+ code_lists = self.prepare_audio_tokens_for_decoder(processed_rows)
329
+
330
+ # Generate audio from codes
331
+ audio_samples = []
332
+ for code_list in code_lists:
333
+ if len(code_list) > 0:
334
+ audio = self.convert_codes_to_waveform(code_list)
335
+ audio_samples.append(audio)
336
+ else:
337
+ raise ValueError("Empty code list, no audio to generate")
338
+
339
+ if not audio_samples:
340
+ return {"error": "No audio samples generated"}
341
+
342
+ # Return first (and only) audio sample
343
+ audio_sample = audio_samples[0].detach().squeeze().cpu().numpy()
344
+
345
+ # Convert float32 array to int16 for WAV format
346
+ audio_int16 = (audio_sample * 32767).astype(np.int16)
347
+
348
+ # Write to WAV in memory (float32 or int16 depending on your preference)
349
+ buffer = io.BytesIO()
350
+ sf.write(buffer, audio_sample, samplerate=24000, format='WAV', subtype='PCM_16') # or PCM_32
351
+ buffer.seek(0)
352
+
353
+ # Encode WAV bytes as base64
354
+ audio_b64 = base64.b64encode(buffer.read()).decode('utf-8')
355
+
356
+ return {
357
+ "audio_sample": audio_sample,
358
+ "audio_b64": audio_b64,
359
+ "sample_rate": 24000,
360
+ }