Spaces:
Bradarr
/
Running on Zero

Bradarr commited on
Commit
721e588
·
verified ·
1 Parent(s): 76e7434

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -45
app.py CHANGED
@@ -38,41 +38,29 @@ This demo allows you to have a conversation with Sesame CSM 1B, leveraging Whisp
38
  *Disclaimer: This demo relies on several large models. Expect longer processing times, and potential resource limitations.*
39
  """
40
 
41
- # Model Loading
42
- @spaces.GPU()
43
- try:
44
- device = "cuda" if torch.cuda.is_available() else "cpu"
45
- model_path = hf_hub_download(repo_id="sesame/csm-1b", filename="ckpt.pt")
46
- generator = load_csm_1b(model_path, device)
47
- logging.info("Sesame CSM 1B loaded successfully.")
48
-
49
- whisper_model, whisper_metadata = whisperx.load_model("large-v2", device)
50
- model_a, whisper_metadata = whisperx.load_align_model(language_code=whisper_metadata.language, device=device)
51
- logging.info("WhisperX model loaded successfully.")
52
-
53
- # Load Gemma 1.1 2B - adjust model name if needed
54
- tokenizer_gemma = AutoTokenizer.from_pretrained("google/gemma-3-1b-pt")
55
- model_gemma = AutoModelForCausalLM.from_pretrained("google/gemma-3-1b-pt").to(device)
56
- logging.info("Gemma 3 1B pt model loaded successfully.")
57
-
58
- except Exception as e:
59
- logging.error(f"Model loading error: {e}")
60
- raise # Re-raise to prevent the app from launching with incomplete models
61
-
62
  # Constants
63
  SPEAKER_ID = 0 # Arbitrary speaker ID
64
  MAX_CONTEXT_SEGMENTS = 5
65
- MAX_GEMMA_LENGTH = 300 #Reduce for the 1.1 2b model
 
66
 
67
  # Global conversation history (important: keep it inside app scope)
68
  conversation_history = []
69
 
 
 
 
 
 
 
 
 
70
  # --- HELPER FUNCTIONS ---
71
- def transcribe_audio(audio_path: str) -> str:
72
  """Transcribes audio using WhisperX."""
73
  try:
74
  audio = whisperx.load_audio(audio_path)
75
- result = whisper_model.transcribe(audio, batch_size=16) # Added batch_size
76
 
77
  # Align Whisper output
78
  result_aligned = whisperx.align(result["segments"], model_a, whisper_metadata, audio, whisper_model, device, return_char_alignments=False)
@@ -80,32 +68,32 @@ def transcribe_audio(audio_path: str) -> str:
80
  return result_aligned["segments"][0]["text"]
81
  except Exception as e:
82
  logging.error(f"WhisperX transcription error: {e}")
83
- return "Error: Could not transcribe audio." # Return an error message
84
 
85
- def generate_response(text: str) -> str:
86
  """Generates a response using Gemma."""
87
  try:
88
  input_text = "Here is a response for the user. " + text
89
  input = tokenizer_gemma(input_text, return_tensors="pt").to(device)
90
- generated_output = model_gemma.generate(**input, max_length=MAX_GEMMA_LENGTH, early_stopping=True) # Added early_stopping
91
  return tokenizer_gemma.decode(generated_output[0], skip_special_tokens=True)
92
  except Exception as e:
93
  logging.error(f"Gemma response generation error: {e}")
94
- return "I'm sorry, I encountered an error generating a response." # Error fallback
95
 
96
  def load_audio(audio_path: str) -> torch.Tensor:
97
  """Loads audio from file and returns a torch tensor."""
98
  try:
99
  audio_tensor, sample_rate = torchaudio.load(audio_path)
100
  audio_tensor = audio_tensor.mean(dim=0) # Mono audio
101
- if sample_rate != generator.sample_rate:
102
  audio_tensor = torchaudio.functional.resample(
103
- audio_tensor, orig_freq=sample_rate, new_freq=generator.sample_rate
104
  )
105
  return audio_tensor
106
  except Exception as e:
107
  logging.error(f"Audio loading error: {e}")
108
- raise gr.Error("Could not load or process the audio file.") from e # Re-raise as Gradio error
109
 
110
  def clear_history():
111
  """Clears the conversation history"""
@@ -115,28 +103,47 @@ def clear_history():
115
  return "Conversation history cleared."
116
 
117
  # --- MAIN INFERENCE FUNCTION ---
118
- #@spaces.GPU()
119
- def infer(user_audio) -> tuple[int, np.ndarray]: # Return sample_rate as int
120
  """Infers a response from the user audio."""
 
 
121
  try:
122
  if not user_audio:
123
  raise ValueError("No audio input received.")
124
- return _infer(user_audio)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  except Exception as e:
126
- logging.exception(f"Inference error: {e}") # Log the full exception
127
  raise gr.Error(f"An error occurred during processing: {e}")
128
 
129
- def _infer(user_audio) -> tuple[int, np.ndarray]: # Return sample_rate as int
130
  """Processes the user input, generates a response, and returns audio."""
131
- global conversation_history # Declare to modify the global list
132
 
133
  try:
134
  # 1. ASR: Transcribe user audio using WhisperX
135
- user_text = transcribe_audio(user_audio)
136
  logging.info(f"User: {user_text}")
137
 
138
  # 2. LLM: Generate a response using Gemma
139
- ai_text = generate_response(user_text)
140
  logging.info(f"AI: {ai_text}")
141
 
142
  # 3. Generate audio using the CSM model
@@ -149,8 +156,8 @@ def _infer(user_audio) -> tuple[int, np.ndarray]: # Return sample_rate as int
149
  )
150
  logging.info("Audio generated successfully.")
151
  except Exception as e:
152
- logging.error(f"Gemma response generation error: {e}")
153
- raise gr.Error(f"Gemma response generation error: {e}") # Error fallback
154
 
155
  #Update conversation history with user input and ai response.
156
  user_segment = Segment(speaker = SPEAKER_ID, text = 'User Audio', audio = load_audio(user_audio))
@@ -175,8 +182,6 @@ def _infer(user_audio) -> tuple[int, np.ndarray]: # Return sample_rate as int
175
 
176
  except Exception as e:
177
  logging.exception(f"Error in _infer: {e}")
178
- # Log the full exception including stack trace for debugging.
179
- # It's crucial to log the *exception*, not just the error message.
180
  raise gr.Error(f"An error occurred during processing: {e}")
181
 
182
  # --- GRADIO INTERFACE ---
@@ -189,6 +194,6 @@ with gr.Blocks() as app:
189
 
190
  btn = gr.Button("Generate Response")
191
  btn.click(infer, inputs=[audio_input], outputs=[audio_output])
192
- clear_button.click(clear_history, outputs=[status_display]) # No input needed
193
 
194
- app.launch(ssr_mode=True)
 
38
  *Disclaimer: This demo relies on several large models. Expect longer processing times, and potential resource limitations.*
39
  """
