Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -33,18 +33,26 @@ whisper_turbo_pipe = pipeline(
|
|
33 |
)
|
34 |
|
35 |
def ids_to_speech_tokens(speech_ids):
|
36 |
-
|
|
|
|
|
|
|
|
|
37 |
|
38 |
def extract_speech_ids(speech_tokens_str):
|
|
|
39 |
speech_ids = []
|
40 |
for token_str in speech_tokens_str:
|
41 |
if token_str.startswith('<|s_') and token_str.endswith('|>'):
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
|
|
|
|
46 |
return speech_ids
|
47 |
|
|
|
48 |
@spaces.GPU(duration=60)
|
49 |
def infer(sample_audio_path, target_text, progress=gr.Progress()):
|
50 |
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
|
@@ -54,10 +62,12 @@ def infer(sample_audio_path, target_text, progress=gr.Progress()):
|
|
54 |
gr.Warning("Trimming audio to first 15secs.")
|
55 |
waveform = waveform[:, :sample_rate*15]
|
56 |
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
|
|
|
|
61 |
prompt_wav = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform_mono)
|
62 |
prompt_text = whisper_turbo_pipe(prompt_wav[0].cpu().numpy(), language="en")['text'].strip() # ✅ Force English transcription
|
63 |
progress(0.5, 'Transcribed! Generating speech...')
|
@@ -68,10 +78,10 @@ def infer(sample_audio_path, target_text, progress=gr.Progress()):
|
|
68 |
gr.Warning("Text is too long. Please keep it under 300 characters.")
|
69 |
target_text = target_text[:300]
|
70 |
|
71 |
-
input_text =
|
72 |
|
73 |
with torch.no_grad():
|
74 |
-
vq_code_prompt = Codec_model.encode_code(prompt_wav)
|
75 |
vq_code_prompt = vq_code_prompt[0,0,:]
|
76 |
speech_ids_prefix = ids_to_speech_tokens(vq_code_prompt)
|
77 |
|
@@ -110,7 +120,7 @@ def infer(sample_audio_path, target_text, progress=gr.Progress()):
|
|
110 |
if not speech_tokens:
|
111 |
raise ValueError("Error: No valid speech tokens extracted!")
|
112 |
|
113 |
-
|
114 |
|
115 |
gen_wav = Codec_model.decode_code(speech_tensor)
|
116 |
gen_wav = gen_wav[:,:,prompt_wav.shape[1]:]
|
|
|
33 |
)
|
34 |
|
35 |
def ids_to_speech_tokens(speech_ids):
|
36 |
+
|
37 |
+
speech_tokens_str = []
|
38 |
+
for speech_id in speech_ids:
|
39 |
+
speech_tokens_str.append(f"<|s_{speech_id}|>")
|
40 |
+
return speech_tokens_str
|
41 |
|
42 |
def extract_speech_ids(speech_tokens_str):
|
43 |
+
|
44 |
speech_ids = []
|
45 |
for token_str in speech_tokens_str:
|
46 |
if token_str.startswith('<|s_') and token_str.endswith('|>'):
|
47 |
+
num_str = token_str[4:-2]
|
48 |
+
|
49 |
+
num = int(num_str)
|
50 |
+
speech_ids.append(num)
|
51 |
+
else:
|
52 |
+
print(f"Unexpected token: {token_str}")
|
53 |
return speech_ids
|
54 |
|
55 |
+
|
56 |
@spaces.GPU(duration=60)
|
57 |
def infer(sample_audio_path, target_text, progress=gr.Progress()):
|
58 |
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
|
|
|
62 |
gr.Warning("Trimming audio to first 15secs.")
|
63 |
waveform = waveform[:, :sample_rate*15]
|
64 |
|
65 |
+
if waveform.size(0) > 1:
|
66 |
+
# Convert stereo to mono by averaging the channels
|
67 |
+
waveform_mono = torch.mean(waveform, dim=0, keepdim=True)
|
68 |
+
else:
|
69 |
+
# If already mono, just use the original waveform
|
70 |
+
waveform_mono = waveform
|
71 |
prompt_wav = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform_mono)
|
72 |
prompt_text = whisper_turbo_pipe(prompt_wav[0].cpu().numpy(), language="en")['text'].strip() # ✅ Force English transcription
|
73 |
progress(0.5, 'Transcribed! Generating speech...')
|
|
|
78 |
gr.Warning("Text is too long. Please keep it under 300 characters.")
|
79 |
target_text = target_text[:300]
|
80 |
|
81 |
+
input_text = prompt_text + ' ' + target_text
|
82 |
|
83 |
with torch.no_grad():
|
84 |
+
vq_code_prompt = Codec_model.encode_code(input_waveform=prompt_wav)
|
85 |
vq_code_prompt = vq_code_prompt[0,0,:]
|
86 |
speech_ids_prefix = ids_to_speech_tokens(vq_code_prompt)
|
87 |
|
|
|
120 |
if not speech_tokens:
|
121 |
raise ValueError("Error: No valid speech tokens extracted!")
|
122 |
|
123 |
+
speech_tokens = torch.tensor(speech_tokens).unsqueeze(0).unsqueeze(0).to(device)
|
124 |
|
125 |
gen_wav = Codec_model.decode_code(speech_tensor)
|
126 |
gen_wav = gen_wav[:,:,prompt_wav.shape[1]:]
|