liuyang commited on
Commit
4a29c47
Β·
1 Parent(s): 37d6160

modify workflow

Browse files
Files changed (1) hide show
  1. app.py +100 -168
app.py CHANGED
@@ -35,7 +35,7 @@ pipe = pipeline(
35
  model="openai/whisper-large-v3-turbo",
36
  torch_dtype=torch.float16,
37
  device="cuda",
38
- model_kwargs={"attn_implementation": "sdpa"},#flash_attention_2
39
  return_timestamps=True,
40
  )
41
 
@@ -87,20 +87,41 @@ class WhisperTranscriber:
87
  except subprocess.CalledProcessError as e:
88
  raise RuntimeError(f"Audio conversion failed: {e}")
89
 
90
- @spaces.GPU
91
- def transcribe_audio(self, audio_path, language=None, translate=False, prompt=None):
92
- """Transcribe audio using Whisper with flash attention"""
93
 
94
- '''
95
- #if self.pipe is None:
96
- # self.setup_models()
97
 
98
- if next(self.pipe.model.parameters()).device.type != "cuda":
99
- self.pipe.model.to("cuda")
100
- '''
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
 
102
-
103
- print("Starting transcription...")
 
 
 
 
104
  start_time = time.time()
105
 
106
  # Prepare generation kwargs
@@ -111,47 +132,54 @@ class WhisperTranscriber:
111
  generate_kwargs["task"] = "translate"
112
  if prompt:
113
  generate_kwargs["prompt_ids"] = self.pipe.tokenizer.encode(prompt)
 
 
 
 
114
 
115
- # Transcribe with timestamps
116
- result = self.pipe(
117
- audio_path,
118
- return_timestamps=True,
119
- generate_kwargs=generate_kwargs,
120
- chunk_length_s=30,
121
- batch_size=128,
122
- )
123
- transcription_time = time.time() - start_time
124
- print(f"Transcription completed in {transcription_time:.2f} seconds")
125
- # Extract segments and detected language
126
- segments = []
127
- if "chunks" in result:
128
- for chunk in result["chunks"]:
129
- segment = {
130
- "start": float(chunk["timestamp"][0] or 0),
131
- "end": float(chunk["timestamp"][1] or 0),
132
- "text": chunk["text"].strip(),
133
- }
134
- segments.append(segment)
135
- else:
136
- # Fallback for different result format
137
- segments = [{
138
- "start": 0.0,
139
- "end": 0.0,
140
- "text": result["text"]
141
- }]
142
 
143
- detected_language = getattr(result, 'language', language or 'unknown')
 
 
 
144
 
145
  transcription_time = time.time() - start_time
146
- print(f"Transcription parse completed in {transcription_time:.2f} seconds")
147
 
148
- return segments, detected_language
149
 
 
150
  def perform_diarization(self, audio_path, num_speakers=None):
151
  """Perform speaker diarization"""
152
  if self.diarization_model is None:
153
- print("Diarization model not available, assigning single speaker")
154
- return [], 1
 
 
 
 
 
 
 
155
 
156
  print("Starting diarization...")
157
  start_time = time.time()
@@ -176,7 +204,7 @@ class WhisperTranscriber:
176
  "speaker": speaker
177
  })
178
 
179
- unique_speakers = {speaker for _, _, speaker in diarization_list}
180
  detected_num_speakers = len(unique_speakers)
181
 
182
  diarization_time = time.time() - start_time
@@ -184,129 +212,35 @@ class WhisperTranscriber:
184
 
185
  return diarize_segments, detected_num_speakers
186
 
187
- def merge_transcription_and_diarization(self, transcription_segments, diarization_segments):
188
- """Merge transcription segments with speaker information"""
189
- if not diarization_segments:
190
- # No diarization available, assign single speaker
191
- for segment in transcription_segments:
192
- segment["speaker"] = "SPEAKER_00"
193
- return transcription_segments
194
-
195
- print("Merging transcription and diarization...")
196
- diarize_df = pd.DataFrame(diarization_segments)
197
-
198
- final_segments = []
199
- for segment in transcription_segments:
200
- # Calculate intersection with diarization segments
201
- diarize_df["intersection"] = np.maximum(0,
202
- np.minimum(diarize_df["end"], segment["end"]) -
203
- np.maximum(diarize_df["start"], segment["start"])
204
- )
205
-
206
- # Find speaker with maximum intersection
207
- dia_tmp = diarize_df[diarize_df["intersection"] > 0]
208
- if len(dia_tmp) > 0:
209
- speaker = (
210
- dia_tmp.groupby("speaker")["intersection"]
211
- .sum()
212
- .sort_values(ascending=False)
213
- .index[0]
214
- )
215
- else:
216
- speaker = "SPEAKER_00"
217
-
218
- segment["speaker"] = speaker
219
- segment["duration"] = segment["end"] - segment["start"]
220
- final_segments.append(segment)
221
-
222
- return final_segments
223
-
224
- def group_segments_by_speaker(self, segments, max_gap=1.0, max_duration=30.0):
225
- """Group consecutive segments from the same speaker"""
226
- if not segments:
227
- return segments
228
-
229
- grouped_segments = []
230
- current_group = segments[0].copy()
231
- sentence_end_pattern = r"[.!?]+\s*$"
232
-
233
- for segment in segments[1:]:
234
- time_gap = segment["start"] - current_group["end"]
235
- current_duration = current_group["end"] - current_group["start"]
236
-
237
- # Conditions for combining segments
238
- can_combine = (
239
- segment["speaker"] == current_group["speaker"] and
240
- time_gap <= max_gap and
241
- current_duration < max_duration and
242
- not re.search(sentence_end_pattern, current_group["text"])
243
- )
244
-
245
- if can_combine:
246
- # Merge segments
247
- current_group["end"] = segment["end"]
248
- current_group["text"] += " " + segment["text"]
249
- current_group["duration"] = current_group["end"] - current_group["start"]
250
- else:
251
- # Start new group
252
- grouped_segments.append(current_group)
253
- current_group = segment.copy()
254
-
255
- grouped_segments.append(current_group)
256
-
257
- # Clean up text
258
- for segment in grouped_segments:
259
- segment["text"] = re.sub(r"\s+", " ", segment["text"]).strip()
260
- segment["text"] = re.sub(r"\s+([.,!?])", r"\1", segment["text"])
261
-
262
- return grouped_segments
263
-
264
  @spaces.GPU