40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  # Constants
42
  SPEAKER_ID = 0 # Arbitrary speaker ID
43
  MAX_CONTEXT_SEGMENTS = 5
44
+ MAX_GEMMA_LENGTH = 300 #Reduce for the 1.1 2b model
45
+ device = "cuda" if torch.cuda.is_available() else "cpu"
46
 
47
  # Global conversation history (important: keep it inside app scope)
48
  conversation_history = []
49
 
50
+ # Global variables to hold loaded models
51
+ global_generator = None
52
+ global_whisper_model = None
53
+ global_model_a = None
54
+ global_whisper_metadata = None
55
+ global_tokenizer_gemma = None
56
+ global_model_gemma = None
57
+
58
  # --- HELPER FUNCTIONS ---
59
+ def transcribe_audio(audio_path: str, whisper_model, model_a, whisper_metadata) -> str:
60
  """Transcribes audio using WhisperX."""
61
  try:
62
  audio = whisperx.load_audio(audio_path)
63
+ result = whisper_model.transcribe(audio, batch_size=16)
64
 
65
  # Align Whisper output
66
  result_aligned = whisperx.align(result["segments"], model_a, whisper_metadata, audio, whisper_model, device, return_char_alignments=False)
 
68
  return result_aligned["segments"][0]["text"]
69
  except Exception as e:
70
  logging.error(f"WhisperX transcription error: {e}")
71
+ return "Error: Could not transcribe audio."
72
 
73
+ def generate_response(text: str, tokenizer_gemma, model_gemma) -> str:
74
  """Generates a response using Gemma."""
75
  try:
76
  input_text = "Here is a response for the user. " + text
77
  input = tokenizer_gemma(input_text, return_tensors="pt").to(device)
78
+ generated_output = model_gemma.generate(**input, max_length=MAX_GEMMA_LENGTH, early_stopping=True)
79
  return tokenizer_gemma.decode(generated_output[0], skip_special_tokens=True)
80
  except Exception as e:
