Hematej commited on
Commit
902f49d
·
verified ·
1 Parent(s): 0e999cf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -12
app.py CHANGED
@@ -33,18 +33,26 @@ whisper_turbo_pipe = pipeline(
33
  )
34
 
35
  def ids_to_speech_tokens(speech_ids):
36
- return [f"<|s_{speech_id}|>" for speech_id in speech_ids]
 
 
 
 
37
 
38
  def extract_speech_ids(speech_tokens_str):
 
39
  speech_ids = []
40
  for token_str in speech_tokens_str:
41
  if token_str.startswith('<|s_') and token_str.endswith('|>'):
42
- try:
43
- speech_ids.append(int(token_str[4:-2]))
44
- except ValueError:
45
- print(f"Unexpected token: {token_str}")
 
 
46
  return speech_ids
47
 
 
48
  @spaces.GPU(duration=60)
49
  def infer(sample_audio_path, target_text, progress=gr.Progress()):
50
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
@@ -54,10 +62,12 @@ def infer(sample_audio_path, target_text, progress=gr.Progress()):
54
  gr.Warning("Trimming audio to first 15secs.")
55
  waveform = waveform[:, :sample_rate*15]
56
 
57
- # Convert stereo to mono dynamically
58
- waveform_mono = waveform.mean(dim=0, keepdim=True) if waveform.size(0) > 1 else waveform
59
- waveform_mono = waveform_mono.to(device)
60
-
 
 
61
  prompt_wav = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform_mono)
62
  prompt_text = whisper_turbo_pipe(prompt_wav[0].cpu().numpy(), language="en")['text'].strip() # ✅ Force English transcription
63
  progress(0.5, 'Transcribed! Generating speech...')
@@ -68,10 +78,10 @@ def infer(sample_audio_path, target_text, progress=gr.Progress()):
68
  gr.Warning("Text is too long. Please keep it under 300 characters.")
69
  target_text = target_text[:300]
70
 
71
- input_text = f"{prompt_text} {target_text}"
72
 
73
  with torch.no_grad():
74
- vq_code_prompt = Codec_model.encode_code(prompt_wav)
75
  vq_code_prompt = vq_code_prompt[0,0,:]
76
  speech_ids_prefix = ids_to_speech_tokens(vq_code_prompt)
77
 
@@ -110,7 +120,7 @@ def infer(sample_audio_path, target_text, progress=gr.Progress()):
110
  if not speech_tokens:
111
  raise ValueError("Error: No valid speech tokens extracted!")
112
 
113
- speech_tensor = torch.tensor(speech_tokens).unsqueeze(0).unsqueeze(0).to(device)
114
 
115
  gen_wav = Codec_model.decode_code(speech_tensor)
116
  gen_wav = gen_wav[:,:,prompt_wav.shape[1]:]
 
33
  )
34
 
35
  def ids_to_speech_tokens(speech_ids):
36
+
37
+ speech_tokens_str = []
38
+ for speech_id in speech_ids:
39
+ speech_tokens_str.append(f"<|s_{speech_id}|>")
40
+ return speech_tokens_str
41
 
42
  def extract_speech_ids(speech_tokens_str):
43
+
44
  speech_ids = []
45
  for token_str in speech_tokens_str:
46
  if token_str.startswith('<|s_') and token_str.endswith('|>'):
47
+ num_str = token_str[4:-2]
48
+
49
+ num = int(num_str)
50
+ speech_ids.append(num)
51
+ else:
52
+ print(f"Unexpected token: {token_str}")
53
  return speech_ids
54
 
55
+
56
  @spaces.GPU(duration=60)
57
  def infer(sample_audio_path, target_text, progress=gr.Progress()):
58
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
 
62
  gr.Warning("Trimming audio to first 15secs.")
63
  waveform = waveform[:, :sample_rate*15]
64
 
65
+ if waveform.size(0) > 1:
66
+ # Convert stereo to mono by averaging the channels
67
+ waveform_mono = torch.mean(waveform, dim=0, keepdim=True)
68
+ else:
69
+ # If already mono, just use the original waveform
70
+ waveform_mono = waveform
71
  prompt_wav = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform_mono)
72
  prompt_text = whisper_turbo_pipe(prompt_wav[0].cpu().numpy(), language="en")['text'].strip() # ✅ Force English transcription
73
  progress(0.5, 'Transcribed! Generating speech...')
 
78
  gr.Warning("Text is too long. Please keep it under 300 characters.")
79
  target_text = target_text[:300]
80
 
81
+ input_text = prompt_text + ' ' + target_text
82
 
83
  with torch.no_grad():
84
+ vq_code_prompt = Codec_model.encode_code(input_waveform=prompt_wav)
85
  vq_code_prompt = vq_code_prompt[0,0,:]
86
  speech_ids_prefix = ids_to_speech_tokens(vq_code_prompt)
87
 
 
120
  if not speech_tokens:
121
  raise ValueError("Error: No valid speech tokens extracted!")
122
 
123
+ speech_tokens = torch.tensor(speech_tokens).unsqueeze(0).unsqueeze(0).to(device)
124
 
125
  gen_wav = Codec_model.decode_code(speech_tensor)
126
  gen_wav = gen_wav[:,:,prompt_wav.shape[1]:]