danelkay commited on
Commit
491cd72
·
1 Parent(s): ed2df19

Injected gradio-lite via gr.HTML(), added WebGPU support

Browse files
Files changed (1) hide show
  1. app.py +95 -767
app.py CHANGED
@@ -1,781 +1,109 @@
1
- from datetime import datetime
2
- import json
3
- import math
4
- from typing import Callable, Iterator, Union
5
- import argparse
6
-
7
- from io import StringIO
8
  import os
9
- import pathlib
10
  import tempfile
11
- import zipfile
12
- import numpy as np
13
-
14
- import torch
15
-
16
- from src.config import VAD_INITIAL_PROMPT_MODE_VALUES, ApplicationConfig, VadInitialPromptMode
17
- from src.diarization.diarization import Diarization
18
- from src.diarization.diarizationContainer import DiarizationContainer
19
- from src.diarization.transcriptLoader import load_transcript
20
- from src.hooks.progressListener import ProgressListener
21
- from src.hooks.subTaskProgressListener import SubTaskProgressListener
22
- from src.languages import get_language_names
23
- from src.modelCache import ModelCache
24
- from src.prompts.jsonPromptStrategy import JsonPromptStrategy
25
- from src.prompts.prependPromptStrategy import PrependPromptStrategy
26
- from src.source import AudioSource, get_audio_source_collection
27
- from src.vadParallel import ParallelContext, ParallelTranscription
28
-
29
- # External programs
30
  import ffmpeg
