Spaces:
Bradarr
/
Running on Zero

Bradarr commited on
Commit
09bb564
·
verified ·
1 Parent(s): ef55fce

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +159 -224
app.py CHANGED
@@ -1,5 +1,4 @@
1
  import os
2
-
3
  import gradio as gr
4
  import numpy as np
5
  import spaces
@@ -8,251 +7,187 @@ import torchaudio
8
  from generator import Segment, load_csm_1b
9
  from huggingface_hub import hf_hub_download, login
10
  from watermarking import watermark
 
 
 
11
 
12
- api_key = os.getenv("HF_TOKEN")
13
- gpu_timeout = int(os.getenv("GPU_TIMEOUT", 60))
14
- CSM_1B_HF_WATERMARK = list(map(int, os.getenv("WATERMARK_KEY").split(" ")))
15
-
16
- login(token=api_key)
17
-
18
- SPACE_INTRO_TEXT = """\
19
- # Sesame CSM 1B
20
-
21
- Generate from CSM 1B (Conversational Speech Model).
22
- Code is available on GitHub: [SesameAILabs/csm](https://github.com/SesameAILabs/csm).
23
- Checkpoint is [hosted on HuggingFace](https://huggingface.co/sesame/csm-1b).
24
 
25
- Try out our interactive demo [sesame.com/voicedemo](https://www.sesame.com/voicedemo),
26
- this uses a fine-tuned variant of CSM.
 
 
 
 
27
 
28
- The model has some capacity for non-English languages due to data contamination in the training
29
- data, but it is likely not to perform well.
 
30
 
31
- ---
 
 
 
32
 
33
- """
 
34
 
35
- CONVO_INTRO_TEXT = """\
36
- ## Conversation content
37
 
38
- Each line is an utterance in the conversation to generate. Speakers alternate between A and B, starting with speaker A.
39
  """
40
 
