Spaces:
Bradarr
/
Running on Zero

Bradarr commited on
Commit
0219bf8
·
verified ·
1 Parent(s): 48078de

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -29
app.py CHANGED
@@ -28,48 +28,55 @@ try:
28
  gpu_timeout = int(os.getenv("GPU_TIMEOUT", 180))
29
  except (ValueError, TypeError) as e:
30
  logging.error(f"Configuration error: {e}")
31
- raise # Re-raise the exception to halt the application
32
 
33
  SPACE_INTRO_TEXT = """\
34
  # Sesame CSM 1B - Conversational Demo
35
 
36
- This demo allows you to have a conversation with Sesame CSM 1B, leveraging WhisperX for speech-to-text and Gemma for generating responses. This is an experimental integration and may require significant resources.
37
 
38
- *Disclaimer: This demo relies on several large models. Expect longer processing times, and potential resource limitations.*
39
  """
40
 
41
  # Constants
42
  SPEAKER_ID = 0 # Arbitrary speaker ID
43
  MAX_CONTEXT_SEGMENTS = 5
44
- MAX_GEMMA_LENGTH = 300 #Reduce for the 1.1 2b model
45
  device = "cuda" if torch.cuda.is_available() else "cpu"
46
 
47
- # Global conversation history (important: keep it inside app scope)
48
  conversation_history = []
49
 
50
  # Global variables to hold loaded models
51
  global_generator = None
52
  global_whisper_model = None
53
  global_model_a = None
54
- global_whisper_metadata = None
55
  global_tokenizer_gemma = None
56
  global_model_gemma = None
57
 
 
58
  # --- HELPER FUNCTIONS ---
59
- def transcribe_audio(audio_path: str, whisper_model, model_a, whisper_metadata) -> str:
60
- """Transcribes audio using WhisperX."""
 
61
  try:
62
  audio = whisperx.load_audio(audio_path)
63
  result = whisper_model.transcribe(audio, batch_size=16)
 
 
 
64
 
65
  # Align Whisper output
66
- result_aligned = whisperx.align(result["segments"], model_a, whisper_metadata, audio, whisper_model, device, return_char_alignments=False)
 
67
 
68
  return result_aligned["segments"][0]["text"]
69
  except Exception as e:
70
  logging.error(f"WhisperX transcription error: {e}")
71
  return "Error: Could not transcribe audio."
72
 
 
73
  def generate_response(text: str, tokenizer_gemma, model_gemma) -> str:
74
  """Generates a response using Gemma."""
75
  try:
@@ -81,12 +88,13 @@ def generate_response(text: str, tokenizer_gemma, model_gemma) -> str:
81
  logging.error(f"Gemma response generation error: {e}")
82
  return "I'm sorry, I encountered an error generating a response."
83
 
 
84
  def load_audio(audio_path: str) -> torch.Tensor:
85
  """Loads audio from file and returns a torch tensor."""
86
  try:
87
  audio_tensor, sample_rate = torchaudio.load(audio_path)
88
  audio_tensor = audio_tensor.mean(dim=0) # Mono audio
89
- if sample_rate != global_generator.sample_rate: #Access via global generator
90
  audio_tensor = torchaudio.functional.resample(
91
  audio_tensor, orig_freq=sample_rate, new_freq=global_generator.sample_rate
92
  )
@@ -95,6 +103,7 @@ def load_audio(audio_path: str) -> torch.Tensor:
95
  logging.error(f"Audio loading error: {e}")
96
  raise gr.Error("Could not load or process the audio file.") from e
97
 
 
98
  def clear_history():
99
  """Clears the conversation history"""
100
  global conversation_history
@@ -102,11 +111,13 @@ def clear_history():
102
  logging.info("Conversation history cleared.")
103
  return "Conversation history cleared."
104
 
 
105
  # --- MAIN INFERENCE FUNCTION ---
 
106
  @spaces.GPU(gpu_timeout=gpu_timeout)
107
  def infer(user_audio) -> tuple:
108
  """Infers a response from the user audio."""
109
- global global_generator, global_whisper_model, global_model_a, global_whisper_metadata, global_tokenizer_gemma, global_model_gemma, device
110
 
111
  try:
112
  if not user_audio:
@@ -119,27 +130,26 @@ def infer(user_audio) -> tuple:
119
  logging.info("Sesame CSM 1B loaded successfully on GPU.")
120
 
121
  if global_whisper_model is None:
122
- global_whisper_model, global_whisper_metadata = whisperx.load_model("large-v2", device)
123
- global_model_a, _ = whisperx.load_align_model(language_code=global_whisper_metadata.language, device=device)
124
  logging.info("WhisperX model loaded successfully on GPU.")
125
 
