Hematej commited on
Commit
5a53ea0
·
verified ·
1 Parent(s): e50b6dd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -27
app.py CHANGED
@@ -26,9 +26,8 @@ Codec_model.eval().to(device)
26
  whisper_turbo_pipe = pipeline(
27
  "automatic-speech-recognition",
28
  model="openai/whisper-large-v3-turbo",
29
- torch_dtype=torch.float16 if device == "cuda" else torch.bfloat16,
30
- device=device,
31
-
32
  )
33
 
34
  def ids_to_speech_tokens(speech_ids):
@@ -47,56 +46,54 @@ def extract_speech_ids(speech_tokens_str):
47
  return speech_ids
48
 
49
  def infer(sample_audio_path, target_text, progress=gr.Progress()):
50
- global tokenizer # ✅ Declare before using
51
 
52
  if tokenizer is None:
53
  print("Warning: Tokenizer is missing, reloading...")
54
- tokenizer = AutoTokenizer.from_pretrained(llasa_3b)
55
 
56
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
57
  progress(0.2, 'Loading audio...')
58
  waveform, sample_rate = torchaudio.load(sample_audio_path)
59
 
60
- if len(waveform[0]) / sample_rate > 60:
61
- print("Trimming audio to first 1 minute.")
62
  waveform = waveform[:, :sample_rate * 60]
63
 
64
  progress(0.4, 'Trimming audio...')
65
  if waveform.shape[1] / sample_rate > 30:
66
- print("Trimming audio to 30 seconds for Whisper ASR.")
67
  waveform = waveform[:, :sample_rate * 30]
68
 
69
  if waveform.size(0) > 1:
70
  waveform_mono = torch.mean(waveform, dim=0, keepdim=True)
71
  else:
72
  waveform_mono = waveform
73
- waveform_mono = waveform_mono.to(device)
74
 
75
  prompt_wav = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform_mono)
76
  prompt_wav_np = prompt_wav[0].detach().cpu().numpy()
77
 
78
  progress(0.6, 'Transcribing speech...')
79
  try:
80
- prompt_text = whisper_turbo_pipe(prompt_wav_np)['text'].strip() # ✅ First call without timestamps
81
  except Exception:
82
- print("Whisper ASR failed. Retrying without timestamps...")
83
- prompt_text = whisper_turbo_pipe(prompt_wav_np, return_timestamps=False)['text'].strip()
84
 
85
  if not prompt_text or prompt_text.lower() in ["error: unable to transcribe", ""]:
86
- print("Warning: Whisper ASR output is empty. Defaulting to target text.")
87
  prompt_text = target_text
88
 
89
  progress(0.8, 'Generating synthesized audio...')
90
-
91
  if len(target_text) == 0:
92
  return None
93
  elif len(target_text) > 500:
94
  target_text = target_text[:500]
95
- print("Text is too long. Please keep it under 500 characters.")
96
 
97
  input_text = " ".join(filter(None, [prompt_text.strip(), target_text.strip()]))
 
98
  with torch.no_grad():
99
- vq_code_prompt = Codec_model.encode_code(input_waveform=prompt_wav)
100
  vq_code_prompt = vq_code_prompt[0,0,:]
101
  speech_ids_prefix = ids_to_speech_tokens(vq_code_prompt)
102
 
@@ -108,11 +105,11 @@ def infer(sample_audio_path, target_text, progress=gr.Progress()):
108
  ]
109
 
110
  input_ids = tokenizer.apply_chat_template(chat, tokenize=True, return_tensors='pt', continue_final_message=True)
111
- input_ids = input_ids.to(device)
112
  speech_end_id = tokenizer.convert_tokens_to_ids('<|SPEECH_GENERATION_END|>')
113
 
114
  if speech_end_id is None:
115
- raise ValueError("Error: `<|SPEECH_GENERATION_END|>` token not found in tokenizer!")
116
 
117
  outputs = model.generate(
118
  input_ids,
@@ -125,19 +122,14 @@ def infer(sample_audio_path, target_text, progress=gr.Progress()):
125
  temperature=0.7,
126
  )
127
 
128
- if len(outputs[0]) < input_ids.shape[1] - len(speech_ids_prefix):
129
- print("Warning: Generated output is shorter than expected.")
130
- generated_ids = outputs[0]
131
- else:
132
- generated_ids = outputs[0][input_ids.shape[1] - len(speech_ids_prefix):-1]
133
-
134
  speech_tokens = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
135
  speech_tokens = extract_speech_ids(speech_tokens)
136
 
137
  if not speech_tokens:
138
- raise ValueError("Error: No valid speech tokens extracted—speech synthesis may fail.")
139
  else:
140
- speech_tokens = torch.tensor(speech_tokens).to(device).unsqueeze(0).unsqueeze(0)
141
 
142
  gen_wav = Codec_model.decode_code(speech_tokens)
143
  gen_wav = gen_wav[:,:,prompt_wav.shape[1]:]
