from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline import torch import soundfile as sf from xcodec2.modeling_xcodec2 import XCodec2Model import torchaudio import gradio as gr import tempfile import os api_key = os.getenv("HF_TOKEN") llasa_3b = 'HKUSTAudio/Llasa-8B' tokenizer = AutoTokenizer.from_pretrained(llasa_3b) device = "cuda" if torch.cuda.is_available() else "cpu" model = AutoModelForCausalLM.from_pretrained( llasa_3b, trust_remote_code=True, torch_dtype=torch.float16 if device == "cuda" else torch.bfloat16, device_map="auto", ) model_path = "srinivasbilla/xcodec2" Codec_model = XCodec2Model.from_pretrained(model_path) Codec_model.eval().to(device) whisper_turbo_pipe = pipeline( "automatic-speech-recognition", model="openai/whisper-large-v3-turbo", torch_dtype=torch.float16 if torch.cuda.is_available() else torch.bfloat16, device="cuda" if torch.cuda.is_available() else "cpu", ) def ids_to_speech_tokens(speech_ids): return [f"<|s_{speech_id}|>" for speech_id in speech_ids] def extract_speech_ids(speech_tokens_str): speech_ids = [] for token_str in speech_tokens_str: try: if token_str.startswith('<|s_') and token_str.endswith('|>'): speech_ids.append(int(token_str[4:-2])) else: raise ValueError(f"Unexpected token format: {token_str}") except ValueError as e: print(f"Error parsing speech token: {e}") return speech_ids def infer(sample_audio_path, target_text, progress=gr.Progress()): global tokenizer if tokenizer is None: print("Warning: Tokenizer is missing, reloading...") tokenizer = AutoTokenizer.from_pretrained("llasa_3b") with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f: progress(0.2, 'Loading audio...') waveform, sample_rate = torchaudio.load(sample_audio_path) # ✅ Trim audio for compatibility if waveform.size(1) / sample_rate > 60: waveform = waveform[:, :sample_rate * 60] progress(0.4, 'Trimming audio...') if waveform.shape[1] / sample_rate > 30: waveform = waveform[:, :sample_rate * 30] if waveform.size(0) > 1: waveform_mono = torch.mean(waveform, dim=0, keepdim=True) else: waveform_mono = waveform waveform_mono = waveform_mono.to("cuda" if torch.cuda.is_available() else "cpu") prompt_wav = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform_mono) prompt_wav_np = prompt_wav[0].detach().cpu().numpy() progress(0.6, 'Transcribing speech...') try: result = whisper_turbo_pipe(input_features=prompt_wav_np, generate_kwargs={"language": "en"}) prompt_text = result['text'].strip() except Exception as e: print(f"Whisper ASR failed. Retrying... Error: {e}") try: result = whisper_turbo_pipe(input_features=prompt_wav_np, generate_kwargs={"language": "en"}) prompt_text = result['text'].strip() except Exception as retry_e: print(f"Retry also failed: {retry_e}") prompt_text = target_text if not prompt_text or prompt_text.lower() in ["error: unable to transcribe", ""]: prompt_text = target_text progress(0.8, 'Generating synthesized audio...') if len(target_text) == 0: return None elif len(target_text) > 500: target_text = target_text[:500] print("Text truncated to 500 characters.") input_text = " ".join(filter(None, [prompt_text.strip(), target_text.strip()])) with torch.no_grad(): vq_code_prompt = Codec_model.encode_code(prompt_wav) vq_code_prompt = vq_code_prompt[0,0,:] speech_ids_prefix = ids_to_speech_tokens(vq_code_prompt) formatted_text = f"<|TEXT_UNDERSTANDING_START|>{input_text}<|TEXT_UNDERSTANDING_END|>" chat = [ {"role": "user", "content": "Convert the text to speech:" + formatted_text}, {"role": "assistant", "content": "<|SPEECH_GENERATION_START|>" + ''.join(speech_ids_prefix)} ] input_ids = tokenizer.apply_chat_template(chat, tokenize=True, return_tensors='pt', continue_final_message=True) input_ids = input_ids.to("cuda" if torch.cuda.is_available() else "cpu") speech_end_id = tokenizer.convert_tokens_to_ids('<|SPEECH_GENERATION_END|>') if speech_end_id is None: raise ValueError("Error: `<|SPEECH_GENERATION_END|>` token not found!") outputs = model.generate( input_ids, max_length=2048, eos_token_id=speech_end_id, pad_token_id=tokenizer.eos_token_id, attention_mask=input_ids.ne(tokenizer.pad_token_id), do_sample=True, top_p=0.9, temperature=0.7, ) generated_ids = outputs[0][input_ids.shape[1] - len(speech_ids_prefix):-1] speech_tokens = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) speech_tokens = extract_speech_ids(speech_tokens) if not speech_tokens: raise ValueError("Error: No valid speech tokens extracted!") else: speech_tokens = torch.tensor(speech_tokens).to("cuda" if torch.cuda.is_available() else "cpu").unsqueeze(0).unsqueeze(0) gen_wav = Codec_model.decode_code(speech_tokens) gen_wav = gen_wav[:,:,prompt_wav.shape[1]:] progress(1.0, 'Complete!') return (16000, gen_wav[0, 0, :].cpu().numpy()) with gr.Blocks() as app_tts: gr.Markdown("# Zero Shot Voice Clone TTS") ref_audio_input = gr.Audio(label="Reference Audio", type="filepath") gen_text_input = gr.Textbox(label="Text to Generate", lines=10) generate_btn = gr.Button("Synthesize", variant="primary") audio_output = gr.Audio(label="Synthesized Audio") generate_btn.click(infer, inputs=[ref_audio_input, gen_text_input], outputs=[audio_output]) with gr.Blocks() as app: gr.TabbedInterface([app_tts], ["TTS"]) app.launch(debug=False, ssr_mode=False)