OmniAICreator commited on
Commit
d53472f
·
verified ·
1 Parent(s): 776d153

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +104 -80
app.py CHANGED
@@ -5,22 +5,21 @@ import soundfile as sf
5
  from xcodec2.modeling_xcodec2 import XCodec2Model
6
  import torchaudio
7
  import gradio as gr
8
- import tempfile
9
 
10
- llasa_3b ='srinivasbilla/llasa-3b'
11
 
12
- tokenizer = AutoTokenizer.from_pretrained(llasa_3b)
13
 
14
  model = AutoModelForCausalLM.from_pretrained(
15
- llasa_3b,
16
  trust_remote_code=True,
17
- device_map='cuda',
18
  )
 
19
 
20
- model_path = "srinivasbilla/xcodec2"
21
 
22
- Codec_model = XCodec2Model.from_pretrained(model_path)
23
- Codec_model.eval().cuda()
24
 
25
  whisper_turbo_pipe = pipeline(
26
  "automatic-speech-recognition",
@@ -50,87 +49,105 @@ def extract_speech_ids(speech_tokens_str):
50
  return speech_ids
51
 
52
  @spaces.GPU(duration=60)
53
- def infer(sample_audio_path, target_text, progress=gr.Progress()):
54
- with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
55
- progress(0, 'Loading and trimming audio...')
56
- waveform, sample_rate = torchaudio.load(sample_audio_path)
57
- if len(waveform[0])/sample_rate > 15:
58
- gr.Warning("Trimming audio to first 15secs.")
59
- waveform = waveform[:, :sample_rate*15]
60
-
61
- # Check if the audio is stereo (i.e., has more than one channel)
62
- if waveform.size(0) > 1:
63
- # Convert stereo to mono by averaging the channels
64
- waveform_mono = torch.mean(waveform, dim=0, keepdim=True)
65
- else:
66
- # If already mono, just use the original waveform
67
- waveform_mono = waveform
68
-
69
- prompt_wav = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform_mono)
70
- prompt_text = whisper_turbo_pipe(prompt_wav[0].numpy())['text'].strip()
71
- progress(0.5, 'Transcribed! Generating speech...')
72
-
73
- if len(target_text) == 0:
74
- return None
75
- elif len(target_text) > 300:
76
- gr.Warning("Text is too long. Please keep it under 300 characters.")
77
- target_text = target_text[:300]
78
-
79
- input_text = prompt_text + ' ' + target_text
80
-
81
- #TTS start!
82
- with torch.no_grad():
83
  # Encode the prompt wav
84
- vq_code_prompt = Codec_model.encode_code(input_waveform=prompt_wav)
85
 
86
- vq_code_prompt = vq_code_prompt[0,0,:]
87
  # Convert int 12345 to token <|s_12345|>
88
  speech_ids_prefix = ids_to_speech_tokens(vq_code_prompt)
89
 
90
- formatted_text = f"<|TEXT_UNDERSTANDING_START|>{input_text}<|TEXT_UNDERSTANDING_END|>"
91
-
92
- # Tokenize the text and the speech prefix
93
- chat = [
94
- {"role": "user", "content": "Convert the text to speech:" + formatted_text},
95
- {"role": "assistant", "content": "<|SPEECH_GENERATION_START|>" + ''.join(speech_ids_prefix)}
96
- ]
97
-
98
- input_ids = tokenizer.apply_chat_template(
99
- chat,
100
- tokenize=True,
101
- return_tensors='pt',
102
- continue_final_message=True
103
- )
104
- input_ids = input_ids.to('cuda')
105
- speech_end_id = tokenizer.convert_tokens_to_ids('<|SPEECH_GENERATION_END|>')
106
-
107
- # Generate the speech autoregressively
108
- outputs = model.generate(
109
- input_ids,
110
- max_length=2048, # We trained our model with a max length of 2048
111
- eos_token_id= speech_end_id ,
112
- do_sample=True,
113
- top_p=1,
114
- temperature=0.8
115
- )
116
- # Extract the speech tokens
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
  generated_ids = outputs[0][input_ids.shape[1]-len(speech_ids_prefix):-1]
 
 
118
 
119
- speech_tokens = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
120
 
121
- # Convert token <|s_23456|> to int 23456
122
- speech_tokens = extract_speech_ids(speech_tokens)
123
 
124
- speech_tokens = torch.tensor(speech_tokens).cuda().unsqueeze(0).unsqueeze(0)
 
 
125
 
126
- # Decode the speech tokens to speech waveform
127
- gen_wav = Codec_model.decode_code(speech_tokens)
128
 