31
-
32
- # UI
33
- import gradio as gr
34
-
35
- from src.download import ExceededMaximumDuration, download_url
36
- from src.utils import optional_int, slugify, str2bool, write_srt, write_vtt
37
- from src.vad import AbstractTranscription, NonSpeechStrategy, PeriodicTranscriptionConfig, TranscriptionConfig, VadPeriodicTranscription, VadSileroTranscription
38
- from src.whisper.abstractWhisperContainer import AbstractWhisperContainer
39
- from src.whisper.whisperFactory import create_whisper_container
40
-
41
- # Configure more application defaults in config.json5
42
-
43
- # Gradio seems to truncate files without keeping the extension, so we need to truncate the file prefix ourself
44
- MAX_FILE_PREFIX_LENGTH = 17
45
-
46
- # Limit auto_parallel to a certain number of CPUs (specify vad_cpu_cores to get a higher number)
47
- MAX_AUTO_CPU_CORES = 8
48
-
49
- WHISPER_MODELS = ["tiny", "base", "small", "medium", "large", "large-v1", "large-v2"]
50
-
51
- class VadOptions:
52
- def __init__(self, vad: str = None, vadMergeWindow: float = 5, vadMaxMergeSize: float = 150, vadPadding: float = 1, vadPromptWindow: float = 1,
53
- vadInitialPromptMode: Union[VadInitialPromptMode, str] = VadInitialPromptMode.PREPREND_FIRST_SEGMENT):
54
- self.vad = vad
55
- self.vadMergeWindow = vadMergeWindow
56
- self.vadMaxMergeSize = vadMaxMergeSize
57
- self.vadPadding = vadPadding
58
- self.vadPromptWindow = vadPromptWindow
59
- self.vadInitialPromptMode = vadInitialPromptMode if isinstance(vadInitialPromptMode, VadInitialPromptMode) \
60
- else VadInitialPromptMode.from_string(vadInitialPromptMode)
61
-
62
- class WhisperTranscriber:
63
- def __init__(self, input_audio_max_duration: float = None, vad_process_timeout: float = None,
64
- vad_cpu_cores: int = 1, delete_uploaded_files: bool = False, output_dir: str = None,
65
- app_config: ApplicationConfig = None):
66
- self.model_cache = ModelCache()
67
- self.parallel_device_list = None
68
- self.gpu_parallel_context = None
69
- self.cpu_parallel_context = None
70
- self.vad_process_timeout = vad_process_timeout
71
- self.vad_cpu_cores = vad_cpu_cores
72
-
73
- self.vad_model = None
74
- self.inputAudioMaxDuration = input_audio_max_duration
75
- self.deleteUploadedFiles = delete_uploaded_files
76
- self.output_dir = output_dir
77
-
78
- # Support for diarization
79
- self.diarization: DiarizationContainer = None
80
- # Dictionary with parameters to pass to diarization.run - if None, diarization is not enabled
81
- self.diarization_kwargs = None
82
- self.app_config = app_config
83
-
84
- def set_parallel_devices(self, vad_parallel_devices: str):
85
- self.parallel_device_list = [ device.strip() for device in vad_parallel_devices.split(",") ] if vad_parallel_devices else None
86
-
87
- def set_auto_parallel(self, auto_parallel: bool):
88
- if auto_parallel:
89
- if torch.cuda.is_available():
90
- self.parallel_device_list = [ str(gpu_id) for gpu_id in range(torch.cuda.device_count())]
91
-
92
- self.vad_cpu_cores = min(os.cpu_count(), MAX_AUTO_CPU_CORES)
93
- print("[Auto parallel] Using GPU devices " + str(self.parallel_device_list) + " and " + str(self.vad_cpu_cores) + " CPU cores for VAD/transcription.")
94
-
95
- def set_diarization(self, auth_token: str, enable_daemon_process: bool = True, **kwargs):
96
- if self.diarization is None:
97
- self.diarization = DiarizationContainer(auth_token=auth_token, enable_daemon_process=enable_daemon_process,
98
- auto_cleanup_timeout_seconds=self.app_config.diarization_process_timeout,
99
- cache=self.model_cache)
100
- # Set parameters
101
- self.diarization_kwargs = kwargs
102
-
103
- def unset_diarization(self):
104
- if self.diarization is not None:
105
- self.diarization.cleanup()
106
- self.diarization_kwargs = None
107
-
108
- # Entry function for the simple tab
109
- def transcribe_webui_simple(self, modelName, languageName, urlData, multipleFiles, microphoneData, task,
110
- vad, vadMergeWindow, vadMaxMergeSize,
111
- word_timestamps: bool = False, highlight_words: bool = False,
112
- diarization: bool = False, diarization_speakers: int = 2):
113
- return self.transcribe_webui_simple_progress(modelName, languageName, urlData, multipleFiles, microphoneData, task,
114
- vad, vadMergeWindow, vadMaxMergeSize,
115
- word_timestamps, highlight_words,
116
- diarization, diarization_speakers)
117
-
118
- # Entry function for the simple tab progress
119
- def transcribe_webui_simple_progress(self, modelName, languageName, urlData, multipleFiles, microphoneData, task,
120
- vad, vadMergeWindow, vadMaxMergeSize,
121
- word_timestamps: bool = False, highlight_words: bool = False,
122
- diarization: bool = False, diarization_speakers: int = 2,
123
- progress=gr.Progress()):
124
-
125
- vadOptions = VadOptions(vad, vadMergeWindow, vadMaxMergeSize, self.app_config.vad_padding, self.app_config.vad_prompt_window, self.app_config.vad_initial_prompt_mode)
126
-
127
- if diarization:
128
- self.set_diarization(auth_token=self.app_config.auth_token, num_speakers=diarization_speakers)
129
- else:
130
- self.unset_diarization()
131
-
132
- return self.transcribe_webui(modelName, languageName, urlData, multipleFiles, microphoneData, task, vadOptions,
133
- word_timestamps=word_timestamps, highlight_words=highlight_words, progress=progress)
134
-
135
- # Entry function for the full tab
136
- def transcribe_webui_full(self, modelName, languageName, urlData, multipleFiles, microphoneData, task,
137
- vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow, vadInitialPromptMode,
138
- # Word timestamps
139
- word_timestamps: bool, highlight_words: bool, prepend_punctuations: str, append_punctuations: str,
140
- initial_prompt: str, temperature: float, best_of: int, beam_size: int, patience: float, length_penalty: float, suppress_tokens: str,
141
- condition_on_previous_text: bool, fp16: bool, temperature_increment_on_fallback: float,
142
- compression_ratio_threshold: float, logprob_threshold: float, no_speech_threshold: float,
143
- diarization: bool = False, diarization_speakers: int = 2,
144
- diarization_min_speakers = 1, diarization_max_speakers = 5):
145
-
146
- return self.transcribe_webui_full_progress(modelName, languageName, urlData, multipleFiles, microphoneData, task,
147
- vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow, vadInitialPromptMode,
148
- word_timestamps, highlight_words, prepend_punctuations, append_punctuations,
149
- initial_prompt, temperature, best_of, beam_size, patience, length_penalty, suppress_tokens,
150
- condition_on_previous_text, fp16, temperature_increment_on_fallback,
151
- compression_ratio_threshold, logprob_threshold, no_speech_threshold,
152
- diarization, diarization_speakers,
153
- diarization_min_speakers, diarization_max_speakers)
154
-
155
- # Entry function for the full tab with progress
156
- def transcribe_webui_full_progress(self, modelName, languageName, urlData, multipleFiles, microphoneData, task,
157
- vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow, vadInitialPromptMode,
158
- # Word timestamps
159
- word_timestamps: bool, highlight_words: bool, prepend_punctuations: str, append_punctuations: str,
160
- initial_prompt: str, temperature: float, best_of: int, beam_size: int, patience: float, length_penalty: float, suppress_tokens: str,
161
- condition_on_previous_text: bool, fp16: bool, temperature_increment_on_fallback: float,
162
- compression_ratio_threshold: float, logprob_threshold: float, no_speech_threshold: float,
163
- diarization: bool = False, diarization_speakers: int = 2,
164
- diarization_min_speakers = 1, diarization_max_speakers = 5,
165
- progress=gr.Progress()):
166
-
167
- # Handle temperature_increment_on_fallback
168
- if temperature_increment_on_fallback is not None:
169
- temperature = tuple(np.arange(temperature, 1.0 + 1e-6, temperature_increment_on_fallback))
170
- else:
171
- temperature = [temperature]
172
-
173
- vadOptions = VadOptions(vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow, vadInitialPromptMode)
174
-
175
- # Set diarization
176
- if diarization:
177
- self.set_diarization(auth_token=self.app_config.auth_token, num_speakers=diarization_speakers,
178
- min_speakers=diarization_min_speakers, max_speakers=diarization_max_speakers)
179
- else:
180
- self.unset_diarization()
181
-
182
- return self.transcribe_webui(modelName, languageName, urlData, multipleFiles, microphoneData, task, vadOptions,
183
- initial_prompt=initial_prompt, temperature=temperature, best_of=best_of, beam_size=beam_size, patience=patience, length_penalty=length_penalty, suppress_tokens=suppress_tokens,
184
- condition_on_previous_text=condition_on_previous_text, fp16=fp16,
185
- compression_ratio_threshold=compression_ratio_threshold, logprob_threshold=logprob_threshold, no_speech_threshold=no_speech_threshold,
186
- word_timestamps=word_timestamps, prepend_punctuations=prepend_punctuations, append_punctuations=append_punctuations, highlight_words=highlight_words,
187
- progress=progress)
188
-
189
- # Perform diarization given a specific input audio file and whisper file
190
- def perform_extra(self, languageName, urlData, singleFile, whisper_file: str,
191
- highlight_words: bool = False,
192
- diarization: bool = False, diarization_speakers: int = 2, diarization_min_speakers = 1, diarization_max_speakers = 5, progress=gr.Progress()):
193
-
194
- if whisper_file is None:
195
- raise ValueError("whisper_file is required")
196
-
197
- # Set diarization
198
- if diarization:
199
- self.set_diarization(auth_token=self.app_config.auth_token, num_speakers=diarization_speakers,
200
- min_speakers=diarization_min_speakers, max_speakers=diarization_max_speakers)
201
- else:
202
- self.unset_diarization()
203
-
204
- def custom_transcribe_file(source: AudioSource):
205
- result = load_transcript(whisper_file.name)
206
-
207
- # Set language if not set
208
- if not "language" in result:
209
- result["language"] = languageName
210
-
211
- # Mark speakers
212
- result = self._handle_diarization(source.source_path, result)
213
- return result
214
-
215
- multipleFiles = [singleFile] if singleFile else None
216
-
217
- # Will return download, text, vtt
218
- return self.transcribe_webui("base", "", urlData, multipleFiles, None, None, None,
219
- progress=progress,highlight_words=highlight_words,
220
- override_transcribe_file=custom_transcribe_file, override_max_sources=1)
221
-
222
- def transcribe_webui(self, modelName, languageName, urlData, multipleFiles, microphoneData, task,
223
- vadOptions: VadOptions, progress: gr.Progress = None, highlight_words: bool = False,
224
- override_transcribe_file: Callable[[AudioSource], dict] = None, override_max_sources = None,
225
- **decodeOptions: dict):
226
- try:
227
- sources = self.__get_source(urlData, multipleFiles, microphoneData)
228
-
229
- if override_max_sources is not None and len(sources) > override_max_sources:
230
- raise ValueError("Maximum number of sources is " + str(override_max_sources) + ", but " + str(len(sources)) + " were provided")
231
-
232
- try:
233
- selectedLanguage = languageName.lower() if len(languageName) > 0 else None
234
- selectedModel = modelName if modelName is not None else "base"
235
-
236
- if override_transcribe_file is None:
237
- model = create_whisper_container(whisper_implementation=self.app_config.whisper_implementation,
238
- model_name=selectedModel, compute_type=self.app_config.compute_type,
239
- cache=self.model_cache, models=self.app_config.models)
240
- else:
241
- model = None
242
-
243
- # Result
244
- download = []
245
- zip_file_lookup = {}
246
- text = ""
247
- vtt = ""
248
-
249
- # Write result
250
- downloadDirectory = tempfile.mkdtemp()
251
- source_index = 0
252
-
253
- outputDirectory = self.output_dir if self.output_dir is not None else downloadDirectory
254
-
255
- # Progress
256
- total_duration = sum([source.get_audio_duration() for source in sources])
257
- current_progress = 0
258
-
259
- # A listener that will report progress to Gradio
260
- root_progress_listener = self._create_progress_listener(progress)
261
-
262
- # Execute whisper
263
- for source in sources:
264
- source_prefix = ""
265
- source_audio_duration = source.get_audio_duration()
266
-
267
- if (len(sources) > 1):
268
- # Prefix (minimum 2 digits)
269
- source_index += 1
270
- source_prefix = str(source_index).zfill(2) + "_"
271
- print("Transcribing ", source.source_path)
272
-
273
- scaled_progress_listener = SubTaskProgressListener(root_progress_listener,
274
- base_task_total=total_duration,
275
- sub_task_start=current_progress,
276
- sub_task_total=source_audio_duration)
277
-
278
- # Transcribe using the override function if specified
279
- if override_transcribe_file is None:
280
- result = self.transcribe_file(model, source.source_path, selectedLanguage, task, vadOptions, scaled_progress_listener, **decodeOptions)
281
- else:
282
- result = override_transcribe_file(source)
283
-
284
- filePrefix = slugify(source_prefix + source.get_short_name(), allow_unicode=True)
285
-
286
- # Update progress
287
- current_progress += source_audio_duration
288
-
289
- source_download, source_text, source_vtt = self.write_result(result, filePrefix, outputDirectory, highlight_words)
290
-
291
- if len(sources) > 1:
292
- # Add new line separators
293
- if (len(source_text) > 0):
294
- source_text += os.linesep + os.linesep
295
- if (len(source_vtt) > 0):
296
- source_vtt += os.linesep + os.linesep
297
-
298
- # Append file name to source text too
299
- source_text = source.get_full_name() + ":" + os.linesep + source_text
300
- source_vtt = source.get_full_name() + ":" + os.linesep + source_vtt
301
-
302
- # Add to result
303
- download.extend(source_download)
304
- text += source_text
305
- vtt += source_vtt
306
-
307
- if (len(sources) > 1):
308
- # Zip files support at least 260 characters, but we'll play it safe and use 200
309
- zipFilePrefix = slugify(source_prefix + source.get_short_name(max_length=200), allow_unicode=True)
310
-
311
- # File names in ZIP file can be longer
312
- for source_download_file in source_download:
313
- # Get file postfix (after last -)
314
- filePostfix = os.path.basename(source_download_file).split("-")[-1]
315
- zip_file_name = zipFilePrefix + "-" + filePostfix
316
- zip_file_lookup[source_download_file] = zip_file_name
317
-
318
- # Create zip file from all sources
319
- if len(sources) > 1:
320
- downloadAllPath = os.path.join(downloadDirectory, "All_Output-" + datetime.now().strftime("%Y%m%d-%H%M%S") + ".zip")
321
-
322
- with zipfile.ZipFile(downloadAllPath, 'w', zipfile.ZIP_DEFLATED) as zip:
323
- for download_file in download:
324
- # Get file name from lookup
325
- zip_file_name = zip_file_lookup.get(download_file, os.path.basename(download_file))
326
- zip.write(download_file, arcname=zip_file_name)
327
-
328
- download.insert(0, downloadAllPath)
329
-
330
- return download, text, vtt
331
-
332
- finally:
333
- # Cleanup source
334
- if self.deleteUploadedFiles:
335
- for source in sources:
336
- print("Deleting source file " + source.source_path)
337
-
338
- try:
339
- os.remove(source.source_path)
340
- except Exception as e:
341
- # Ignore error - it's just a cleanup
342
- print("Error deleting source file " + source.source_path + ": " + str(e))
343
-
344
- except ExceededMaximumDuration as e:
345
- return [], ("[ERROR]: Maximum remote video length is " + str(e.maxDuration) + "s, file was " + str(e.videoDuration) + "s"), "[ERROR]"
346
-
347
- def transcribe_file(self, model: AbstractWhisperContainer, audio_path: str, language: str, task: str = None,
348
- vadOptions: VadOptions = VadOptions(),
349
- progressListener: ProgressListener = None, **decodeOptions: dict):
350
-
351
- initial_prompt = decodeOptions.pop('initial_prompt', None)
352
-
353
- if progressListener is None:
354
- # Default progress listener
355
- progressListener = ProgressListener()
356
-
357
- if ('task' in decodeOptions):
358
- task = decodeOptions.pop('task')
359
-
360
- initial_prompt_mode = vadOptions.vadInitialPromptMode
361
-
362
- # Set default initial prompt mode
363
- if (initial_prompt_mode is None):
364
- initial_prompt_mode = VadInitialPromptMode.PREPREND_FIRST_SEGMENT
365
-
366
- if (initial_prompt_mode == VadInitialPromptMode.PREPEND_ALL_SEGMENTS or
367
- initial_prompt_mode == VadInitialPromptMode.PREPREND_FIRST_SEGMENT):
368
- # Prepend initial prompt
369
- prompt_strategy = PrependPromptStrategy(initial_prompt, initial_prompt_mode)
370
- elif (vadOptions.vadInitialPromptMode == VadInitialPromptMode.JSON_PROMPT_MODE):
371
- # Use a JSON format to specify the prompt for each segment
372
- prompt_strategy = JsonPromptStrategy(initial_prompt)
373
- else:
374
- raise ValueError("Invalid vadInitialPromptMode: " + initial_prompt_mode)
375
-
376
- # Callable for processing an audio file
377
- whisperCallable = model.create_callback(language, task, prompt_strategy=prompt_strategy, **decodeOptions)
378
-
379
- # The results
380
- if (vadOptions.vad == 'silero-vad'):
381
- # Silero VAD where non-speech gaps are transcribed
382
- process_gaps = self._create_silero_config(NonSpeechStrategy.CREATE_SEGMENT, vadOptions)
383
- result = self.process_vad(audio_path, whisperCallable, self.vad_model, process_gaps, progressListener=progressListener)
384
- elif (vadOptions.vad == 'silero-vad-skip-gaps'):
385
- # Silero VAD where non-speech gaps are simply ignored
386
- skip_gaps = self._create_silero_config(NonSpeechStrategy.SKIP, vadOptions)
387
- result = self.process_vad(audio_path, whisperCallable, self.vad_model, skip_gaps, progressListener=progressListener)
388
- elif (vadOptions.vad == 'silero-vad-expand-into-gaps'):
389
- # Use Silero VAD where speech-segments are expanded into non-speech gaps
390
- expand_gaps = self._create_silero_config(NonSpeechStrategy.EXPAND_SEGMENT, vadOptions)
391
- result = self.process_vad(audio_path, whisperCallable, self.vad_model, expand_gaps, progressListener=progressListener)
392
- elif (vadOptions.vad == 'periodic-vad'):
393
- # Very simple VAD - mark every 5 minutes as speech. This makes it less likely that Whisper enters an infinite loop, but
394
- # it may create a break in the middle of a sentence, causing some artifacts.
395
- periodic_vad = VadPeriodicTranscription()
396
- period_config = PeriodicTranscriptionConfig(periodic_duration=vadOptions.vadMaxMergeSize, max_prompt_window=vadOptions.vadPromptWindow)
397
- result = self.process_vad(audio_path, whisperCallable, periodic_vad, period_config, progressListener=progressListener)
398
-
399
- else:
400
- if (self._has_parallel_devices()):
401
- # Use a simple period transcription instead, as we need to use the parallel context
402
- periodic_vad = VadPeriodicTranscription()
403
- period_config = PeriodicTranscriptionConfig(periodic_duration=math.inf, max_prompt_window=1)
404
-
405
- result = self.process_vad(audio_path, whisperCallable, periodic_vad, period_config, progressListener=progressListener)
406
- else:
407
- # Default VAD
408
- result = whisperCallable.invoke(audio_path, 0, None, None, progress_listener=progressListener)
409
-
410
- # Diarization
411
- result = self._handle_diarization(audio_path, result)
412
- return result
413
-
414
- def _handle_diarization(self, audio_path: str, input: dict):
415
- if self.diarization and self.diarization_kwargs:
416
- print("Diarizing ", audio_path)
417
- diarization_result = list(self.diarization.run(audio_path, **self.diarization_kwargs))
418
-
419
- # Print result
420
- print("Diarization result: ")
421
- for entry in diarization_result:
422
- print(f" start={entry.start:.1f}s stop={entry.end:.1f}s speaker_{entry.speaker}")
423
-
424
- # Add speakers to result
425
- input = self.diarization.mark_speakers(diarization_result, input)
426
-
427
- return input
428
-
429
- def _create_progress_listener(self, progress: gr.Progress):
430
- if (progress is None):
431
- # Dummy progress listener
432
- return ProgressListener()
433
-
434
- class ForwardingProgressListener(ProgressListener):
435
- def __init__(self, progress: gr.Progress):
436
- self.progress = progress
437
-
438
- def on_progress(self, current: Union[int, float], total: Union[int, float]):
439
- # From 0 to 1
440
- self.progress(current / total)
441
-
442
- def on_finished(self):
443
- self.progress(1)
444
-
445
- return ForwardingProgressListener(progress)
446
-
447
- def process_vad(self, audio_path, whisperCallable, vadModel: AbstractTranscription, vadConfig: TranscriptionConfig,
448
- progressListener: ProgressListener = None):
449
- if (not self._has_parallel_devices()):
450
- # No parallel devices, so just run the VAD and Whisper in sequence
451
- return vadModel.transcribe(audio_path, whisperCallable, vadConfig, progressListener=progressListener)
452
-
453
- gpu_devices = self.parallel_device_list
454
-
455
- if (gpu_devices is None or len(gpu_devices) == 0):
456
- # No GPU devices specified, pass the current environment variable to the first GPU process. This may be NULL.
457
- gpu_devices = [os.environ.get("CUDA_VISIBLE_DEVICES", None)]
458
-
459
- # Create parallel context if needed
460
- if (self.gpu_parallel_context is None):
461
- # Create a context wih processes and automatically clear the pool after 1 hour of inactivity
462
- self.gpu_parallel_context = ParallelContext(num_processes=len(gpu_devices), auto_cleanup_timeout_seconds=self.vad_process_timeout)
463
- # We also need a CPU context for the VAD
464
- if (self.cpu_parallel_context is None):
465
- self.cpu_parallel_context = ParallelContext(num_processes=self.vad_cpu_cores, auto_cleanup_timeout_seconds=self.vad_process_timeout)
466
-
467
- parallel_vad = ParallelTranscription()
468
- return parallel_vad.transcribe_parallel(transcription=vadModel, audio=audio_path, whisperCallable=whisperCallable,
469
- config=vadConfig, cpu_device_count=self.vad_cpu_cores, gpu_devices=gpu_devices,
470
- cpu_parallel_context=self.cpu_parallel_context, gpu_parallel_context=self.gpu_parallel_context,
471
- progress_listener=progressListener)
472
-
473
- def _has_parallel_devices(self):
474
- return (self.parallel_device_list is not None and len(self.parallel_device_list) > 0) or self.vad_cpu_cores > 1
475
-
476
- def _concat_prompt(self, prompt1, prompt2):
477
- if (prompt1 is None):
478
- return prompt2
479
- elif (prompt2 is None):
480
- return prompt1
481
- else:
482
- return prompt1 + " " + prompt2
483
-
484
- def _create_silero_config(self, non_speech_strategy: NonSpeechStrategy, vadOptions: VadOptions):
485
- # Use Silero VAD
486
- if (self.vad_model is None):
487
- self.vad_model = VadSileroTranscription()
488
-
489
- config = TranscriptionConfig(non_speech_strategy = non_speech_strategy,
490
- max_silent_period=vadOptions.vadMergeWindow, max_merge_size=vadOptions.vadMaxMergeSize,
491
- segment_padding_left=vadOptions.vadPadding, segment_padding_right=vadOptions.vadPadding,
492
- max_prompt_window=vadOptions.vadPromptWindow)
493
-
494
- return config
495
-
496
- def write_result(self, result: dict, source_name: str, output_dir: str, highlight_words: bool = False):
497
- if not os.path.exists(output_dir):
498
- os.makedirs(output_dir)
499
-
500
- text = result["text"]
501
- language = result["language"] if "language" in result else None
502
- languageMaxLineWidth = self.__get_max_line_width(language)
503
-
504
- # We always create the JSON file for debugging purposes
505
- json_result = json.dumps(result, indent=4, ensure_ascii=False)
506
- json_file = self.__create_file(json_result, output_dir, source_name + "-result.json")
507
- print("Created JSON file " + json_file)
508
-
509
- print("Max line width " + str(languageMaxLineWidth))
510
- vtt = self.__get_subs(result["segments"], "vtt", languageMaxLineWidth, highlight_words=highlight_words)
511
- srt = self.__get_subs(result["segments"], "srt", languageMaxLineWidth, highlight_words=highlight_words)
512
-
513
- output_files = []
514
- output_files.append(self.__create_file(srt, output_dir, source_name + "-subs.srt"));
515
- output_files.append(self.__create_file(vtt, output_dir, source_name + "-subs.vtt"));
516
- output_files.append(self.__create_file(text, output_dir, source_name + "-transcript.txt"));
517
- output_files.append(json_file)
518
-
519
- return output_files, text, vtt
520
-
521
- def clear_cache(self):
522
- self.model_cache.clear()
523
- self.vad_model = None
524
-
525
- def __get_source(self, urlData, multipleFiles, microphoneData):
526
- return get_audio_source_collection(urlData, multipleFiles, microphoneData, self.inputAudioMaxDuration)
527
-
528
- def __get_max_line_width(self, language: str) -> int:
529
- if (language and language.lower() in ["japanese", "ja", "chinese", "zh"]):
530
- # Chinese characters and kana are wider, so limit line length to 40 characters
531
- return 40
532
- else:
533
- # TODO: Add more languages
534
- # 80 latin characters should fit on a 1080p/720p screen
535
- return 80
536
-
537
- def __get_subs(self, segments: Iterator[dict], format: str, maxLineWidth: int, highlight_words: bool = False) -> str:
538
- segmentStream = StringIO()
539
-
540
- if format == 'vtt':
541
- write_vtt(segments, file=segmentStream, maxLineWidth=maxLineWidth, highlight_words=highlight_words)
542
- elif format == 'srt':
543
- write_srt(segments, file=segmentStream, maxLineWidth=maxLineWidth, highlight_words=highlight_words)
544
- else:
545
- raise Exception("Unknown format " + format)
546
-
547
- segmentStream.seek(0)
548
- return segmentStream.read()
549
-
550
- def __create_file(self, text: str, directory: str, fileName: str) -> str:
551
- # Write the text to a file
552
- with open(os.path.join(directory, fileName), 'w+', encoding="utf-8") as file:
553
- file.write(text)
554
-
555
- return file.name
556
-
557
- def close(self):
558
- print("Closing parallel contexts")
559
- self.clear_cache()
560
-
561
- if (self.gpu_parallel_context is not None):
562
- self.gpu_parallel_context.close()
563
- if (self.cpu_parallel_context is not None):
564
- self.cpu_parallel_context.close()
565
-
566
- # Cleanup diarization
567
- if (self.diarization is not None):
568
- self.diarization.cleanup()
569
- self.diarization = None
570
-
571
- def create_ui(app_config: ApplicationConfig):
572
- ui = WhisperTranscriber(app_config.input_audio_max_duration, app_config.vad_process_timeout, app_config.vad_cpu_cores,
573
- app_config.delete_uploaded_files, app_config.output_dir, app_config)
574
-
575
- # Specify a list of devices to use for parallel processing
576
- ui.set_parallel_devices(app_config.vad_parallel_devices)
577
- ui.set_auto_parallel(app_config.auto_parallel)
578
-
579
- is_whisper = False
580
-
581
- if app_config.whisper_implementation == "whisper":
582
- implementation_name = "Whisper"
583
- is_whisper = True
584
- elif app_config.whisper_implementation in ["faster-whisper", "faster_whisper"]:
585
- implementation_name = "Faster Whisper"
586
- else:
587
- # Try to convert from camel-case to title-case
588
- implementation_name = app_config.whisper_implementation.title().replace("_", " ").replace("-", " ")
589
-
590
- ui_description = implementation_name + " is a general-purpose speech recognition model. It is trained on a large dataset of diverse "
591
- ui_description += " audio and is also a multi-task model that can perform multilingual speech recognition "
592
- ui_description += " as well as speech translation and language identification. "
593
-
594
- ui_description += "\n\n\n\nFor longer audio files (>10 minutes) not in English, it is recommended that you select Silero VAD (Voice Activity Detector) in the VAD option."
595
-
596
- # Recommend faster-whisper
597
- if is_whisper:
598
- ui_description += "\n\n\n\nFor faster inference on GPU, try [faster-whisper](https://huggingface.co/spaces/aadnk/faster-whisper-webui)."
599
-
600
- if app_config.input_audio_max_duration > 0:
601
- ui_description += "\n\n" + "Max audio file length: " + str(app_config.input_audio_max_duration) + " s"
602
-
603
- ui_article = "Read the [documentation here](https://gitlab.com/aadnk/whisper-webui/-/blob/main/docs/options.md)."
604
-
605
- whisper_models = app_config.get_model_names()
606
-
607
- common_inputs = lambda : [
608
- gr.Dropdown(choices=whisper_models, value=app_config.default_model_name, label="Model"),
609
- gr.Dropdown(choices=sorted(get_language_names()), label="Language", value=app_config.language),
610
- gr.Text(label="URL (YouTube, etc.)"),
611
- gr.File(label="Upload Files", file_count="multiple"),
612
- gr.Audio(source="microphone", type="filepath", label="Microphone Input"),
613
- gr.Dropdown(choices=["transcribe", "translate"], label="Task", value=app_config.task),
614
- ]
615
-
616
- common_vad_inputs = lambda : [
617
- gr.Dropdown(choices=["none", "silero-vad", "silero-vad-skip-gaps", "silero-vad-expand-into-gaps", "periodic-vad"], value=app_config.default_vad, label="VAD"),
618
- gr.Number(label="VAD - Merge Window (s)", precision=0, value=app_config.vad_merge_window),
619
- gr.Number(label="VAD - Max Merge Size (s)", precision=0, value=app_config.vad_max_merge_size),
620
- ]
621
 