126
  if global_tokenizer_gemma is None:
127
  global_tokenizer_gemma = AutoTokenizer.from_pretrained("google/gemma-3-1b-pt")
128
  global_model_gemma = AutoModelForCausalLM.from_pretrained("google/gemma-3-1b-pt").to(device)
129
  logging.info("Gemma 3 1B pt model loaded successfully on GPU.")
130
-
131
- return _infer(user_audio, global_generator, global_whisper_model, global_model_a, global_whisper_metadata, global_tokenizer_gemma, global_model_gemma)
132
  except Exception as e:
133
  logging.exception(f"Inference error: {e}")
134
  raise gr.Error(f"An error occurred during processing: {e}")
135
 
136
- def _infer(user_audio, generator, whisper_model, model_a, whisper_metadata, tokenizer_gemma, model_gemma) -> tuple:
 
137
  """Processes the user input, generates a response, and returns audio."""
138
  global conversation_history
139
 
140
  try:
141
  # 1. ASR: Transcribe user audio using WhisperX
142
- user_text = transcribe_audio(user_audio, whisper_model, model_a, whisper_metadata)
143
  logging.info(f"User: {user_text}")
144
 
145
  # 2. LLM: Generate a response using Gemma
@@ -147,17 +157,13 @@ def _infer(user_audio, generator, whisper_model, model_a, whisper_metadata, toke
147
  logging.info(f"AI: {ai_text}")
148
 
149
  # 3. Generate audio using the CSM model
150
- try:
151
- ai_audio = generator.generate(
152
- text=ai_text,
153
- speaker=SPEAKER_ID,
154
- context=conversation_history,
155
- max_audio_length_ms=30_000,
156
- )
157
- logging.info("Audio generated successfully.")
158
- except Exception as e:
159
- logging.error(f"CSM response generation error: {e}")
160
- raise gr.Error(f"CSM response generation error: {e}")
161
 
162
  #Update conversation history with user input and ai response.
163
  user_segment = Segment(speaker = SPEAKER_ID, text = 'User Audio', audio = load_audio(user_audio))
@@ -184,7 +190,9 @@ def _infer(user_audio, generator, whisper_model, model_a, whisper_metadata, toke
184
  logging.exception(f"Error in _infer: {e}")
185
  raise gr.Error(f"An error occurred during processing: {e}")
186
 
 
187
  # --- GRADIO INTERFACE ---
 
188
  with gr.Blocks() as app:
189
  gr.Markdown(SPACE_INTRO_TEXT)
190
  audio_input = gr.Audio(label="Your Input", type="filepath")
@@ -196,4 +204,4 @@ with gr.Blocks() as app:
196
  btn.click(infer, inputs=[audio_input], outputs=[audio_output])
197
  clear_button.click(clear_history, outputs=[status_display])
198
 
199
- app.launch(share=False) #Add share = True for public link
 
28
  gpu_timeout = int(os.getenv("GPU_TIMEOUT", 180))
29
  except (ValueError, TypeError) as e:
30
  logging.error(f"Configuration error: {e}")
31
+ raise
32
 
33
  SPACE_INTRO_TEXT = """\
34
  # Sesame CSM 1B - Conversational Demo
35
 
36
+ This demo allows you to have a conversation with Sesame CSM 1B, leveraging WhisperX for speech-to-text and Gemma for generating responses. This is an experimental integration and may require significant resources.
37
 
38
+ *Disclaimer: This demo relies on several large models. Expect longer processing times, and potential resource limitations.*
39
  """
40
 
41
  # Constants
42
  SPEAKER_ID = 0 # Arbitrary speaker ID
43
  MAX_CONTEXT_SEGMENTS = 5
44
+ MAX_GEMMA_LENGTH = 300
45
  device = "cuda" if torch.cuda.is_available() else "cpu"
46
 
47
+ # Global conversation history
48
  conversation_history = []
49
 
50
  # Global variables to hold loaded models
51
  global_generator = None
52
  global_whisper_model = None
53
  global_model_a = None
54
+ # global_whisper_metadata = None # No longer needed at the global level
55
  global_tokenizer_gemma = None
56
  global_model_gemma = None
57
 
58
+
59
  # --- HELPER FUNCTIONS ---
60
+
61
+ def transcribe_audio(audio_path: str, whisper_model, model_a) -> str: # Removed whisper_metadata
62
+ """Transcribes audio using WhisperX and aligns it."""
63
  try:
64
  audio = whisperx.load_audio(audio_path)
65
  result = whisper_model.transcribe(audio, batch_size=16)
66
+ # Get language from the result. Much more reliable.
67
+ language = result["language"]
68
+
69
 
70
  # Align Whisper output
71
+ model_a, metadata = whisperx.load_align_model(language_code=language, device=device) #Load it here to ensure metadata is extracted.
72
+ result_aligned = whisperx.align(result["segments"], model_a, metadata, audio, device, return_char_alignments=False)
73
 
74
  return result_aligned["segments"][0]["text"]
75
  except Exception as e:
76
  logging.error(f"WhisperX transcription error: {e}")
77
  return "Error: Could not transcribe audio."
78
 
79
+
80
  def generate_response(text: str, tokenizer_gemma, model_gemma) -> str:
81
  """Generates a response using Gemma."""
82
  try:
 
88
  logging.error(f"Gemma response generation error: {e}")
89
  return "I'm sorry, I encountered an error generating a response."
90
 
91
+
92
  def load_audio(audio_path: str) -> torch.Tensor:
93
  """Loads audio from file and returns a torch tensor."""
94
  try:
95
  audio_tensor, sample_rate = torchaudio.load(audio_path)
96
  audio_tensor = audio_tensor.mean(dim=0) # Mono audio
97
+ if sample_rate != global_generator.sample_rate:
98
  audio_tensor = torchaudio.functional.resample(
99
  audio_tensor, orig_freq=sample_rate, new_freq=global_generator.sample_rate
100
  )
 
103
  logging.error(f"Audio loading error: {e}")
104
  raise gr.Error("Could not load or process the audio file.") from e
105
 
106
+
107
  def clear_history():
108
  """Clears the conversation history"""
109
  global conversation_history
 
111
  logging.info("Conversation history cleared.")
112
  return "Conversation history cleared."
113
 
114
+
115
  # --- MAIN INFERENCE FUNCTION ---
116
+
117
  @spaces.GPU(gpu_timeout=gpu_timeout)
118
  def infer(user_audio) -> tuple:
119
  """Infers a response from the user audio."""
120
+ global global_generator, global_whisper_model, global_model_a, global_tokenizer_gemma, global_model_gemma, device
121
 
122
  try:
123
  if not user_audio:
 
130
  logging.info("Sesame CSM 1B loaded successfully on GPU.")
131
 
132
  if global_whisper_model is None:
133
+ global_whisper_model = whisperx.load_model("large-v2", device) # No unpacking
 
134
  logging.info("WhisperX model loaded successfully on GPU.")
135
 
136
  if global_tokenizer_gemma is None:
137
  global_tokenizer_gemma = AutoTokenizer.from_pretrained("google/gemma-3-1b-pt")
138
  global_model_gemma = AutoModelForCausalLM.from_pretrained("google/gemma-3-1b-pt").to(device)
139
  logging.info("Gemma 3 1B pt model loaded successfully on GPU.")
140
+ return _infer(user_audio, global_generator, global_whisper_model, global_model_a, global_tokenizer_gemma, global_model_gemma) #Removed Metadata
 
141
  except Exception as e:
142
  logging.exception(f"Inference error: {e}")
143
  raise gr.Error(f"An error occurred during processing: {e}")
144
 
145
+
146
+ def _infer(user_audio, generator, whisper_model, model_a, tokenizer_gemma, model_gemma) -> tuple:
147
  """Processes the user input, generates a response, and returns audio."""
148
  global conversation_history
149
 
150
  try:
151
  # 1. ASR: Transcribe user audio using WhisperX
152
+ user_text = transcribe_audio(user_audio, whisper_model, model_a) #Removed Metadata
153
  logging.info(f"User: {user_text}")
154
 
155
  # 2. LLM: Generate a response using Gemma
 
157
  logging.info(f"AI: {ai_text}")
158
 
159
  # 3. Generate audio using the CSM model
160
+ ai_audio = generator.generate(
161
+ text=ai_text,
162
+ speaker=SPEAKER_ID,
163
+ context=conversation_history,
164
+ max_audio_length_ms=30_000,
165
+ )
166
+ logging.info("Audio generated successfully.")
 
 
 
 
167
 
168
  #Update conversation history with user input and ai response.
169
  user_segment = Segment(speaker = SPEAKER_ID, text = 'User Audio', audio = load_audio(user_audio))
 
190
  logging.exception(f"Error in _infer: {e}")
191
  raise gr.Error(f"An error occurred during processing: {e}")
192
 
193
+
194
  # --- GRADIO INTERFACE ---
195
+
196
  with gr.Blocks() as app:
197
  gr.Markdown(SPACE_INTRO_TEXT)
198
  audio_input = gr.Audio(label="Your Input", type="filepath")
 
204
  btn.click(infer, inputs=[audio_input], outputs=[audio_output])
205
  clear_button.click(clear_history, outputs=[status_display])
206
 
207
+ app.launch(share=False)