Update app.py
Browse files
app.py
CHANGED
|
@@ -26,9 +26,8 @@ Codec_model.eval().to(device)
|
|
| 26 |
whisper_turbo_pipe = pipeline(
|
| 27 |
"automatic-speech-recognition",
|
| 28 |
model="openai/whisper-large-v3-turbo",
|
| 29 |
-
torch_dtype=torch.float16 if
|
| 30 |
-
device=
|
| 31 |
-
|
| 32 |
)
|
| 33 |
|
| 34 |
def ids_to_speech_tokens(speech_ids):
|
|
@@ -47,56 +46,54 @@ def extract_speech_ids(speech_tokens_str):
|
|
| 47 |
return speech_ids
|
| 48 |
|
| 49 |
def infer(sample_audio_path, target_text, progress=gr.Progress()):
|
| 50 |
-
global tokenizer
|
| 51 |
|
| 52 |
if tokenizer is None:
|
| 53 |
print("Warning: Tokenizer is missing, reloading...")
|
| 54 |
-
tokenizer = AutoTokenizer.from_pretrained(llasa_3b)
|
| 55 |
|
| 56 |
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
|
| 57 |
progress(0.2, 'Loading audio...')
|
| 58 |
waveform, sample_rate = torchaudio.load(sample_audio_path)
|
| 59 |
|
| 60 |
-
|
| 61 |
-
|
| 62 |
waveform = waveform[:, :sample_rate * 60]
|
| 63 |
|
| 64 |
progress(0.4, 'Trimming audio...')
|
| 65 |
if waveform.shape[1] / sample_rate > 30:
|
| 66 |
-
print("Trimming audio to 30 seconds for Whisper ASR.")
|
| 67 |
waveform = waveform[:, :sample_rate * 30]
|
| 68 |
|
| 69 |
if waveform.size(0) > 1:
|
| 70 |
waveform_mono = torch.mean(waveform, dim=0, keepdim=True)
|
| 71 |
else:
|
| 72 |
waveform_mono = waveform
|
| 73 |
-
waveform_mono = waveform_mono.to(
|
| 74 |
|
| 75 |
prompt_wav = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform_mono)
|
| 76 |
prompt_wav_np = prompt_wav[0].detach().cpu().numpy()
|
| 77 |
|
| 78 |
progress(0.6, 'Transcribing speech...')
|
| 79 |
try:
|
| 80 |
-
prompt_text = whisper_turbo_pipe(prompt_wav_np)['text'].strip() # ✅
|
| 81 |
except Exception:
|
| 82 |
-
print("Whisper ASR failed. Retrying
|
| 83 |
-
prompt_text = whisper_turbo_pipe(prompt_wav_np
|
| 84 |
|
| 85 |
if not prompt_text or prompt_text.lower() in ["error: unable to transcribe", ""]:
|
| 86 |
-
print("Warning: Whisper ASR output is empty. Defaulting to target text.")
|
| 87 |
prompt_text = target_text
|
| 88 |
|
| 89 |
progress(0.8, 'Generating synthesized audio...')
|
| 90 |
-
|
| 91 |
if len(target_text) == 0:
|
| 92 |
return None
|
| 93 |
elif len(target_text) > 500:
|
| 94 |
target_text = target_text[:500]
|
| 95 |
-
print("Text
|
| 96 |
|
| 97 |
input_text = " ".join(filter(None, [prompt_text.strip(), target_text.strip()]))
|
|
|
|
| 98 |
with torch.no_grad():
|
| 99 |
-
vq_code_prompt = Codec_model.encode_code(
|
| 100 |
vq_code_prompt = vq_code_prompt[0,0,:]
|
| 101 |
speech_ids_prefix = ids_to_speech_tokens(vq_code_prompt)
|
| 102 |
|
|
@@ -108,11 +105,11 @@ def infer(sample_audio_path, target_text, progress=gr.Progress()):
|
|
| 108 |
]
|
| 109 |
|
| 110 |
input_ids = tokenizer.apply_chat_template(chat, tokenize=True, return_tensors='pt', continue_final_message=True)
|
| 111 |
-
input_ids = input_ids.to(
|
| 112 |
speech_end_id = tokenizer.convert_tokens_to_ids('<|SPEECH_GENERATION_END|>')
|
| 113 |
|
| 114 |
if speech_end_id is None:
|
| 115 |
-
raise ValueError("Error: `<|SPEECH_GENERATION_END|>` token not found
|
| 116 |
|
| 117 |
outputs = model.generate(
|
| 118 |
input_ids,
|
|
@@ -125,19 +122,14 @@ def infer(sample_audio_path, target_text, progress=gr.Progress()):
|
|
| 125 |
temperature=0.7,
|
| 126 |
)
|
| 127 |
|
| 128 |
-
|
| 129 |
-
print("Warning: Generated output is shorter than expected.")
|
| 130 |
-
generated_ids = outputs[0]
|
| 131 |
-
else:
|
| 132 |
-
generated_ids = outputs[0][input_ids.shape[1] - len(speech_ids_prefix):-1]
|
| 133 |
-
|
| 134 |
speech_tokens = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
| 135 |
speech_tokens = extract_speech_ids(speech_tokens)
|
| 136 |
|
| 137 |
if not speech_tokens:
|
| 138 |
-
raise ValueError("Error: No valid speech tokens extracted
|
| 139 |
else:
|
| 140 |
-
speech_tokens = torch.tensor(speech_tokens).to(
|
| 141 |
|
| 142 |
gen_wav = Codec_model.decode_code(speech_tokens)
|
| 143 |
gen_wav = gen_wav[:,:,prompt_wav.shape[1]:]
|
|
@@ -145,7 +137,6 @@ def infer(sample_audio_path, target_text, progress=gr.Progress()):
|
|
| 145 |
progress(1.0, 'Complete!')
|
| 146 |
|
| 147 |
return (16000, gen_wav[0, 0, :].cpu().numpy())
|
| 148 |
-
|
| 149 |
with gr.Blocks() as app_tts:
|
| 150 |
gr.Markdown("# Zero Shot Voice Clone TTS")
|
| 151 |
ref_audio_input = gr.Audio(label="Reference Audio", type="filepath")
|
|
|
|
| 26 |
whisper_turbo_pipe = pipeline(
|
| 27 |
"automatic-speech-recognition",
|
| 28 |
model="openai/whisper-large-v3-turbo",
|
| 29 |
+
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.bfloat16,
|
| 30 |
+
device="cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
| 31 |
)
|
| 32 |
|
| 33 |
def ids_to_speech_tokens(speech_ids):
|
|
|
|
| 46 |
return speech_ids
|
| 47 |
|
| 48 |
def infer(sample_audio_path, target_text, progress=gr.Progress()):
|
| 49 |
+
global tokenizer
|
| 50 |
|
| 51 |
if tokenizer is None:
|
| 52 |
print("Warning: Tokenizer is missing, reloading...")
|
| 53 |
+
tokenizer = AutoTokenizer.from_pretrained("llasa_3b")
|
| 54 |
|
| 55 |
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
|
| 56 |
progress(0.2, 'Loading audio...')
|
| 57 |
waveform, sample_rate = torchaudio.load(sample_audio_path)
|
| 58 |
|
| 59 |
+
# ✅ Trim audio for compatibility
|
| 60 |
+
if waveform.size(1) / sample_rate > 60:
|
| 61 |
waveform = waveform[:, :sample_rate * 60]
|
| 62 |
|
| 63 |
progress(0.4, 'Trimming audio...')
|
| 64 |
if waveform.shape[1] / sample_rate > 30:
|
|
|
|
| 65 |
waveform = waveform[:, :sample_rate * 30]
|
| 66 |
|
| 67 |
if waveform.size(0) > 1:
|
| 68 |
waveform_mono = torch.mean(waveform, dim=0, keepdim=True)
|
| 69 |
else:
|
| 70 |
waveform_mono = waveform
|
| 71 |
+
waveform_mono = waveform_mono.to("cuda" if torch.cuda.is_available() else "cpu")
|
| 72 |
|
| 73 |
prompt_wav = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform_mono)
|
| 74 |
prompt_wav_np = prompt_wav[0].detach().cpu().numpy()
|
| 75 |
|
| 76 |
progress(0.6, 'Transcribing speech...')
|
| 77 |
try:
|
| 78 |
+
prompt_text = whisper_turbo_pipe(prompt_wav_np, language="en")['text'].strip() # ✅ Force English transcription
|
| 79 |
except Exception:
|
| 80 |
+
print("Whisper ASR failed. Retrying...")
|
| 81 |
+
prompt_text = whisper_turbo_pipe(prompt_wav_np)['text'].strip()
|
| 82 |
|
| 83 |
if not prompt_text or prompt_text.lower() in ["error: unable to transcribe", ""]:
|
|
|
|
| 84 |
prompt_text = target_text
|
| 85 |
|
| 86 |
progress(0.8, 'Generating synthesized audio...')
|
|
|
|
| 87 |
if len(target_text) == 0:
|
| 88 |
return None
|
| 89 |
elif len(target_text) > 500:
|
| 90 |
target_text = target_text[:500]
|
| 91 |
+
print("Text truncated to 500 characters.")
|
| 92 |
|
| 93 |
input_text = " ".join(filter(None, [prompt_text.strip(), target_text.strip()]))
|
| 94 |
+
|
| 95 |
with torch.no_grad():
|
| 96 |
+
vq_code_prompt = Codec_model.encode_code(prompt_wav)
|
| 97 |
vq_code_prompt = vq_code_prompt[0,0,:]
|
| 98 |
speech_ids_prefix = ids_to_speech_tokens(vq_code_prompt)
|
| 99 |
|
|
|
|
| 105 |
]
|
| 106 |
|
| 107 |
input_ids = tokenizer.apply_chat_template(chat, tokenize=True, return_tensors='pt', continue_final_message=True)
|
| 108 |
+
input_ids = input_ids.to("cuda" if torch.cuda.is_available() else "cpu")
|
| 109 |
speech_end_id = tokenizer.convert_tokens_to_ids('<|SPEECH_GENERATION_END|>')
|
| 110 |
|
| 111 |
if speech_end_id is None:
|
| 112 |
+
raise ValueError("Error: `<|SPEECH_GENERATION_END|>` token not found!")
|
| 113 |
|
| 114 |
outputs = model.generate(
|
| 115 |
input_ids,
|
|
|
|
| 122 |
temperature=0.7,
|
| 123 |
)
|
| 124 |
|
| 125 |
+
generated_ids = outputs[0][input_ids.shape[1] - len(speech_ids_prefix):-1]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 126 |
speech_tokens = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
| 127 |
speech_tokens = extract_speech_ids(speech_tokens)
|
| 128 |
|
| 129 |
if not speech_tokens:
|
| 130 |
+
raise ValueError("Error: No valid speech tokens extracted!")
|
| 131 |
else:
|
| 132 |
+
speech_tokens = torch.tensor(speech_tokens).to("cuda" if torch.cuda.is_available() else "cpu").unsqueeze(0).unsqueeze(0)
|
| 133 |
|
| 134 |
gen_wav = Codec_model.decode_code(speech_tokens)
|
| 135 |
gen_wav = gen_wav[:,:,prompt_wav.shape[1]:]
|
|
|
|
| 137 |
progress(1.0, 'Complete!')
|
| 138 |
|
| 139 |
return (16000, gen_wav[0, 0, :].cpu().numpy())
|
|
|
|
| 140 |
with gr.Blocks() as app_tts:
|
| 141 |
gr.Markdown("# Zero Shot Voice Clone TTS")
|
| 142 |
ref_audio_input = gr.Audio(label="Reference Audio", type="filepath")
|