265
  def process_audio(self, audio_file, num_speakers=None, language=None,
266
  translate=False, prompt=None, group_segments=True):
267
- """Main processing function"""
268
  if audio_file is None:
269
  return {"error": "No audio file provided"}
270
 
271
  try:
272
- # Setup models if not already done
273
- #self.setup_models()
274
 
275
- # Convert audio format
276
- #wav_path = self.convert_audio_format(audio_file)
 
 
277
 
278
- try:
279
- # Transcribe audio
280
- transcription_segments, detected_language = self.transcribe_audio(
281
- audio_file, language, translate, prompt
282
- )
283
-
284
- # Perform diarization
285
- diarization_segments, detected_num_speakers = self.perform_diarization(
286
- audio_file, num_speakers
287
- )
288
-
289
- # Merge transcription and diarization
290
- final_segments = self.merge_transcription_and_diarization(
291
- transcription_segments, diarization_segments
292
- )
293
-
294
- # Group segments if requested
295
- if group_segments:
296
- final_segments = self.group_segments_by_speaker(final_segments)
297
-
298
- return {
299
- "segments": final_segments,
300
- "language": detected_language,
301
- "num_speakers": detected_num_speakers or 1,
302
- "total_segments": len(final_segments)
303
- }
304
 
305
- finally:
306
- # Clean up temporary file
307
- if os.path.exists(audio_file):
308
- os.unlink(audio_file)
309
-
310
  except Exception as e:
311
  import traceback
312
  traceback.print_exc()
@@ -320,21 +254,19 @@ def format_segments_for_display(result):
320
  if "error" in result:
321
  return f"❌ Error: {result['error']}"
322
 
323
- segments = result.get("segments", [])
324
- language = result.get("language", "unknown")
325
- num_speakers = result.get("num_speakers", 1)
326
 
327
  output = f"🎯 **Detection Results:**\n"
328
- output += f"- Language: {language}\n"
329
- output += f"- Speakers: {num_speakers}\n"
330
- output += f"- Segments: {len(segments)}\n\n"
331
 
332
  output += "πŸ“ **Transcription:**\n\n"
333
 
334
- for i, segment in enumerate(segments, 1):
335
- start_time = str(datetime.timedelta(seconds=int(segment["start"])))
336
- end_time = str(datetime.timedelta(seconds=int(segment["end"])))
337
- speaker = segment.get("speaker", "SPEAKER_00")
338
  text = segment["text"]
339
 
340
  output += f"**{speaker}** ({start_time} β†’ {end_time})\n"
 
35
  model="openai/whisper-large-v3-turbo",
36
  torch_dtype=torch.float16,
37
  device="cuda",
38
+ model_kwargs={"attn_implementation": "flash_attention_2"},#flash_attention_2
39
  return_timestamps=True,
40
  )
41
 
 
87
  except subprocess.CalledProcessError as e:
88
  raise RuntimeError(f"Audio conversion failed: {e}")
89
 
90
+ def cut_audio_segments(self, audio_path, diarization_segments):
91
+ """Cut audio into segments based on diarization results"""
92
+ print("Cutting audio into segments...")
93
 
94
+ # Load the full audio
95
+ waveform, sample_rate = torchaudio.load(audio_path)
 
96
 
97
+ audio_segments = []
98
+ for segment in diarization_segments:
99
+ start_sample = int(segment["start"] * sample_rate)
100
+ end_sample = int(segment["end"] * sample_rate)
101
+
102
+ # Extract the segment
103
+ segment_waveform = waveform[:, start_sample:end_sample]
104
+
105
+ # Create temporary file for this segment
106
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
107
+ temp_file.close()
108
+
109
+ # Save the segment
110
+ torchaudio.save(temp_file.name, segment_waveform, sample_rate)
111
+
112
+ audio_segments.append({
113
+ "audio_path": temp_file.name,
114
+ "start": segment["start"],
115
+ "end": segment["end"],
116
+ "speaker": segment["speaker"]
117
+ })
118
 