81
  logging.error(f"Gemma response generation error: {e}")
82
+ return "I'm sorry, I encountered an error generating a response."
83
 
84
  def load_audio(audio_path: str) -> torch.Tensor:
85
  """Loads audio from file and returns a torch tensor."""
86
  try:
87
  audio_tensor, sample_rate = torchaudio.load(audio_path)
88
  audio_tensor = audio_tensor.mean(dim=0) # Mono audio
89
+ if sample_rate != global_generator.sample_rate: #Access via global generator
90
  audio_tensor = torchaudio.functional.resample(
91
+ audio_tensor, orig_freq=sample_rate, new_freq=global_generator.sample_rate
92
  )
93
  return audio_tensor
94
  except Exception as e:
95
  logging.error(f"Audio loading error: {e}")
96
+ raise gr.Error("Could not load or process the audio file.") from e
97
 
98
  def clear_history():
99
  """Clears the conversation history"""
 
103
  return "Conversation history cleared."
104
 
105
  # --- MAIN INFERENCE FUNCTION ---
106
+ @spaces.GPU(gpu_timeout=gpu_timeout)
107
+ def infer(user_audio) -> tuple:
108
  """Infers a response from the user audio."""
109
+ global global_generator, global_whisper_model, global_model_a, global_whisper_metadata, global_tokenizer_gemma, global_model_gemma, device
110
+
111
  try:
112
  if not user_audio:
113
  raise ValueError("No audio input received.")
114
+
115
+ # Load models if not already loaded
116
+ if global_generator is None:
117
+ model_path = hf_hub_download(repo_id="sesame/csm-1b", filename="ckpt.pt")
118
+ global_generator = load_csm_1b(model_path, device)
119
+ logging.info("Sesame CSM 1B loaded successfully on GPU.")
120
+
121
+ if global_whisper_model is None:
122
+ global_whisper_model, global_whisper_metadata = whisperx.load_model("large-v2", device)
123
+ global_model_a, _ = whisperx.load_align_model(language_code=global_whisper_metadata.language, device=device)
124
+ logging.info("WhisperX model loaded successfully on GPU.")
125
+
126
+ if global_tokenizer_gemma is None:
127
+ global_tokenizer_gemma = AutoTokenizer.from_pretrained("google/gemma-3-1b-pt")
128
+ global_model_gemma = AutoModelForCausalLM.from_pretrained("google/gemma-3-1b-pt").to(device)
129
+ logging.info("Gemma 3 1B pt model loaded successfully on GPU.")
130
+
131
+ return _infer(user_audio, global_generator, global_whisper_model, global_model_a, global_whisper_metadata, global_tokenizer_gemma, global_model_gemma)
132
  except Exception as e:
133
+ logging.exception(f"Inference error: {e}")
134
  raise gr.Error(f"An error occurred during processing: {e}")
135
 
136
+ def _infer(user_audio, generator, whisper_model, model_a, whisper_metadata, tokenizer_gemma, model_gemma) -> tuple:
137
  """Processes the user input, generates a response, and returns audio."""
138
+ global conversation_history
139
 
140
  try:
141
  # 1. ASR: Transcribe user audio using WhisperX
142
+ user_text = transcribe_audio(user_audio, whisper_model, model_a, whisper_metadata)
143
  logging.info(f"User: {user_text}")
144
 
145
  # 2. LLM: Generate a response using Gemma
146
+ ai_text = generate_response(user_text, tokenizer_gemma, model_gemma)
147
  logging.info(f"AI: {ai_text}")
148
 
149
  # 3. Generate audio using the CSM model
 
156
  )
157
  logging.info("Audio generated successfully.")
158
  except Exception as e:
159
+ logging.error(f"CSM response generation error: {e}")
160
+ raise gr.Error(f"CSM response generation error: {e}")
161
 
162
  #Update conversation history with user input and ai response.
163
  user_segment = Segment(speaker = SPEAKER_ID, text = 'User Audio', audio = load_audio(user_audio))
 
182
 
183
  except Exception as e:
184
  logging.exception(f"Error in _infer: {e}")
 
 
185
  raise gr.Error(f"An error occurred during processing: {e}")
186
 
187
  # --- GRADIO INTERFACE ---
 
194
 
195
  btn = gr.Button("Generate Response")
196
  btn.click(infer, inputs=[audio_input], outputs=[audio_output])
197
+ clear_button.click(clear_history, outputs=[status_display])
198
 
199
+ app.launch(share=False) #Add share = True for public link