41
- DEFAULT_CONVERSATION = """\
42
- Hey how are you doing.
43
- Pretty good, pretty good.
44
- I'm great, so happy to be speaking to you.
45
- Me too, this is some cool stuff huh?
46
- Yeah, I've been reading more about speech generation, and it really seems like context is important.
47
- Definitely.
48
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
- SPEAKER_PROMPTS = {
51
- "conversational_a": {
52
- "text": (
53
- "like revising for an exam I'd have to try and like keep up the momentum because I'd "
54
- "start really early I'd be like okay I'm gonna start revising now and then like "
55
- "you're revising for ages and then I just like start losing steam I didn't do that "
56
- "for the exam we had recently to be fair that was a more of a last minute scenario "
57
- "but like yeah I'm trying to like yeah I noticed this yesterday that like Mondays I "
58
- "sort of start the day with this not like a panic but like a"
59
- ),
60
- "audio": "prompts/conversational_a.wav",
61
- },
62
- "conversational_b": {
63
- "text": (
64
- "like a super Mario level. Like it's very like high detail. And like, once you get "
65
- "into the park, it just like, everything looks like a computer game and they have all "
66
- "these, like, you know, if, if there's like a, you know, like in a Mario game, they "
67
- "will have like a question block. And if you like, you know, punch it, a coin will "
68
- "come out. So like everyone, when they come into the park, they get like this little "
69
- "bracelet and then you can go punching question blocks around."
70
- ),
71
- "audio": "prompts/conversational_b.wav",
72
- },
73
- "read_speech_a": {
74
- "text": (
75
- "And Lake turned round upon me, a little abruptly, his odd yellowish eyes, a little "
76
- "like those of the sea eagle, and the ghost of his smile that flickered on his "
77
- "singularly pale face, with a stern and insidious look, confronted me."
78
- ),
79
- "audio": "prompts/read_speech_a.wav",
80
- },
81
- "read_speech_b": {
82
- "text": (
83
- "He was such a big boy that he wore high boots and carried a jack knife. He gazed and "
84
- "gazed at the cap, and could not keep from fingering the blue tassel."
85
- ),
86
- "audio": "prompts/read_speech_b.wav",
87
- },
88
- "read_speech_c": {
89
- "text": (
90
- "All passed so quickly, there was so much going on around him, the Tree quite forgot "
91
- "to look to himself."
92
- ),
93
- "audio": "prompts/read_speech_c.wav",
94
- },
95
- "read_speech_d": {
96
- "text": (
97
- "Suddenly I was back in the old days Before you felt we ought to drift apart. It was "
98
- "some trick-the way your eyebrows raise."
99
- ),
100
- "audio": "prompts/read_speech_d.wav",
101
- },
102
- }
103
-
104
- device = "cuda" if torch.cuda.is_available() else "cpu"
105
- model_path = hf_hub_download(repo_id="sesame/csm-1b", filename="ckpt.pt")
106
- generator = load_csm_1b(model_path, device)
107
 
 
 
 
 
108
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  @spaces.GPU(duration=gpu_timeout)
110
- def infer(
111
- text_prompt_speaker_a,
112
- text_prompt_speaker_b,
113
- audio_prompt_speaker_a,
114
- audio_prompt_speaker_b,
115
- gen_conversation_input,
116
- ) -> tuple[np.ndarray, int]:
117
- # Estimate token limit, otherwise failure might happen after many utterances have been generated.
118
- if len(gen_conversation_input.strip() + text_prompt_speaker_a.strip() + text_prompt_speaker_b.strip()) >= 2000:
119
- raise gr.Error("Prompts and conversation too long.", duration=30)
120
-
121
  try:
122
- return _infer(
123
- text_prompt_speaker_a,
124
- text_prompt_speaker_b,
125
- audio_prompt_speaker_a,
126
- audio_prompt_speaker_b,
127
- gen_conversation_input,
128
- )
129
- except ValueError as e:
130
- raise gr.Error(f"Error generating audio: {e}", duration=120)
131
-
132
-
133
- def _infer(
134
- text_prompt_speaker_a,
135
- text_prompt_speaker_b,
136
- audio_prompt_speaker_a,
137
- audio_prompt_speaker_b,
138
- gen_conversation_input,
139
- ) -> tuple[np.ndarray, int]:
140
- audio_prompt_a = prepare_prompt(text_prompt_speaker_a, 0, audio_prompt_speaker_a)
141
- audio_prompt_b = prepare_prompt(text_prompt_speaker_b, 1, audio_prompt_speaker_b)
142
-
143
- prompt_segments: list[Segment] = [audio_prompt_a, audio_prompt_b]
144
- generated_segments: list[Segment] = []
145
-
146
- conversation_lines = [line.strip() for line in gen_conversation_input.strip().split("\n") if line.strip()]
147
- for i, line in enumerate(conversation_lines):
148
- # Alternating speakers A and B, starting with A
149
- speaker_id = i % 2
150
-
151
- audio_tensor = generator.generate(
152
- text=line,
153
- speaker=speaker_id,
154
- context=prompt_segments + generated_segments,
155
- max_audio_length_ms=30_000,
156
- )
157
- generated_segments.append(Segment(text=line, speaker=speaker_id, audio=audio_tensor))
158
 
159
- # Concatenate all generations and convert to 16-bit int format
160
- audio_tensors = [segment.audio for segment in generated_segments]
161
- audio_tensor = torch.cat(audio_tensors, dim=0)
162
 
163
- # This applies an imperceptible watermark to identify audio as AI-generated.
164
- # Watermarking ensures transparency, dissuades misuse, and enables traceability.
165
- # Please be a responsible AI citizen and keep the watermarking in place.
166
- # If using CSM 1B in another application, use your own private key and keep it secret.
167
- audio_tensor, wm_sample_rate = watermark(
168
- generator._watermarker, audio_tensor, generator.sample_rate, CSM_1B_HF_WATERMARK
169
- )
170
- audio_tensor = torchaudio.functional.resample(
171
- audio_tensor, orig_freq=wm_sample_rate, new_freq=generator.sample_rate
172
- )
173
-
174
- audio_array = (audio_tensor * 32768).to(torch.int16).cpu().numpy()
175
-
176
- return generator.sample_rate, audio_array
177
-
178
-
179
- def prepare_prompt(text: str, speaker: int, audio_path: str) -> Segment:
180
- audio_tensor, _ = load_prompt_audio(audio_path)
181
- return Segment(text=text, speaker=speaker, audio=audio_tensor)
182
-
183
-
184
- def load_prompt_audio(audio_path: str) -> torch.Tensor:
185
- audio_tensor, sample_rate = torchaudio.load(audio_path)
186
- audio_tensor = audio_tensor.squeeze(0)
187
- if sample_rate != generator.sample_rate:
188
- audio_tensor = torchaudio.functional.resample(
189
- audio_tensor, orig_freq=sample_rate, new_freq=generator.sample_rate
 
 
 
 
 
 
 
 
190
  )
191
- return audio_tensor, generator.sample_rate
192
-
193
-
194
- def create_speaker_prompt_ui(speaker_name: str):
195
- speaker_dropdown = gr.Dropdown(
196
- choices=list(SPEAKER_PROMPTS.keys()), label="Select a predefined speaker", value=speaker_name
197
- )
198
- with gr.Accordion("Or add your own voice prompt", open=False):
199
- text_prompt_speaker = gr.Textbox(label="Speaker prompt", lines=4, value=SPEAKER_PROMPTS[speaker_name]["text"])
200
- audio_prompt_speaker = gr.Audio(
201
- label="Speaker prompt", type="filepath", value=SPEAKER_PROMPTS[speaker_name]["audio"]
202
  )
203
 
204
- return speaker_dropdown, text_prompt_speaker, audio_prompt_speaker
 
205
 
 
 
 
 
 
206
 
 
207
  with gr.Blocks() as app:
208
  gr.Markdown(SPACE_INTRO_TEXT)
209
- gr.Markdown("## Voices")
210
- with gr.Row():
211
- with gr.Column():
212
- gr.Markdown("### Speaker A")
213
- speaker_a_dropdown, text_prompt_speaker_a, audio_prompt_speaker_a = create_speaker_prompt_ui(
214
- "conversational_a"
215
- )
216
 
217
- with gr.Column():
218
- gr.Markdown("### Speaker B")
219
- speaker_b_dropdown, text_prompt_speaker_b, audio_prompt_speaker_b = create_speaker_prompt_ui(
220
- "conversational_b"
221
- )
222
 
223
- def update_audio(speaker):
224
- if speaker in SPEAKER_PROMPTS:
225
- return SPEAKER_PROMPTS[speaker]["audio"]
226
- return None
227
-
228
- def update_text(speaker):
229
- if speaker in SPEAKER_PROMPTS:
230
- return SPEAKER_PROMPTS[speaker]["text"]
231
- return None
232
-
233
- speaker_a_dropdown.change(fn=update_audio, inputs=[speaker_a_dropdown], outputs=[audio_prompt_speaker_a])
234
- speaker_b_dropdown.change(fn=update_audio, inputs=[speaker_b_dropdown], outputs=[audio_prompt_speaker_b])
235
-
236
- speaker_a_dropdown.change(fn=update_text, inputs=[speaker_a_dropdown], outputs=[text_prompt_speaker_a])
237
- speaker_b_dropdown.change(fn=update_text, inputs=[speaker_b_dropdown], outputs=[text_prompt_speaker_b])
238
-
239
- gr.Markdown(CONVO_INTRO_TEXT)
240
-
241
- gen_conversation_input = gr.TextArea(label="conversation", lines=20, value=DEFAULT_CONVERSATION)
242
- generate_btn = gr.Button("Generate conversation", variant="primary")
243
- gr.Markdown("GPU time limited to 3 minutes, for longer usage duplicate the space.")
244
- audio_output = gr.Audio(label="Synthesized audio")
245
-
246
- generate_btn.click(
247
- infer,
248
- inputs=[
249
- text_prompt_speaker_a,
250
- text_prompt_speaker_b,
251
- audio_prompt_speaker_a,
252
- audio_prompt_speaker_b,
253
- gen_conversation_input,
254
- ],
255
- outputs=[audio_output],
256
- )
257
-
258
- app.launch(ssr_mode=True)
 
1
  import os
 
2
  import gradio as gr
3
  import numpy as np
4
  import spaces
 
7
  from generator import Segment, load_csm_1b
8
  from huggingface_hub import hf_hub_download, login
9
  from watermarking import watermark
10
+ import whisperx
11
+ from transformers import AutoTokenizer, AutoModelForCausalLM
12
+ import logging
13
 
14
+ # Configure logging
15
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
 
 
 
 
 
 
 
 
 
 
16
 
17
+ # Authentication and Configuration
18
+ try:
19
+ api_key = os.getenv("HF_TOKEN")
20
+ if not api_key:
21
+ raise ValueError("HF_TOKEN not found in environment variables.")
22
+ login(token=api_key)
23
 
24
+ CSM_1B_HF_WATERMARK = list(map(int, os.getenv("WATERMARK_KEY").split(" ")))
25
+ if not CSM_1B_HF_WATERMARK:
26
+ raise ValueError("WATERMARK_KEY not found or invalid in environment variables.")
27
 
28
+ gpu_timeout = int(os.getenv("GPU_TIMEOUT", 180))
29
+ except (ValueError, TypeError) as e:
30
+ logging.error(f"Configuration error: {e}")
31
+ raise # Re-raise the exception to halt the application
32
 
33
+ SPACE_INTRO_TEXT = """\
34
+ # Sesame CSM 1B - Conversational Demo
35
 
