Spaces:
Bradarr
/
Running on Zero

Bradarr commited on
Commit
cc8bd68
·
verified ·
1 Parent(s): c51f2cc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -90
app.py CHANGED
@@ -7,14 +7,14 @@ import torchaudio
7
  from generator import Segment, load_csm_1b
8
  from huggingface_hub import hf_hub_download, login
9
  from watermarking import watermark
10
- import whisperx
11
  from transformers import AutoTokenizer, AutoModelForCausalLM
12
  import logging
13
 
14
  # Configure logging
15
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
16
 
17
- # Authentication and Configuration
18
  try:
19
  api_key = os.getenv("HF_TOKEN")
20
  if not api_key:
@@ -30,55 +30,37 @@ except (ValueError, TypeError) as e:
30
  logging.error(f"Configuration error: {e}")
31
  raise
32
 
33
- SPACE_INTRO_TEXT = """\
34
  # Sesame CSM 1B - Conversational Demo
35
 
36
- This demo allows you to have a conversation with Sesame CSM 1B, leveraging WhisperX for speech-to-text and Gemma for generating responses. This is an experimental integration and may require significant resources.
37
 
38
- *Disclaimer: This demo relies on several large models. Expect longer processing times, and potential resource limitations.*
39
  """
40
 
41
- # Constants
42
- SPEAKER_ID = 0 # Arbitrary speaker ID
 
 
43
  MAX_CONTEXT_SEGMENTS = 5
44
  MAX_GEMMA_LENGTH = 300
45
- device = "cuda" # if torch.cuda.is_available() else "cpu"
46
 
47
- # Global conversation history
48
  conversation_history = []
49
 
50
- # Global variables to hold loaded models
51
- global_generator = None
52
- global_whisper_model = None
53
- global_model_a = None
54
- # global_whisper_metadata = None # No longer needed at the global level
55
- global_tokenizer_gemma = None
56
- global_model_gemma = None
57
-
58
-
59
- # --- HELPER FUNCTIONS ---
60
 
61
- def transcribe_audio(audio_path: str, whisper_model, model_a) -> str: # Removed whisper_metadata
62
- """Transcribes audio using WhisperX and aligns it."""
63
  try:
64
- audio = whisperx.load_audio(audio_path)
65
- result = whisper_model.transcribe(audio, batch_size=16)
66
- # Get language from the result. Much more reliable.
67
- language = result["language"]
68
-
69
-
70
- # Align Whisper output
71
- model_a, metadata = whisperx.load_align_model(language_code=language, device=device) #Load it here to ensure metadata is extracted.
72
- result_aligned = whisperx.align(result["segments"], model_a, metadata, audio, device, return_char_alignments=False)
73
-
74
- return result_aligned["segments"][0]["text"]
75
  except Exception as e:
76
- logging.error(f"WhisperX transcription error: {e}")
77
  return "Error: Could not transcribe audio."
78
 
79
-
80
- def generate_response(text: str, tokenizer_gemma, model_gemma) -> str:
81
- """Generates a response using Gemma."""
82
  try:
83
  input_text = "Here is a response for the user. " + text
84
  input = tokenizer_gemma(input_text, return_tensors="pt").to(device)
@@ -88,94 +70,88 @@ def generate_response(text: str, tokenizer_gemma, model_gemma) -> str:
88
  logging.error(f"Gemma response generation error: {e}")
89
  return "I'm sorry, I encountered an error generating a response."
90
 
91
-
92
- def load_audio(audio_path: str) -> torch.Tensor:
93
- """Loads audio from file and returns a torch tensor."""
94
  try:
95
  audio_tensor, sample_rate = torchaudio.load(audio_path)
96
- audio_tensor = audio_tensor.mean(dim=0) # Mono audio
97
- if sample_rate != global_generator.sample_rate:
98
- audio_tensor = torchaudio.functional.resample(
99
- audio_tensor, orig_freq=sample_rate, new_freq=global_generator.sample_rate
100
- )
101
  return audio_tensor