622
- common_word_timestamps_inputs = lambda : [
623
- gr.Checkbox(label="Word Timestamps", value=app_config.word_timestamps),
624
- gr.Checkbox(label="Word Timestamps - Highlight Words", value=app_config.highlight_words),
625
- ]
626
-
627
- has_diarization_libs = Diarization.has_libraries()
628
-
629
- if not has_diarization_libs:
630
- print("Diarization libraries not found - disabling diarization")
631
- app_config.diarization = False
632
-
633
- common_diarization_inputs = lambda : [
634
- gr.Checkbox(label="Diarization", value=app_config.diarization, interactive=has_diarization_libs),
635
- gr.Number(label="Diarization - Speakers", precision=0, value=app_config.diarization_speakers, interactive=has_diarization_libs)
636
- ]
637
-
638
- is_queue_mode = app_config.queue_concurrency_count is not None and app_config.queue_concurrency_count > 0
639
-
640
- simple_transcribe = gr.Interface(fn=ui.transcribe_webui_simple_progress if is_queue_mode else ui.transcribe_webui_simple,
641
- description=ui_description, article=ui_article, inputs=[
642
- *common_inputs(),
643
- *common_vad_inputs(),
644
- *common_word_timestamps_inputs(),
645
- *common_diarization_inputs(),
646
- ], outputs=[
647
- gr.File(label="Download"),
648
- gr.Text(label="Transcription"),
649
- gr.Text(label="Segments")
650
- ])
651
-
652
- full_description = ui_description + "\n\n\n\n" + "Be careful when changing some of the options in the full interface - this can cause the model to crash."
653
-
654
- full_transcribe = gr.Interface(fn=ui.transcribe_webui_full_progress if is_queue_mode else ui.transcribe_webui_full,
655
- description=full_description, article=ui_article, inputs=[
656
- *common_inputs(),
657
 
658
- *common_vad_inputs(),
659
- gr.Number(label="VAD - Padding (s)", precision=None, value=app_config.vad_padding),
660
- gr.Number(label="VAD - Prompt Window (s)", precision=None, value=app_config.vad_prompt_window),
661
- gr.Dropdown(choices=VAD_INITIAL_PROMPT_MODE_VALUES, label="VAD - Initial Prompt Mode"),
662
-
663
- *common_word_timestamps_inputs(),
664
- gr.Text(label="Word Timestamps - Prepend Punctuations", value=app_config.prepend_punctuations),
665
- gr.Text(label="Word Timestamps - Append Punctuations", value=app_config.append_punctuations),
666
 
667
- gr.TextArea(label="Initial Prompt"),
668
- gr.Number(label="Temperature", value=app_config.temperature),
669
- gr.Number(label="Best Of - Non-zero temperature", value=app_config.best_of, precision=0),
670
- gr.Number(label="Beam Size - Zero temperature", value=app_config.beam_size, precision=0),
671
- gr.Number(label="Patience - Zero temperature", value=app_config.patience),
672
- gr.Number(label="Length Penalty - Any temperature", value=app_config.length_penalty),
673
- gr.Text(label="Suppress Tokens - Comma-separated list of token IDs", value=app_config.suppress_tokens),
674
- gr.Checkbox(label="Condition on previous text", value=app_config.condition_on_previous_text),
675
- gr.Checkbox(label="FP16", value=app_config.fp16),
676
- gr.Number(label="Temperature increment on fallback", value=app_config.temperature_increment_on_fallback),
677
- gr.Number(label="Compression ratio threshold", value=app_config.compression_ratio_threshold),
678
- gr.Number(label="Logprob threshold", value=app_config.logprob_threshold),
679
- gr.Number(label="No speech threshold", value=app_config.no_speech_threshold),
680
 
681
- *common_diarization_inputs(),
682
- gr.Number(label="Diarization - Min Speakers", precision=0, value=app_config.diarization_min_speakers, interactive=has_diarization_libs),
683
- gr.Number(label="Diarization - Max Speakers", precision=0, value=app_config.diarization_max_speakers, interactive=has_diarization_libs),
684
 
685
- ], outputs=[
686
- gr.File(label="Download"),
687
- gr.Text(label="Transcription"),
688
- gr.Text(label="Segments")
689
- ])
690
 