36
+ This demo allows you to have a conversation with Sesame CSM 1B, leveraging WhisperX for speech-to-text and Gemma for generating responses. This is an experimental integration and may require significant resources.
 
37
 
38
+ *Disclaimer: This demo relies on several large models. Expect longer processing times, and potential resource limitations.*
39
  """
40
 
41
+ # Model Loading
42
+ try:
43
+ device = "cuda" if torch.cuda.is_available() else "cpu"
44
+ model_path = hf_hub_download(repo_id="sesame/csm-1b", filename="ckpt.pt")
45
+ generator = load_csm_1b(model_path, device)
46
+ logging.info("Sesame CSM 1B loaded successfully.")
47
+
48
+ whisper_model, whisper_metadata = whisperx.load_model("large-v2", device)
49
+ model_a, whisper_metadata = whisperx.load_align_model(language_code=whisper_metadata.language, device=device)
50
+ logging.info("WhisperX model loaded successfully.")
51
+
52
+ # Load Gemma 1.1 2B - adjust model name if needed
53
+ tokenizer_gemma = AutoTokenizer.from_pretrained("google/gemma-3-1b-pt")
54
+ model_gemma = AutoModelForCausalLM.from_pretrained("google/gemma-3-1b-pt").to(device)
55
+ logging.info("Gemma 3 1B pt model loaded successfully.")
56
+
57
+ except Exception as e:
58
+ logging.error(f"Model loading error: {e}")
59
+ raise # Re-raise to prevent the app from launching with incomplete models
60
+
61
+ # Constants
62
+ SPEAKER_ID = 0 # Arbitrary speaker ID
63
+ MAX_CONTEXT_SEGMENTS = 5
64
+ MAX_GEMMA_LENGTH = 300 #Reduce for the 1.1 2b model
65
+
66
+ # Global conversation history (important: keep it inside app scope)
67
+ conversation_history = []
68
+
69
+ # --- HELPER FUNCTIONS ---
70
+ def transcribe_audio(audio_path: str) -> str:
71
+ """Transcribes audio using WhisperX."""
72
+ try:
73
+ audio = whisperx.load_audio(audio_path)
74
+ result = whisper_model.transcribe(audio, batch_size=16) # Added batch_size
75
 
76
+ # Align Whisper output
77
+ result_aligned = whisperx.align(result["segments"], model_a, whisper_metadata, audio, whisper_model, device, return_char_alignments=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
+ return result_aligned["segments"][0]["text"]
80
+ except Exception as e:
81
+ logging.error(f"WhisperX transcription error: {e}")
82
+ return "Error: Could not transcribe audio." # Return an error message
83
 
84
+ def generate_response(text: str) -> str:
85
+ """Generates a response using Gemma."""
86
+ try:
87
+ input_text = "Here is a response for the user. " + text
88
+ input = tokenizer_gemma(input_text, return_tensors="pt").to(device)
89
+ generated_output = model_gemma.generate(**input, max_length=MAX_GEMMA_LENGTH, early_stopping=True) # Added early_stopping
90
+ return tokenizer_gemma.decode(generated_output[0], skip_special_tokens=True)
91
+ except Exception as e:
92
+ logging.error(f"Gemma response generation error: {e}")
93
+ return "I'm sorry, I encountered an error generating a response." # Error fallback
94
+
95
+ def load_audio(audio_path: str) -> torch.Tensor:
96
+ """Loads audio from file and returns a torch tensor."""
97
+ try:
98
+ audio_tensor, sample_rate = torchaudio.load(audio_path)
99
+ audio_tensor = audio_tensor.mean(dim=0) # Mono audio
100
+ if sample_rate != generator.sample_rate:
101
+ audio_tensor = torchaudio.functional.resample(
102
+ audio_tensor, orig_freq=sample_rate, new_freq=generator.sample_rate
103
+ )
104
+ return audio_tensor
105
+ except Exception as e:
106
+ logging.error(f"Audio loading error: {e}")
107
+ raise gr.Error("Could not load or process the audio file.") from e # Re-raise as Gradio error
108
+
109
+ def clear_history():
110
+ """Clears the conversation history"""
111
+ global conversation_history
112
+ conversation_history = []
113
+ logging.info("Conversation history cleared.")
114
+ return "Conversation history cleared."
115
+
116
+ # --- MAIN INFERENCE FUNCTION ---
117
  @spaces.GPU(duration=gpu_timeout)
118
+ def infer(user_audio) -> tuple[int, np.ndarray]: # Return sample_rate as int
119
+ """Infers a response from the user audio."""
 
 
 
 
 
 
 
 
 
120
  try:
121
+ if not user_audio:
122
+ raise ValueError("No audio input received.")
123
+ return _infer(user_audio)
124
+ except Exception as e:
125
+ logging.exception(f"Inference error: {e}") # Log the full exception
126
+ raise gr.Error(f"An error occurred during processing: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
 
128
+ def _infer(user_audio) -> tuple[int, np.ndarray]: # Return sample_rate as int
129
+ """Processes the user input, generates a response, and returns audio."""
130
+ global conversation_history # Declare to modify the global list
131
 
132
+ try:
133
+ # 1. ASR: Transcribe user audio using WhisperX
134
+ user_text = transcribe_audio(user_audio)
135
+ logging.info(f"User: {user_text}")
136
+
137
+ # 2. LLM: Generate a response using Gemma
138
+ ai_text = generate_response(user_text)
139
+ logging.info(f"AI: {ai_text}")
140
+
141
+ # 3. Generate audio using the CSM model
142
+ try:
143
+ ai_audio = generator.generate(
144
+ text=ai_text,
145
+ speaker=SPEAKER_ID,
146
+ context=conversation_history,
147
+ max_audio_length_ms=30_000,
148
+ )
149
+ logging.info("Audio generated successfully.")
150
+ except Exception as e:
151
+ logging.error(f"Gemma response generation error: {e}")
152
+ raise gr.Error(f"Gemma response generation error: {e}") # Error fallback
153
+
154
+ #Update conversation history with user input and ai response.
155
+ user_segment = Segment(speaker = SPEAKER_ID, text = 'User Audio', audio = load_audio(user_audio))
156
+ ai_segment = Segment(speaker = SPEAKER_ID, text = 'AI Audio', audio = ai_audio)
157
+ conversation_history.append(user_segment)
158
+ conversation_history.append(ai_segment)
159
+
160
+ #Limit Conversation History
161
+ if len(conversation_history) > MAX_CONTEXT_SEGMENTS:
162
+ conversation_history.pop(0)
163
+
164
+ # 4. Watermarking and Audio Conversion
165
+ audio_tensor, wm_sample_rate = watermark(
166
+ generator._watermarker, ai_audio, generator.sample_rate, CSM_1B_HF_WATERMARK
167
  )
168
+ audio_tensor = torchaudio.functional.resample(
169
+ audio_tensor, orig_freq=wm_sample_rate, new_freq=generator.sample_rate
 
 
 
 
 
 
 
 
 
170
  )
171
 
172
+ ai_audio_array = (audio_tensor * 32768).to(torch.int16).cpu().numpy()
173
+ return generator.sample_rate, ai_audio_array
174
 
175
+ except Exception as e:
176
+ logging.exception(f"Error in _infer: {e}")
177
+ # Log the full exception including stack trace for debugging.
178
+ # It's crucial to log the *exception*, not just the error message.
179
+ raise gr.Error(f"An error occurred during processing: {e}")
180
 
181
+ # --- GRADIO INTERFACE ---
182
  with gr.Blocks() as app:
183
  gr.Markdown(SPACE_INTRO_TEXT)
184
+ audio_input = gr.Audio(label="Your Input", source="microphone", type="filepath")
185
+ audio_output = gr.Audio(label="AI Response")
186
+ clear_button = gr.Button("Clear Conversation History")
187
+ status_display = gr.Textbox(label="Status", visible=False)
 
 
 
188
 
189
+ btn = gr.Button("Generate Response")
190
+ btn.click(infer, inputs=[audio_input], outputs=[audio_output])
191
+ clear_button.click(clear_history, outputs=[status_display]) # No input needed
 
 
192
 
193
+ app.launch(ssr_mode=True)