102
  except Exception as e:
103
  logging.error(f"Audio loading error: {e}")
104
  raise gr.Error("Could not load or process the audio file.") from e
105
 
106
-
107
  def clear_history():
108
- """Clears the conversation history"""
109
  global conversation_history
110
  conversation_history = []
111
  logging.info("Conversation history cleared.")
112
  return "Conversation history cleared."
113
 
 
114
 
115
- # --- MAIN INFERENCE FUNCTION ---
116
-
117
- @spaces.GPU(gpu_timeout=gpu_timeout)
118
- def infer(user_audio) -> tuple:
119
- """Infers a response from the user audio."""
120
- global global_generator, global_whisper_model, global_model_a, global_tokenizer_gemma, global_model_gemma, device
 
 
 
 
 
121
 
122
  try:
 
 
 
 
 
 
 
 
 
 
 
 
123
  if not user_audio:
124
  raise ValueError("No audio input received.")
125
-
126
- # Load models if not already loaded
127
- if global_generator is None:
128
- model_path = hf_hub_download(repo_id="sesame/csm-1b", filename="ckpt.pt")
129
- global_generator = load_csm_1b(model_path, device)
130
- logging.info("Sesame CSM 1B loaded successfully on GPU.")
131
-
132
- if global_whisper_model is None:
133
- global_whisper_model = whisperx.load_model("large-v2", device) # No unpacking
134
- logging.info("WhisperX model loaded successfully on GPU.")
135
-
136
- if global_tokenizer_gemma is None:
137
- global_tokenizer_gemma = AutoTokenizer.from_pretrained("google/gemma-3-1b-pt")
138
- global_model_gemma = AutoModelForCausalLM.from_pretrained("google/gemma-3-1b-pt").to(device)
139
- logging.info("Gemma 3 1B pt model loaded successfully on GPU.")
140
- return _infer(user_audio, global_generator, global_whisper_model, global_model_a, global_tokenizer_gemma, global_model_gemma) #Removed Metadata
141
  except Exception as e:
142
  logging.exception(f"Inference error: {e}")
143
  raise gr.Error(f"An error occurred during processing: {e}")
144
 
145
-
146
- def _infer(user_audio, generator, whisper_model, model_a, tokenizer_gemma, model_gemma) -> tuple:
147
- """Processes the user input, generates a response, and returns audio."""
148
  global conversation_history
149
 
150
  try:
151
- # 1. ASR: Transcribe user audio using WhisperX
152
- user_text = transcribe_audio(user_audio, whisper_model, model_a) #Removed Metadata
153
  logging.info(f"User: {user_text}")
154
 
155
- # 2. LLM: Generate a response using Gemma
156
- ai_text = generate_response(user_text, tokenizer_gemma, model_gemma)
157
  logging.info(f"AI: {ai_text}")
158
 
159
- # 3. Generate audio using the CSM model
160
- ai_audio = generator.generate(
161
- text=ai_text,
162
- speaker=SPEAKER_ID,
163
- context=conversation_history,
164
- max_audio_length_ms=30_000,
165
- )
166
- logging.info("Audio generated successfully.")
 
 
 
167
 
168
- #Update conversation history with user input and ai response.
169
- user_segment = Segment(speaker = SPEAKER_ID, text = 'User Audio', audio = load_audio(user_audio))
170
  ai_segment = Segment(speaker = SPEAKER_ID, text = 'AI Audio', audio = ai_audio)
171
  conversation_history.append(user_segment)
172
  conversation_history.append(ai_segment)
173
 
174
- #Limit Conversation History
175
  if len(conversation_history) > MAX_CONTEXT_SEGMENTS:
176
  conversation_history.pop(0)
177
 
178
- # 4. Watermarking and Audio Conversion
179
  audio_tensor, wm_sample_rate = watermark(
180
  generator._watermarker, ai_audio, generator.sample_rate, CSM_1B_HF_WATERMARK
181
  )