@@ -145,7 +137,6 @@ def infer(sample_audio_path, target_text, progress=gr.Progress()):
145
  progress(1.0, 'Complete!')
146
 
147
  return (16000, gen_wav[0, 0, :].cpu().numpy())
148
-
149
  with gr.Blocks() as app_tts:
150
  gr.Markdown("# Zero Shot Voice Clone TTS")
151
  ref_audio_input = gr.Audio(label="Reference Audio", type="filepath")
 
26
  whisper_turbo_pipe = pipeline(
27
  "automatic-speech-recognition",
28
  model="openai/whisper-large-v3-turbo",
29
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.bfloat16,
30
+ device="cuda" if torch.cuda.is_available() else "cpu"
 
31
  )
32
 
33
  def ids_to_speech_tokens(speech_ids):
 
46
  return speech_ids
47
 
48
  def infer(sample_audio_path, target_text, progress=gr.Progress()):
49
+ global tokenizer
50
 
51
  if tokenizer is None:
52
  print("Warning: Tokenizer is missing, reloading...")
53
+ tokenizer = AutoTokenizer.from_pretrained("llasa_3b")
54
 
55
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
56
  progress(0.2, 'Loading audio...')
57
  waveform, sample_rate = torchaudio.load(sample_audio_path)
58
 
59
+ # Trim audio for compatibility
60
+ if waveform.size(1) / sample_rate > 60:
61
  waveform = waveform[:, :sample_rate * 60]
62
 
63
  progress(0.4, 'Trimming audio...')
64
  if waveform.shape[1] / sample_rate > 30:
 
65
  waveform = waveform[:, :sample_rate * 30]
66
 
67
  if waveform.size(0) > 1:
68
  waveform_mono = torch.mean(waveform, dim=0, keepdim=True)
69
  else:
70
  waveform_mono = waveform
71
+ waveform_mono = waveform_mono.to("cuda" if torch.cuda.is_available() else "cpu")
72
 
73
  prompt_wav = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform_mono)
74
  prompt_wav_np = prompt_wav[0].detach().cpu().numpy()
75
 
76
  progress(0.6, 'Transcribing speech...')
77
  try:
78
+ prompt_text = whisper_turbo_pipe(prompt_wav_np, language="en")['text'].strip() # ✅ Force English transcription
79
  except Exception:
80
+ print("Whisper ASR failed. Retrying...")
81
+ prompt_text = whisper_turbo_pipe(prompt_wav_np)['text'].strip()
82
 
83
  if not prompt_text or prompt_text.lower() in ["error: unable to transcribe", ""]:
 
84
  prompt_text = target_text
85
 
86
  progress(0.8, 'Generating synthesized audio...')
 
87
  if len(target_text) == 0:
88
  return None
89
  elif len(target_text) > 500:
90
  target_text = target_text[:500]
91
+ print("Text truncated to 500 characters.")
92
 
93
  input_text = " ".join(filter(None, [prompt_text.strip(), target_text.strip()]))
94
+
95
  with torch.no_grad():
96
+ vq_code_prompt = Codec_model.encode_code(prompt_wav)
97
  vq_code_prompt = vq_code_prompt[0,0,:]
98
  speech_ids_prefix = ids_to_speech_tokens(vq_code_prompt)
99
 
 
105
  ]
106
 
107
  input_ids = tokenizer.apply_chat_template(chat, tokenize=True, return_tensors='pt', continue_final_message=True)
108
+ input_ids = input_ids.to("cuda" if torch.cuda.is_available() else "cpu")
109
  speech_end_id = tokenizer.convert_tokens_to_ids('<|SPEECH_GENERATION_END|>')
110
 
111
  if speech_end_id is None:
112
+ raise ValueError("Error: `<|SPEECH_GENERATION_END|>` token not found!")
113
 
114
  outputs = model.generate(
115
  input_ids,
 
122
  temperature=0.7,
123
  )
124
 
125
+ generated_ids = outputs[0][input_ids.shape[1] - len(speech_ids_prefix):-1]
 
 
 
 
 
126
  speech_tokens = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
127
  speech_tokens = extract_speech_ids(speech_tokens)
128
 
129
  if not speech_tokens:
130
+ raise ValueError("Error: No valid speech tokens extracted!")
131
  else:
132
+ speech_tokens = torch.tensor(speech_tokens).to("cuda" if torch.cuda.is_available() else "cpu").unsqueeze(0).unsqueeze(0)
133
 
134
  gen_wav = Codec_model.decode_code(speech_tokens)
135
  gen_wav = gen_wav[:,:,prompt_wav.shape[1]:]
 
137
  progress(1.0, 'Complete!')
138
 
139
  return (16000, gen_wav[0, 0, :].cpu().numpy())
 
140
  with gr.Blocks() as app_tts:
141
  gr.Markdown("# Zero Shot Voice Clone TTS")
142
  ref_audio_input = gr.Audio(label="Reference Audio", type="filepath")