691
- perform_extra_interface = gr.Interface(fn=ui.perform_extra,
692
- description="Perform additional processing on a given JSON or SRT file", article=ui_article, inputs=[
693
- gr.Dropdown(choices=sorted(get_language_names()), label="Language", value=app_config.language),
694
- gr.Text(label="URL (YouTube, etc.)"),
695
- gr.File(label="Upload Audio File", file_count="single"),
696
- gr.File(label="Upload JSON/SRT File", file_count="single"),
697
- gr.Checkbox(label="Word Timestamps - Highlight Words", value=app_config.highlight_words),
698
 
699
- *common_diarization_inputs(),
700
- gr.Number(label="Diarization - Min Speakers", precision=0, value=app_config.diarization_min_speakers, interactive=has_diarization_libs),
701
- gr.Number(label="Diarization - Max Speakers", precision=0, value=app_config.diarization_max_speakers, interactive=has_diarization_libs),
702
-
703
- ], outputs=[
704
- gr.File(label="Download"),
705
- gr.Text(label="Transcription"),
706
- gr.Text(label="Segments")
707
- ])
708
-
709
- demo = gr.TabbedInterface([simple_transcribe, full_transcribe, perform_extra_interface], tab_names=["Simple", "Full", "Extra"])
710
-
711
- # Queue up the demo
712
- if is_queue_mode:
713
- demo.queue(concurrency_count=app_config.queue_concurrency_count)
714
- print("Queue mode enabled (concurrency count: " + str(app_config.queue_concurrency_count) + ")")
715
- else:
716
- print("Queue mode disabled - progress bars will not be shown.")
717
-
718
- demo.launch(share=app_config.share, server_name=app_config.server_name, server_port=app_config.server_port)
719
-
720
- # Clean up
721
- ui.close()
722
-
723
- if __name__ == '__main__':
724
- default_app_config = ApplicationConfig.create_default()
725
- whisper_models = default_app_config.get_model_names()
726
-
727
- # Environment variable overrides
728
- default_whisper_implementation = os.environ.get("WHISPER_IMPLEMENTATION", default_app_config.whisper_implementation)
729
-
730
- parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
731
- parser.add_argument("--input_audio_max_duration", type=int, default=default_app_config.input_audio_max_duration, \
732
- help="Maximum audio file length in seconds, or -1 for no limit.") # 600
733
- parser.add_argument("--share", type=bool, default=default_app_config.share, \
734
- help="True to share the app on HuggingFace.") # False
735
- parser.add_argument("--server_name", type=str, default=default_app_config.server_name, \
736
- help="The host or IP to bind to. If None, bind to localhost.") # None
737
- parser.add_argument("--server_port", type=int, default=default_app_config.server_port, \
738
- help="The port to bind to.") # 7860
739
- parser.add_argument("--queue_concurrency_count", type=int, default=default_app_config.queue_concurrency_count, \
740
- help="The number of concurrent requests to process.") # 1
741
- parser.add_argument("--default_model_name", type=str, choices=whisper_models, default=default_app_config.default_model_name, \
742
- help="The default model name.") # medium
743
- parser.add_argument("--default_vad", type=str, default=default_app_config.default_vad, \
744
- help="The default VAD.") # silero-vad
745
- parser.add_argument("--vad_initial_prompt_mode", type=str, default=default_app_config.vad_initial_prompt_mode, choices=VAD_INITIAL_PROMPT_MODE_VALUES, \
746
- help="Whether or not to prepend the initial prompt to each VAD segment (prepend_all_segments), or just the first segment (prepend_first_segment)") # prepend_first_segment
747
- parser.add_argument("--vad_parallel_devices", type=str, default=default_app_config.vad_parallel_devices, \
748
- help="A commma delimited list of CUDA devices to use for parallel processing. If None, disable parallel processing.") # ""
749
- parser.add_argument("--vad_cpu_cores", type=int, default=default_app_config.vad_cpu_cores, \
750
- help="The number of CPU cores to use for VAD pre-processing.") # 1
751
- parser.add_argument("--vad_process_timeout", type=float, default=default_app_config.vad_process_timeout, \
752
- help="The number of seconds before inactivate processes are terminated. Use 0 to close processes immediately, or None for no timeout.") # 1800
753
- parser.add_argument("--auto_parallel", type=bool, default=default_app_config.auto_parallel, \
754
- help="True to use all available GPUs and CPU cores for processing. Use vad_cpu_cores/vad_parallel_devices to specify the number of CPU cores/GPUs to use.") # False
755
- parser.add_argument("--output_dir", "-o", type=str, default=default_app_config.output_dir, \
756
- help="directory to save the outputs")
757
- parser.add_argument("--whisper_implementation", type=str, default=default_whisper_implementation, choices=["whisper", "faster-whisper"],\
758
- help="the Whisper implementation to use")
759
- parser.add_argument("--compute_type", type=str, default=default_app_config.compute_type, choices=["default", "auto", "int8", "int8_float16", "int16", "float16", "float32"], \
760
- help="the compute type to use for inference")
761
- parser.add_argument("--threads", type=optional_int, default=0,
762
- help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
763
-
764
- parser.add_argument('--auth_token', type=str, default=default_app_config.auth_token, help='HuggingFace API Token (optional)')
765
- parser.add_argument("--diarization", type=str2bool, default=default_app_config.diarization, \
766
- help="whether to perform speaker diarization")
767
- parser.add_argument("--diarization_num_speakers", type=int, default=default_app_config.diarization_speakers, help="Number of speakers")
768
- parser.add_argument("--diarization_min_speakers", type=int, default=default_app_config.diarization_min_speakers, help="Minimum number of speakers")
769
- parser.add_argument("--diarization_max_speakers", type=int, default=default_app_config.diarization_max_speakers, help="Maximum number of speakers")
770
- parser.add_argument("--diarization_process_timeout", type=int, default=default_app_config.diarization_process_timeout, \
771
- help="Number of seconds before inactivate diarization processes are terminated. Use 0 to close processes immediately, or None for no timeout.")
772
 
773
- args = parser.parse_args().__dict__
 
774
 
775
- updated_config = default_app_config.update(**args)
 
 
776
 
777
- if (threads := args.pop("threads")) > 0:
778
- torch.set_num_threads(threads)
779
 
780
- print("Using whisper implementation: " + updated_config.whisper_implementation)
781
- create_ui(app_config=updated_config)
 
1
+ import gradio as gr
 
 
 
 
 
 
2
  import os
 
3
  import tempfile
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  import ffmpeg
5
+ import json
6
+ from huggingface_hub import InferenceApi
7
+ from typing import List, Dict, Tuple
8
+
9
+ # 🔹 Constants
10
+ MODEL_NAME: str = "ivrit-ai/faster-whisper-v2-d4"
11
+ TRANSLATION_MODEL_NAME: str = "dicta-il/dictalm2.0-GGUF"
12
+ TEMP_DIR: str = tempfile.gettempdir()
13
+
14
+ # 🔹 Load Hugging Face Inference API
15
+ ASR_API = InferenceApi(repo_id=MODEL_NAME)
16
+ TRANSLATION_API = InferenceApi(repo_id=TRANSLATION_MODEL_NAME)
17
+
18
+ def convert_audio(audio_path: str) -> str:
19
+ """Converts an audio file to 16kHz WAV format for compatibility."""
20
+ converted_path = os.path.join(TEMP_DIR, "converted.wav")
21
+ (
22
+ ffmpeg
23
+ .input(audio_path)
24
+ .output(converted_path, format="wav", ar="16000")
25
+ .run(overwrite_output=True, quiet=True)
26
+ )
27
+ return converted_path
28
+
29
+ def transcribe_audio(file: str, translate: bool) -> Tuple[str, str]:
30
+ """Transcribes audio and optionally translates it using Hugging Face API."""
31
+ audio_path = file if file.endswith(".wav") else convert_audio(file)
32
+
33
+ with open(audio_path, "rb") as audio_file:
34
+ result = ASR_API(inputs=audio_file)
35
+
36
+ segments = result.get("segments", [])
37
+ subtitles: List[Dict[str, str]] = []
38
+ transcribed_text: str = ""
39
+
40
+ for segment in segments:
41
+ hebrew_text = segment["text"]
42
+ start_time = segment["start"]
43
+ end_time = segment["end"]
44
+ eng_translation = ""
45
+
46
+ if translate:
47
+ eng_translation = TRANSLATION_API(inputs=hebrew_text)[0]["translation_text"]
48
+
49
+ subtitles.append({
50
+ "start": start_time,
51
+ "end": end_time,
52
+ "text": hebrew_text,
53
+ "translation": eng_translation if translate else None
54
+ })
55
+
56
+ transcribed_text += f"{hebrew_text} "
57
+
58
+ return json.dumps(subtitles), transcribed_text
59
+
60
+ # 🔹 Inject WebGPU-compatible JavaScript via `gr.HTML()`
61
+ webgpu_script = """
62
+ <script type="module">
63
+ import { pipeline } from 'https://cdn.jsdelivr.net/npm/@xenova/transformers@latest';
64
+
65
+ let asr;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
+ async function loadModel() {
68
+ asr = await pipeline("automatic-speech-recognition", "openai/whisper-large-v3");
69
+ console.log("WebGPU ASR model loaded.");
70
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
+ async function transcribe(audioFile) {
73
+ if (!asr) {
74
+ console.error("Model not loaded.");
75
+ return;
76
+ }
77
+ const result = await asr(audioFile);
78
+ document.getElementById("output").innerText = result.text;
79
+ }
80
 
81
+ document.getElementById("upload").addEventListener("change", async (event) => {
82
+ const file = event.target.files[0];
83
+ transcribe(file);
84
+ });
 
 
 
 
 
 
 
 
 
85
 
86
+ loadModel();
87
+ </script>
 
88
 
89
+ <input type="file" id="upload" accept="audio/*">
90
+ <p id="output">Transcription will appear here.</p>
91
+ """
 
 
92
 
93
+ # 🔹 Gradio UI
94
+ with gr.Blocks() as demo:
95
+ gr.Markdown("# WhatShutup: Transcribe WhatsApp Voice Messages with WebGPU Support")
 
 
 
 
96
 
97
+ webgpu_component = gr.HTML(webgpu_script)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
 
99
+ audio_input = gr.Audio(source="upload", type="filepath", label="Upload Audio File")
100
+ translate_checkbox = gr.Checkbox(label="Translate to English?", value=False)
101
 
102
+ with gr.Row():
103
+ audio_player = gr.Audio(source="upload", type="filepath", label="Playback")
104
+ transcript_output = gr.Textbox(label="Transcription & Subtitles", lines=10)
105
 
106
+ submit_btn = gr.Button("Transcribe")
107
+ submit_btn.click(transcribe_audio, inputs=[audio_input, translate_checkbox], outputs=[audio_player, transcript_output])
108
 
109
+ demo.launch()