@@ -190,8 +166,7 @@ def _infer(user_audio, generator, whisper_model, model_a, tokenizer_gemma, model
190
  logging.exception(f"Error in _infer: {e}")
191
  raise gr.Error(f"An error occurred during processing: {e}")
192
 
193
-
194
- # --- GRADIO INTERFACE ---
195
 
196
  with gr.Blocks() as app:
197
  gr.Markdown(SPACE_INTRO_TEXT)
@@ -204,4 +179,4 @@ with gr.Blocks() as app:
204
  btn.click(infer, inputs=[audio_input], outputs=[audio_output])
205
  clear_button.click(clear_history, outputs=[status_display])
206
 
207
- app.launch(share=False)
 
7
  from generator import Segment, load_csm_1b
8
  from huggingface_hub import hf_hub_download, login
9
  from watermarking import watermark
10
+ import whisper
11
  from transformers import AutoTokenizer, AutoModelForCausalLM
12
  import logging
13
 
14
  # Configure logging
15
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
16
 
17
+ # --- Authentication and Configuration --- (Moved BEFORE model loading)
18
  try:
19
  api_key = os.getenv("HF_TOKEN")
20
  if not api_key:
 
30
  logging.error(f"Configuration error: {e}")
31
  raise
32
 
33
+ SPACE_INTRO_TEXT = """
34
  # Sesame CSM 1B - Conversational Demo
35
 
36
+ This demo allows you to have a conversation with Sesame CSM 1B, leveraging Whisper for speech-to-text and Gemma for generating responses. This is an experimental integration and may require significant resources.
37
 
38
+ *Disclaimer: This demo relies on several large models. Expect longer processing times, and potential resource limitations.*
39
  """
40
 
41
+ # --- Model Loading --- (Moved INSIDE infer function)
42
+
43
+ # --- Constants --- (Constants can stay outside)
44
+ SPEAKER_ID = 0
45
  MAX_CONTEXT_SEGMENTS = 5
46
  MAX_GEMMA_LENGTH = 300
 
47
 
48
+ # --- Global Conversation History ---
49
  conversation_history = []
50
 
51
+ # --- Helper Functions ---
 
 
 
 
 
 
 
 
 
52
 
53
+ def transcribe_audio(audio_path: str, whisper_model) -> str: # Pass whisper_model
 
54
  try:
55
+ audio = whisper.load_audio(audio_path)
56
+ audio = whisper.pad_or_trim(audio)
57
+ result = whisper_model.transcribe(audio)
58
+ return result["text"]
 
 
 
 
 
 
 
59
  except Exception as e:
60
+ logging.error(f"Whisper transcription error: {e}")
61
  return "Error: Could not transcribe audio."
62
 
63
+ def generate_response(text: str, model_gemma, tokenizer_gemma, device) -> str: # Pass model and tokenizer
 
 
64
  try:
65
  input_text = "Here is a response for the user. " + text
66
  input = tokenizer_gemma(input_text, return_tensors="pt").to(device)
 
70
  logging.error(f"Gemma response generation error: {e}")
71
  return "I'm sorry, I encountered an error generating a response."
72
 
73
+ def load_audio(audio_path: str, generator) -> torch.Tensor: #Pass generator
 
 
74
  try:
75
  audio_tensor, sample_rate = torchaudio.load(audio_path)
76
+ audio_tensor = audio_tensor.mean(dim=0)
77
+ if sample_rate != generator.sample_rate:
78
+ audio_tensor = torchaudio.functional.resample(audio_tensor, orig_freq=sample_rate, new_freq=generator.sample_rate)
 
 
79
  return audio_tensor
80
  except Exception as e:
81
  logging.error(f"Audio loading error: {e}")
82
  raise gr.Error("Could not load or process the audio file.") from e
83
 
 
84
  def clear_history():
 
85
  global conversation_history
86
  conversation_history = []
87
  logging.info("Conversation history cleared.")
88
  return "Conversation history cleared."
89
 
90
+ # --- Main Inference Function ---
91
 
92
+ @spaces.GPU(duration=gpu_timeout) # Decorator FIRST
93
+ def infer(user_audio) -> tuple[int, np.ndarray]:
94
+ # --- CUDA Availability Check (INSIDE infer) ---
95
+ if torch.cuda.is_available():
96
+ print(f"CUDA is available! Device count: {torch.cuda.device_count()}")
97
+ print(f"CUDA device name: {torch.cuda.get_device_name(0)}")
98
+ print(f"CUDA version: {torch.version.cuda}")
99
+ device = "cuda"
100
+ else:
101
+ print("CUDA is NOT available. Using CPU.") # Use CPU, don't raise
102
+ device = "cpu"
103
 
104
  try:
105
+ # --- Model Loading (INSIDE infer, after device is set) ---
106
+ model_path = hf_hub_download(repo_id="sesame/csm-1b", filename="ckpt.pt")
107
+ generator = load_csm_1b(model_path, device)
108
+ logging.info("Sesame CSM 1B loaded successfully.")
109
+
110
+ whisper_model = whisper.load_model("large-v2", device=device)
111
+ logging.info("Whisper model loaded successfully.")
112
+
113
+ tokenizer_gemma = AutoTokenizer.from_pretrained("google/gemma-3-1b-pt")
114
+ model_gemma = AutoModelForCausalLM.from_pretrained("google/gemma-3-1b-pt").to(device)
115
+ logging.info("Gemma 3 1B pt model loaded successfully.")
116
+
117
  if not user_audio:
118
  raise ValueError("No audio input received.")
119
+ return _infer(user_audio, generator, whisper_model, tokenizer_gemma, model_gemma, device) #Pass all models
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
  except Exception as e:
121
  logging.exception(f"Inference error: {e}")
122
  raise gr.Error(f"An error occurred during processing: {e}")
123
 
124
+ def _infer(user_audio, generator, whisper_model, tokenizer_gemma, model_gemma, device) -> tuple[int, np.ndarray]:
 
 
125
  global conversation_history
126
 
127
  try:
128
+ user_text = transcribe_audio(user_audio, whisper_model) # Pass whisper_model
 
129
  logging.info(f"User: {user_text}")
130
 
131
+ ai_text = generate_response(user_text, model_gemma, tokenizer_gemma, device) # Pass model and tokenizer
 
132
  logging.info(f"AI: {ai_text}")
133
 
134
+ try:
135
+ ai_audio = generator.generate(
136
+ text=ai_text,
137
+ speaker=SPEAKER_ID,
138
+ context=conversation_history,
139
+ max_audio_length_ms=30_000,
140
+ )
141
+ logging.info("Audio generated successfully.")
142
+ except Exception as e:
143
+ logging.error(f"Sesame response generation error: {e}")
144
+ raise gr.Error(f"Sesame response generation error: {e}")
145
 
146
+
147
+ user_segment = Segment(speaker = SPEAKER_ID, text = 'User Audio', audio = load_audio(user_audio, generator)) #Pass Generator
148
  ai_segment = Segment(speaker = SPEAKER_ID, text = 'AI Audio', audio = ai_audio)
149
  conversation_history.append(user_segment)
150
  conversation_history.append(ai_segment)
151
 
 
152
  if len(conversation_history) > MAX_CONTEXT_SEGMENTS:
153
  conversation_history.pop(0)
154
 
 
155
  audio_tensor, wm_sample_rate = watermark(
156
  generator._watermarker, ai_audio, generator.sample_rate, CSM_1B_HF_WATERMARK
157
  )
 
166
  logging.exception(f"Error in _infer: {e}")
167
  raise gr.Error(f"An error occurred during processing: {e}")
168
 
169
+ # --- Gradio Interface ---
 
170
 
171
  with gr.Blocks() as app:
172
  gr.Markdown(SPACE_INTRO_TEXT)
 
179
  btn.click(infer, inputs=[audio_input], outputs=[audio_output])
180
  clear_button.click(clear_history, outputs=[status_display])
181
 
182
+ app.launch(ssr_mode=False)