import os import torch import numpy as np import librosa import soundfile as sf import traceback import base64 import io import wave from transformers import AutoModelForCausalLM, AutoTokenizer from snac import SNAC from vllm import LLM, SamplingParams class EndpointHandler: def __init__(self, path=""): # Delimiter tokens as defined in Orpheus' vocabulary self.START_OF_HUMAN = 128259 self.START_OF_TEXT = 128000 self.END_OF_TEXT = 128009 self.END_OF_HUMAN = 128260 self.START_OF_AI = 128261 self.START_OF_SPEECH = 128257 self.END_OF_SPEECH = 128258 self.END_OF_AI = 128262 self.AUDIO_TOKENS_START = 128266 # Load the models and tokenizer self.model = LLM(path, max_model_len = 4096, gpu_memory_utilization = 0.3) self.tokenizer = AutoTokenizer.from_pretrained(path) # Move to devices self.device = "cuda" if torch.cuda.is_available() else "cpu" # self.model.to(self.device) # Load SNAC model for audio decoding try: self.snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz") self.snac_model.to(self.device) except Exception as e: raise RuntimeError(f"Failed to load SNAC model: {e}") # Set up functions to format and encode text/audio def encode_text(self, text): return self.tokenizer.encode(text, return_tensors="pt", add_special_tokens=False) def encode_audio(self, base64_audio_str): audio_bytes = base64.b64decode(base64_audio_str) audio_buffer = io.BytesIO(audio_bytes) waveform, sr = sf.read(audio_buffer, dtype='float32') if waveform.ndim > 1: waveform = np.mean(waveform, axis=1) if sr != 24000: waveform = librosa.resample(waveform, orig_sr=sr, target_sr=24000) return self.tokenize_audio(waveform) def format_text_block(self, text_ids): return [ torch.tensor([[self.START_OF_HUMAN]], dtype=torch.int64), torch.tensor([[self.START_OF_TEXT]], dtype=torch.int64), text_ids, torch.tensor([[self.END_OF_TEXT]], dtype=torch.int64), torch.tensor([[self.END_OF_HUMAN]], dtype=torch.int64) ] def format_audio_block(self, audio_codes): return [ torch.tensor([[self.START_OF_AI]], dtype=torch.int64), torch.tensor([[self.START_OF_SPEECH]], dtype=torch.int64), torch.tensor([audio_codes], dtype=torch.int64), torch.tensor([[self.END_OF_SPEECH]], dtype=torch.int64), torch.tensor([[self.END_OF_AI]], dtype=torch.int64) ] def enroll_user(self, enrollment_pairs): """ Parameters: - enrollment_pairs: List of tuples (text, audio_data), where audio_data is base64-encoded audio data Returns: - cloning_features (str): serialized enrollment data """ enrollment_data = [] for text, base64_audio in enrollment_pairs: text_ids = self.encode_text(text).cpu() audio_codes = self.encode_audio(base64_audio) enrollment_data.append({ "text_ids": text_ids, "audio_codes": audio_codes }) # Serialize enrollment data buffer = io.BytesIO() torch.save(enrollment_data, buffer) buffer.seek(0) # Encode as base64 string and assign to attribute cloning_features = base64.b64encode(buffer.read()).decode('utf-8') return cloning_features def prepare_audio_tokens_for_decoder(self, audio_codes_list): """ Given a list containing sequences of generated audio codes, do the following: 1. Trim length to a multiple of 7 (SNAC decoder requires 7 tokens per audio frame) 2. Adjust token values to SNAC decoder's expected range """ modified_audio_codes_list = [] for audio_codes in audio_codes_list: # Trim length to a multiple of 7 length = (audio_codes.size(0) // 7) * 7 trimmed = audio_codes[:length] # Adjust token values to SNAC decoder's expected range audio_codes = trimmed - self.AUDIO_TOKENS_START # Add modified audio codes to list modified_audio_codes_list.append(audio_codes) return modified_audio_codes_list # Convert audio sample to codes and reconstruct def tokenize_audio(self, waveform): waveform = torch.from_numpy(waveform).unsqueeze(0).unsqueeze(0).to(self.device) with torch.inference_mode(): codes = self.snac_model.encode(waveform) all_codes = [] for i in range(codes[0].shape[1]): all_codes.append(codes[0][0][(1 * i) + 0].item() + self.AUDIO_TOKENS_START + (0 * 4096)) all_codes.append(codes[1][0][(2 * i) + 0].item() + self.AUDIO_TOKENS_START + (1 * 4096)) all_codes.append(codes[2][0][(4 * i) + 0].item() + self.AUDIO_TOKENS_START + (2 * 4096)) all_codes.append(codes[2][0][(4 * i) + 1].item() + self.AUDIO_TOKENS_START + (3 * 4096)) all_codes.append(codes[1][0][(2 * i) + 1].item() + self.AUDIO_TOKENS_START + (4 * 4096)) all_codes.append(codes[2][0][(4 * i) + 2].item() + self.AUDIO_TOKENS_START + (5 * 4096)) all_codes.append(codes[2][0][(4 * i) + 3].item() + self.AUDIO_TOKENS_START + (6 * 4096)) return all_codes def preprocess(self, data): # Preprocess input data before inference self.voice_cloning = data.get("clone", False) # Extract parameters from request target_text = data["inputs"] parameters = data.get("parameters", {}) cloning_features = data.get("cloning_features", None) temperature = float(parameters.get("temperature", 0.6)) top_p = float(parameters.get("top_p", 0.95)) max_new_tokens = int(parameters.get("max_new_tokens", 1200)) repetition_penalty = float(parameters.get("repetition_penalty", 1.1)) if self.voice_cloning: """Handle voice cloning using cloning features""" if not cloning_features: raise ValueError("No cloning features were provided") else: # Decode back into tensors enrollment_data = torch.load(io.BytesIO(base64.b64decode(cloning_features))) # Process pre-tokenized enrollment_data input_sequence = [] for item in enrollment_data: text_ids = item["text_ids"] audio_codes = item["audio_codes"] input_sequence.extend(self.format_text_block(text_ids)) input_sequence.extend(self.format_audio_block(audio_codes)) # Append target text whose audio we want target_text_ids = self.encode_text(target_text) input_sequence.extend(self.format_text_block(target_text_ids)) # Start of target audio - audio codes to be completed by model input_sequence.extend([ torch.tensor([[self.START_OF_AI]], dtype=torch.int64), torch.tensor([[self.START_OF_SPEECH]], dtype=torch.int64) ]) # Final input tensor input_ids = torch.cat(input_sequence, dim=1) # Heuristic to determine max_new_tokens based on empirical relationship # between the length of the prompt ids and the length of the generated ids prompt_ids = self.encode_text(target_text) max_new_tokens = int(prompt_ids.size()[1] * 20 + 200) input_ids = input_ids.to(self.device) else: # Handle standard text-to-speech # Extract parameters from request voice = parameters.get("voice", "Eniola") prompt = f"{voice}: {target_text}" input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids # Add special tokens input_ids = torch.cat(self.format_text_block(input_ids), dim=1) # No need for padding as we're processing a single sequence input_ids = input_ids.to(self.device) return { "input_ids": input_ids, "temperature": temperature, "top_p": top_p, "max_new_tokens": max_new_tokens, "repetition_penalty": repetition_penalty, } def inference(self, inputs): """ Run model inference on the preprocessed inputs """ # Extract parameters input_ids = inputs["input_ids"] sampling_params = SamplingParams( temperature = inputs["temperature"], top_p = inputs["top_p"], max_tokens = inputs["max_new_tokens"], repetition_penalty = inputs["repetition_penalty"], stop_token_ids = [self.END_OF_SPEECH], ) prompt_string = self.tokenizer.decode(input_ids[0]) # Forward pass through the model generated_ids = self.model.generate(prompt_string, sampling_params) return torch.tensor(generated_ids[0].outputs[0].token_ids).unsqueeze(0) def __call__(self, data): # Main entry point for the handler try: enroll_user = data.get("enroll_user", False) if enroll_user: # We extract cloning features for enrollment enrollment_pairs = data.get("enrollments", []) cloning_features = self.enroll_user(enrollment_pairs) return {"cloning_features": cloning_features} else: # We want to generate speech using preset cloning features preprocessed_inputs = self.preprocess(data) model_outputs = self.inference(preprocessed_inputs) response = self.postprocess(model_outputs) return response # Catch that error, baby except Exception as e: traceback.print_exc() return {"error": str(e)} # Postprocess generated ids def convert_codes_to_waveform(self, code_list): """ Reorganize tokens for SNAC decoding """ layer_1 = [] # Coarsest layer layer_2 = [] # Intermediate layer layer_3 = [] # Finest layer num_groups = len(code_list) // 7 for i in range(num_groups): idx = 7 * i layer_1.append(code_list[7 * i + 0] - (0 * 4096)) layer_2.append(code_list[7 * i + 1] - (1 * 4096)) layer_3.append(code_list[7 * i + 2] - (2 * 4096)) layer_3.append(code_list[7 * i + 3] - (3 * 4096)) layer_2.append(code_list[7 * i + 4] - (4 * 4096)) layer_3.append(code_list[7 * i + 5] - (5 * 4096)) layer_3.append(code_list[7 * i + 6] - (6 * 4096)) codes = [ torch.tensor(layer_1).unsqueeze(0).to(self.device), torch.tensor(layer_2).unsqueeze(0).to(self.device), torch.tensor(layer_3).unsqueeze(0).to(self.device) ] # Decode audio audio_hat = self.snac_model.decode(codes) return audio_hat def postprocess(self, generated_ids): if self.voice_cloning: """ For cloning applications, use this postprocess function to get generated audio samples """ # Modify audio codes to be digestible byb SNAC decoder code_lists = self.prepare_audio_tokens_for_decoder(generated_ids) # Generate audio from codes temp = self.convert_codes_to_waveform(code_lists[0]) audio_sample = temp.detach().squeeze().to("cpu").numpy() else: """ Process generated tokens into audio """ # Find Start of Audio token token_indices = (generated_ids == self.START_OF_SPEECH).nonzero(as_tuple=True) if len(token_indices[1]) > 0: last_occurrence_idx = token_indices[1][-1].item() cropped_tensor = generated_ids[:, last_occurrence_idx+1:] else: cropped_tensor = generated_ids # Remove End of Audio tokens processed_rows = [] for row in cropped_tensor: masked_row = row[row != self.END_OF_SPEECH] processed_rows.append(masked_row) code_lists = self.prepare_audio_tokens_for_decoder(processed_rows) # Generate audio from codes audio_samples = [] for code_list in code_lists: if len(code_list) > 0: audio = self.convert_codes_to_waveform(code_list) audio_samples.append(audio) else: raise ValueError("Empty code list, no audio to generate") if not audio_samples: return {"error": "No audio samples generated"} # Return first (and only) audio sample audio_sample = audio_samples[0].detach().squeeze().cpu().numpy() # Convert float32 array to int16 for WAV format audio_int16 = (audio_sample * 32767).astype(np.int16) # Write to WAV in memory (float32 or int16 depending on your preference) buffer = io.BytesIO() sf.write(buffer, audio_sample, samplerate=24000, format='WAV', subtype='PCM_16') # or PCM_32 buffer.seek(0) # Encode WAV bytes as base64 audio_b64 = base64.b64encode(buffer.read()).decode('utf-8') return { "audio_sample": audio_sample, "audio_b64": audio_b64, "sample_rate": 24000, }