119
+ return audio_segments
120
+
121
+ @spaces.GPU
122
+ def transcribe_audio_segments(self, audio_segments, language=None, translate=False, prompt=None):
123
+ """Transcribe multiple audio segments"""
124
+ print(f"Transcribing {len(audio_segments)} audio segments...")
125
  start_time = time.time()
126
 
127
  # Prepare generation kwargs
 
132
  generate_kwargs["task"] = "translate"
133
  if prompt:
134
  generate_kwargs["prompt_ids"] = self.pipe.tokenizer.encode(prompt)
135
+
136
+ results = []
137
+ for i, segment in enumerate(audio_segments):
138
+ print(f"Processing segment {i+1}/{len(audio_segments)}")
139
 
140
+ # Transcribe this segment
141
+ result = self.pipe(
142
+ segment["audio_path"],
143
+ return_timestamps=True,
144
+ generate_kwargs=generate_kwargs,
145
+ chunk_length_s=30,
146
+ batch_size=128,
147
+ )
148
+
149
+ # Extract text
150
+ text = result["text"].strip() if "text" in result else ""
151
+
152
+ # Create result entry
153
+ results.append({
154
+ "start_time": segment["start"],
155
+ "end_time": segment["end"],
156
+ "speaker_label": segment["speaker"],
157
+ "text": text
158
+ })
 
 
 
 
 
 
 
 
159
 
160
+ # Clean up temporary files
161
+ for segment in audio_segments:
162
+ if os.path.exists(segment["audio_path"]):
163
+ os.unlink(segment["audio_path"])
164
 
165
  transcription_time = time.time() - start_time
166
+ print(f"All segments transcribed in {transcription_time:.2f} seconds")
167
 
168
+ return results
169
 
170
+ @spaces.GPU
171
  def perform_diarization(self, audio_path, num_speakers=None):
172
  """Perform speaker diarization"""
173
  if self.diarization_model is None:
174
+ print("Diarization model not available, creating single speaker segment")
175
+ # Load audio to get duration
176
+ waveform, sample_rate = torchaudio.load(audio_path)
177
+ duration = waveform.shape[1] / sample_rate
178
+ return [{
179
+ "start": 0.0,
180
+ "end": duration,
181
+ "speaker": "SPEAKER_00"
182
+ }], 1
183
 
184
  print("Starting diarization...")
185
  start_time = time.time()
 
204
  "speaker": speaker
205
  })
206
 
207
+ unique_speakers = {speaker for segment in diarize_segments for speaker in [segment["speaker"]]}
208
  detected_num_speakers = len(unique_speakers)
209
 
210
  diarization_time = time.time() - start_time
 
212
 
213
  return diarize_segments, detected_num_speakers
214
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
215
  @spaces.GPU
216
  def process_audio(self, audio_file, num_speakers=None, language=None,
217
  translate=False, prompt=None, group_segments=True):
218
+ """Main processing function - diarization first, then transcription"""
219
  if audio_file is None:
220
  return {"error": "No audio file provided"}
221
 
222
  try:
223
+ print("Starting new processing pipeline...")
 
224
 
225
+ # Step 1: Perform diarization first
226
+ diarization_segments, detected_num_speakers = self.perform_diarization(
227
+ audio_file, num_speakers
228
+ )
229
 
230
+ # Step 2: Cut audio into segments based on diarization
231
+ audio_segments = self.cut_audio_segments(audio_file, diarization_segments)
232
+
233
+ # Step 3: Transcribe each segment
234
+ transcription_results = self.transcribe_audio_segments(
235
+ audio_segments, language, translate, prompt
236
+ )
237
+
238
+ # Step 4: Return in requested format
239
+ return {
240
+ "speaker_count": detected_num_speakers,
241
+ "transcription": transcription_results
242
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
243
 
 
 
 
 
 
244
  except Exception as e:
245
  import traceback
246
  traceback.print_exc()
 
254
  if "error" in result:
255
  return f"❌ Error: {result['error']}"
256
 
257
+ speaker_count = result.get("speaker_count", 1)
258
+ transcription = result.get("transcription", [])
 
259
 
260
  output = f"🎯 **Detection Results:**\n"
261
+ output += f"- Speakers: {speaker_count}\n"
262
+ output += f"- Segments: {len(transcription)}\n\n"
 
263
 
264
  output += "πŸ“ **Transcription:**\n\n"
265
 
266
+ for i, segment in enumerate(transcription, 1):
267
+ start_time = str(datetime.timedelta(seconds=int(segment["start_time"])))
268
+ end_time = str(datetime.timedelta(seconds=int(segment["end_time"])))
269
+ speaker = segment.get("speaker_label", "SPEAKER_00")
270
  text = segment["text"]
271
 
272
  output += f"**{speaker}** ({start_time} β†’ {end_time})\n"