129
- # if only need the generated part
130
- gen_wav = gen_wav[:,:,prompt_wav.shape[1]:]
131
 
132
- progress(1, 'Synthesized!')
 
 
133
 
 
 
134
  return (16000, gen_wav[0, 0, :].cpu().numpy())
135
 
136
  with gr.Blocks() as app_tts:
@@ -138,6 +155,10 @@ with gr.Blocks() as app_tts:
138
  ref_audio_input = gr.Audio(label="Reference Audio", type="filepath")
139
  gen_text_input = gr.Textbox(label="Text to Generate", lines=10)
140
 
 
 
 
 
141
  generate_btn = gr.Button("Synthesize", variant="primary")
142
 
143
  audio_output = gr.Audio(label="Synthesized Audio")
@@ -147,6 +168,8 @@ with gr.Blocks() as app_tts:
147
  inputs=[
148
  ref_audio_input,
149
  gen_text_input,
 
 
150
  ],
151
  outputs=[audio_output],
152
  )
@@ -156,17 +179,18 @@ with gr.Blocks() as app_credits:
156
  # Credits
157
 
158
  * [zhenye234](https://github.com/zhenye234) for the original [repo](https://github.com/zhenye234/LLaSA_training)
159
- * [mrfakename](https://huggingface.co/mrfakename) for the [gradio demo code](https://huggingface.co/spaces/mrfakename/E2-F5-TTS)
 
160
  """)
161
 
162
  with gr.Blocks() as app:
163
  gr.Markdown(
164
  """
165
- # llasa 3b TTS
166
 
167
- This is a local web UI for llasa 3b SOTA(imo) Zero Shot Voice Cloning and TTS model.
168
 
169
- The checkpoints support English and Chinese.
170
 
171
  If you're having issues, try converting your reference audio to WAV or MP3, clipping it to 15s, and shortening your prompt.
172
  """
 
5
  from xcodec2.modeling_xcodec2 import XCodec2Model
6
  import torchaudio
7
  import gradio as gr
 
8
 
9
+ llasa_model_id = 'OmniAICreator/Galgame-Llasa-3B'
10
 
11
+ tokenizer = AutoTokenizer.from_pretrained(llasa_model_id)
12
 
13
  model = AutoModelForCausalLM.from_pretrained(
14
+ llasa_model_id,
15
  trust_remote_code=True,
 
16
  )
17
+ model.eval().cuda()
18
 
19
+ xcodec2_model_id = "HKUSTAudio/xcodec2"
20
 
21
+ codec_model = XCodec2Model.from_pretrained(xcodec2_model_id)
22
+ codec_model.eval().cuda()
23
 
24
  whisper_turbo_pipe = pipeline(
25
  "automatic-speech-recognition",
 
49
  return speech_ids
50
 
51
  @spaces.GPU(duration=60)
52
+ def infer(sample_audio_path, target_text, temperature, top_p, progress=gr.Progress()):
53
+ if not target_text or not target_text.strip():
54
+ gr.Warning("Please input text to generate audio.")
55
+ return None, None
56
+ if len(target_text) > 300:
57
+ gr.Warning("Text is too long. Please keep it under 300 characters.")
58
+ target_text = target_text[:300]
59
+ with torch.no_grad():
60
+ if sample_audio_path:
61
+ progress(0, 'Loading and trimming audio...')
62
+ waveform, sample_rate = torchaudio.load(sample_audio_path)
63
+ if len(waveform[0])/sample_rate > 15:
64
+ gr.Warning("Trimming audio to first 15secs.")
65
+ waveform = waveform[:, :sample_rate*15]
66
+
67
+ # Check if the audio is stereo (i.e., has more than one channel)
68
+ if waveform.size(0) > 1:
69
+ # Convert stereo to mono by averaging the channels
70
+ waveform_mono = torch.mean(waveform, dim=0, keepdim=True)
71
+ else:
72
+ # If already mono, just use the original waveform
73
+ waveform_mono = waveform
74
+
75
+ prompt_wav = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform_mono)
76
+ prompt_wav_len = prompt_wav.shape[1]
77
+ prompt_text = whisper_turbo_pipe(prompt_wav[0].numpy())['text'].strip()
78
+ progress(0.5, 'Transcribed! Encoding audio...')
79
+
 
 
80
  # Encode the prompt wav
81
+ vq_code_prompt = codec_model.encode_code(input_waveform=prompt_wav)[0, 0, :]
82
 
 
83
  # Convert int 12345 to token <|s_12345|>
84
  speech_ids_prefix = ids_to_speech_tokens(vq_code_prompt)
85
 
86
+ input_text = prompt_text + ' ' + target_text
87
+
88
+ assistant_content = "<|SPEECH_GENERATION_START|>" + ''.join(speech_ids_prefix)
89
+ else:
90
+ progress(0, "Preparing...")
91
+ input_text = target_text
92
+ assistant_content = "<|SPEECH_GENERATION_START|>"
93
+ speech_ids_prefix = []
94
+ prompt_wav_len = 0
95
+
96
+ progress(0.75, "Generating audio...")
97
+
98
+ formatted_text = f"<|TEXT_UNDERSTANDING_START|>{input_text}<|TEXT_UNDERSTANDING_END|>"
99
+
100
+ # Tokenize the text and the speech prefix
101
+ chat = [
102
+ {"role": "user", "content": "Convert the text to speech:" + formatted_text},
103
+ {"role": "assistant", "content": assistant_content}
104
+ ]
105
+
106
+ input_ids = tokenizer.apply_chat_template(
107
+ chat,
108
+ tokenize=True,
109
+ return_tensors='pt',
110
+ continue_final_message=True
111
+ ).to('cuda')
112
+
113
+ speech_end_id = tokenizer.convert_tokens_to_ids('<|SPEECH_GENERATION_END|>')
114
+
115
+ # Generate the speech autoregressively
116
+ outputs = model.generate(
117
+ input_ids,
118
+ max_length=2048, # We trained our model with a max length of 2048
119
+ eos_token_id=speech_end_id,
120
+ do_sample=True,
121
+ top_p=top_p,
122
+ temperature=temperature
123
+ )
124
+
125
+ # Extract the speech tokens
126
+ if sample_audio_path:
127
  generated_ids = outputs[0][input_ids.shape[1]-len(speech_ids_prefix):-1]
128
+ else:
129
+ generated_ids = outputs[0][input_ids.shape[1]:-1]
130
 
131
+ speech_tokens = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
132
 
133
+ # Convert token <|s_23456|> to int 23456
134
+ speech_tokens = extract_speech_ids(speech_tokens)
135
 
136
+ if not speech_tokens:
137
+ gr.Error("Audio generation failed.")
138
+ return None
139
 
140
+ speech_tokens = torch.tensor(speech_tokens).cuda().unsqueeze(0).unsqueeze(0)
 
141
 
142
+ # Decode the speech tokens to speech waveform
143
+ gen_wav = codec_model.decode_code(speech_tokens)
144
 
145
+ # if only need the generated part
146
+ if sample_audio_path and prompt_wav_len > 0:
147
+ gen_wav = gen_wav[:, :, prompt_wav_len:]
148
 
149
+ progress(1, 'Synthesized!')
150
+
151
  return (16000, gen_wav[0, 0, :].cpu().numpy())
152
 
153
  with gr.Blocks() as app_tts:
 
155
  ref_audio_input = gr.Audio(label="Reference Audio", type="filepath")
156
  gen_text_input = gr.Textbox(label="Text to Generate", lines=10)
157
 
158
+ with gr.Row():
159
+ temperature_slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.5, step=0.05, label="Temperature")
160
+ top_p_slider = gr.Slider(minimum=0.0, maximum=1.0, value=1.0, step=0.05, label="Top-p")
161
+
162
  generate_btn = gr.Button("Synthesize", variant="primary")
163
 
164
  audio_output = gr.Audio(label="Synthesized Audio")
 
168
  inputs=[
169
  ref_audio_input,
170
  gen_text_input,
171
+ temperature_slider,
172
+ top_p_slider,
173
  ],
174
  outputs=[audio_output],
175
  )
 
179
  # Credits
180
 
181
  * [zhenye234](https://github.com/zhenye234) for the original [repo](https://github.com/zhenye234/LLaSA_training)
182
+ * [mrfakename](https://huggingface.co/mrfakename) for the [gradio demo code](https://huggingface.co/spaces/mrfakename/E2-F5-TTS)
183
+ * [SunderAli17](https://huggingface.co/SunderAli17) for the [gradio demo code](https://huggingface.co/spaces/SunderAli17/llasa-3b-tts)
184
  """)
185
 
186
  with gr.Blocks() as app:
187
  gr.Markdown(
188
  """
189
+ # Galgame Llasa 3B
190
 
191
+ This is a local web UI for Galgame Llasa 3B TTS model.
192
 
193
+ The model is fine-tuned by Japanese audio data.
194
 
195
  If you're having issues, try converting your reference audio to WAV or MP3, clipping it to 15s, and shortening your prompt.
196
  """