Spaces:
				
			
			
	
			
			
					
		Running
		
	
	
	
			
			
	
	
	
	
		
		
					
		Running
		
	Add configuration file and support for custom models
Browse filesCustom models can be added to the configuration file,
under the "models" section. See the comments for more
details.
- .gitignore +1 -0
- app.py +37 -17
- cli.py +72 -38
- config.json5 +62 -0
- requirements.txt +3 -1
- src/config.py +134 -0
- src/conversion/hf_converter.py +67 -0
- src/whisperContainer.py +29 -3
    	
        .gitignore
    CHANGED
    
    | @@ -1,5 +1,6 @@ | |
| 1 | 
             
            # Byte-compiled / optimized / DLL files
         | 
| 2 | 
             
            __pycache__/
         | 
|  | |
| 3 | 
             
            flagged/
         | 
| 4 | 
             
            *.py[cod]
         | 
| 5 | 
             
            *$py.class
         | 
|  | |
| 1 | 
             
            # Byte-compiled / optimized / DLL files
         | 
| 2 | 
             
            __pycache__/
         | 
| 3 | 
            +
            .vscode/
         | 
| 4 | 
             
            flagged/
         | 
| 5 | 
             
            *.py[cod]
         | 
| 6 | 
             
            *$py.class
         | 
    	
        app.py
    CHANGED
    
    | @@ -11,6 +11,7 @@ import zipfile | |
| 11 | 
             
            import numpy as np
         | 
| 12 |  | 
| 13 | 
             
            import torch
         | 
|  | |
| 14 | 
             
            from src.modelCache import ModelCache
         | 
| 15 | 
             
            from src.source import get_audio_source_collection
         | 
| 16 | 
             
            from src.vadParallel import ParallelContext, ParallelTranscription
         | 
| @@ -62,7 +63,8 @@ WHISPER_MODELS = ["tiny", "base", "small", "medium", "large", "large-v1", "large | |
| 62 |  | 
| 63 | 
             
            class WhisperTranscriber:
         | 
| 64 | 
             
                def __init__(self, input_audio_max_duration: float = DEFAULT_INPUT_AUDIO_MAX_DURATION, vad_process_timeout: float = None, 
         | 
| 65 | 
            -
                             vad_cpu_cores: int = 1, delete_uploaded_files: bool = DELETE_UPLOADED_FILES, output_dir: str = None | 
|  | |
| 66 | 
             
                    self.model_cache = ModelCache()
         | 
| 67 | 
             
                    self.parallel_device_list = None
         | 
| 68 | 
             
                    self.gpu_parallel_context = None
         | 
| @@ -75,6 +77,8 @@ class WhisperTranscriber: | |
| 75 | 
             
                    self.deleteUploadedFiles = delete_uploaded_files
         | 
| 76 | 
             
                    self.output_dir = output_dir
         | 
| 77 |  | 
|  | |
|  | |
| 78 | 
             
                def set_parallel_devices(self, vad_parallel_devices: str):
         | 
| 79 | 
             
                    self.parallel_device_list = [ device.strip() for device in vad_parallel_devices.split(",") ] if vad_parallel_devices else None
         | 
| 80 |  | 
| @@ -115,7 +119,7 @@ class WhisperTranscriber: | |
| 115 | 
             
                            selectedLanguage = languageName.lower() if len(languageName) > 0 else None
         | 
| 116 | 
             
                            selectedModel = modelName if modelName is not None else "base"
         | 
| 117 |  | 
| 118 | 
            -
                            model = WhisperContainer(model_name=selectedModel, cache=self.model_cache)
         | 
| 119 |  | 
| 120 | 
             
                            # Result
         | 
| 121 | 
             
                            download = []
         | 
| @@ -360,8 +364,8 @@ class WhisperTranscriber: | |
| 360 | 
             
            def create_ui(input_audio_max_duration, share=False, server_name: str = None, server_port: int = 7860, 
         | 
| 361 | 
             
                          default_model_name: str = "medium", default_vad: str = None, vad_parallel_devices: str = None, 
         | 
| 362 | 
             
                          vad_process_timeout: float = None, vad_cpu_cores: int = 1, auto_parallel: bool = False, 
         | 
| 363 | 
            -
                          output_dir: str = None):
         | 
| 364 | 
            -
                ui = WhisperTranscriber(input_audio_max_duration, vad_process_timeout, vad_cpu_cores, DELETE_UPLOADED_FILES, output_dir)
         | 
| 365 |  | 
| 366 | 
             
                # Specify a list of devices to use for parallel processing
         | 
| 367 | 
             
                ui.set_parallel_devices(vad_parallel_devices)
         | 
| @@ -378,8 +382,10 @@ def create_ui(input_audio_max_duration, share=False, server_name: str = None, se | |
| 378 |  | 
| 379 | 
             
                ui_article = "Read the [documentation here](https://gitlab.com/aadnk/whisper-webui/-/blob/main/docs/options.md)"
         | 
| 380 |  | 
|  | |
|  | |
| 381 | 
             
                simple_inputs = lambda : [
         | 
| 382 | 
            -
                    gr.Dropdown(choices= | 
| 383 | 
             
                    gr.Dropdown(choices=sorted(LANGUAGES), label="Language"),
         | 
| 384 | 
             
                    gr.Text(label="URL (YouTube, etc.)"),
         | 
| 385 | 
             
                    gr.File(label="Upload Files", file_count="multiple"),
         | 
| @@ -429,18 +435,32 @@ def create_ui(input_audio_max_duration, share=False, server_name: str = None, se | |
| 429 | 
             
                ui.close()
         | 
| 430 |  | 
| 431 | 
             
            if __name__ == '__main__':
         | 
|  | |
|  | |
|  | |
| 432 | 
             
                parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
         | 
| 433 | 
            -
                parser.add_argument("--input_audio_max_duration", type=int, default= | 
| 434 | 
            -
             | 
| 435 | 
            -
                parser.add_argument("-- | 
| 436 | 
            -
             | 
| 437 | 
            -
                parser.add_argument("-- | 
| 438 | 
            -
             | 
| 439 | 
            -
                parser.add_argument("-- | 
| 440 | 
            -
             | 
| 441 | 
            -
                parser.add_argument("-- | 
| 442 | 
            -
             | 
| 443 | 
            -
                parser.add_argument("-- | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 444 |  | 
| 445 | 
             
                args = parser.parse_args().__dict__
         | 
| 446 | 
            -
                create_ui(**args)
         | 
|  | |
| 11 | 
             
            import numpy as np
         | 
| 12 |  | 
| 13 | 
             
            import torch
         | 
| 14 | 
            +
            from src.config import ApplicationConfig
         | 
| 15 | 
             
            from src.modelCache import ModelCache
         | 
| 16 | 
             
            from src.source import get_audio_source_collection
         | 
| 17 | 
             
            from src.vadParallel import ParallelContext, ParallelTranscription
         | 
|  | |
| 63 |  | 
| 64 | 
             
            class WhisperTranscriber:
         | 
| 65 | 
             
                def __init__(self, input_audio_max_duration: float = DEFAULT_INPUT_AUDIO_MAX_DURATION, vad_process_timeout: float = None, 
         | 
| 66 | 
            +
                             vad_cpu_cores: int = 1, delete_uploaded_files: bool = DELETE_UPLOADED_FILES, output_dir: str = None, 
         | 
| 67 | 
            +
                             app_config: ApplicationConfig = None):
         | 
| 68 | 
             
                    self.model_cache = ModelCache()
         | 
| 69 | 
             
                    self.parallel_device_list = None
         | 
| 70 | 
             
                    self.gpu_parallel_context = None
         | 
|  | |
| 77 | 
             
                    self.deleteUploadedFiles = delete_uploaded_files
         | 
| 78 | 
             
                    self.output_dir = output_dir
         | 
| 79 |  | 
| 80 | 
            +
                    self.app_config = app_config
         | 
| 81 | 
            +
             | 
| 82 | 
             
                def set_parallel_devices(self, vad_parallel_devices: str):
         | 
| 83 | 
             
                    self.parallel_device_list = [ device.strip() for device in vad_parallel_devices.split(",") ] if vad_parallel_devices else None
         | 
| 84 |  | 
|  | |
| 119 | 
             
                            selectedLanguage = languageName.lower() if len(languageName) > 0 else None
         | 
| 120 | 
             
                            selectedModel = modelName if modelName is not None else "base"
         | 
| 121 |  | 
| 122 | 
            +
                            model = WhisperContainer(model_name=selectedModel, cache=self.model_cache, models=self.app_config.models)
         | 
| 123 |  | 
| 124 | 
             
                            # Result
         | 
| 125 | 
             
                            download = []
         | 
|  | |
| 364 | 
             
            def create_ui(input_audio_max_duration, share=False, server_name: str = None, server_port: int = 7860, 
         | 
| 365 | 
             
                          default_model_name: str = "medium", default_vad: str = None, vad_parallel_devices: str = None, 
         | 
| 366 | 
             
                          vad_process_timeout: float = None, vad_cpu_cores: int = 1, auto_parallel: bool = False, 
         | 
| 367 | 
            +
                          output_dir: str = None, app_config: ApplicationConfig = None):
         | 
| 368 | 
            +
                ui = WhisperTranscriber(input_audio_max_duration, vad_process_timeout, vad_cpu_cores, DELETE_UPLOADED_FILES, output_dir, app_config)
         | 
| 369 |  | 
| 370 | 
             
                # Specify a list of devices to use for parallel processing
         | 
| 371 | 
             
                ui.set_parallel_devices(vad_parallel_devices)
         | 
|  | |
| 382 |  | 
| 383 | 
             
                ui_article = "Read the [documentation here](https://gitlab.com/aadnk/whisper-webui/-/blob/main/docs/options.md)"
         | 
| 384 |  | 
| 385 | 
            +
                whisper_models = app_config.get_model_names()
         | 
| 386 | 
            +
             | 
| 387 | 
             
                simple_inputs = lambda : [
         | 
| 388 | 
            +
                    gr.Dropdown(choices=whisper_models, value=default_model_name, label="Model"),
         | 
| 389 | 
             
                    gr.Dropdown(choices=sorted(LANGUAGES), label="Language"),
         | 
| 390 | 
             
                    gr.Text(label="URL (YouTube, etc.)"),
         | 
| 391 | 
             
                    gr.File(label="Upload Files", file_count="multiple"),
         | 
|  | |
| 435 | 
             
                ui.close()
         | 
| 436 |  | 
| 437 | 
             
            if __name__ == '__main__':
         | 
| 438 | 
            +
                app_config = ApplicationConfig.parse_file(os.environ.get("WHISPER_WEBUI_CONFIG", "config.json5"))
         | 
| 439 | 
            +
                whisper_models = app_config.get_model_names()
         | 
| 440 | 
            +
             | 
| 441 | 
             
                parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
         | 
| 442 | 
            +
                parser.add_argument("--input_audio_max_duration", type=int, default=app_config.input_audio_max_duration, \
         | 
| 443 | 
            +
                                    help="Maximum audio file length in seconds, or -1 for no limit.") # 600
         | 
| 444 | 
            +
                parser.add_argument("--share", type=bool, default=app_config.share, \
         | 
| 445 | 
            +
                                    help="True to share the app on HuggingFace.") # False
         | 
| 446 | 
            +
                parser.add_argument("--server_name", type=str, default=app_config.server_name, \
         | 
| 447 | 
            +
                                    help="The host or IP to bind to. If None, bind to localhost.") # None
         | 
| 448 | 
            +
                parser.add_argument("--server_port", type=int, default=app_config.server_port, \
         | 
| 449 | 
            +
                                    help="The port to bind to.") # 7860
         | 
| 450 | 
            +
                parser.add_argument("--default_model_name", type=str, choices=whisper_models, default=app_config.default_model_name, \
         | 
| 451 | 
            +
                                    help="The default model name.") # medium
         | 
| 452 | 
            +
                parser.add_argument("--default_vad", type=str, default=app_config.default_vad, \
         | 
| 453 | 
            +
                                    help="The default VAD.") # silero-vad
         | 
| 454 | 
            +
                parser.add_argument("--vad_parallel_devices", type=str, default=app_config.vad_parallel_devices, \
         | 
| 455 | 
            +
                                    help="A commma delimited list of CUDA devices to use for parallel processing. If None, disable parallel processing.") # ""
         | 
| 456 | 
            +
                parser.add_argument("--vad_cpu_cores", type=int, default=app_config.vad_cpu_cores, \
         | 
| 457 | 
            +
                                    help="The number of CPU cores to use for VAD pre-processing.") # 1
         | 
| 458 | 
            +
                parser.add_argument("--vad_process_timeout", type=float, default=app_config.vad_process_timeout, \
         | 
| 459 | 
            +
                                    help="The number of seconds before inactivate processes are terminated. Use 0 to close processes immediately, or None for no timeout.") # 1800
         | 
| 460 | 
            +
                parser.add_argument("--auto_parallel", type=bool, default=app_config.auto_parallel, \
         | 
| 461 | 
            +
                                    help="True to use all available GPUs and CPU cores for processing. Use vad_cpu_cores/vad_parallel_devices to specify the number of CPU cores/GPUs to use.") # False
         | 
| 462 | 
            +
                parser.add_argument("--output_dir", "-o", type=str, default=app_config.output_dir, \
         | 
| 463 | 
            +
                                    help="directory to save the outputs") # None
         | 
| 464 |  | 
| 465 | 
             
                args = parser.parse_args().__dict__
         | 
| 466 | 
            +
                create_ui(app_config=app_config, **args)
         | 
    	
        cli.py
    CHANGED
    
    | @@ -6,48 +6,81 @@ import warnings | |
| 6 | 
             
            import numpy as np
         | 
| 7 |  | 
| 8 | 
             
            import torch
         | 
| 9 | 
            -
            from app import LANGUAGES,  | 
|  | |
| 10 | 
             
            from src.download import download_url
         | 
| 11 |  | 
| 12 | 
             
            from src.utils import optional_float, optional_int, str2bool
         | 
| 13 | 
             
            from src.whisperContainer import WhisperContainer
         | 
| 14 |  | 
| 15 | 
             
            def cli():
         | 
|  | |
|  | |
|  | |
| 16 | 
             
                parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
         | 
| 17 | 
            -
                parser.add_argument("audio", nargs="+", type=str,  | 
| 18 | 
            -
             | 
| 19 | 
            -
                parser.add_argument("-- | 
| 20 | 
            -
             | 
| 21 | 
            -
                parser.add_argument("-- | 
| 22 | 
            -
             | 
| 23 | 
            -
             | 
| 24 | 
            -
             | 
| 25 | 
            -
                parser.add_argument("-- | 
| 26 | 
            -
             | 
| 27 | 
            -
                parser.add_argument("-- | 
| 28 | 
            -
             | 
| 29 | 
            -
             | 
| 30 | 
            -
                parser.add_argument("-- | 
| 31 | 
            -
             | 
| 32 | 
            -
                parser.add_argument("-- | 
| 33 | 
            -
             | 
| 34 | 
            -
             | 
| 35 | 
            -
             | 
| 36 | 
            -
             | 
| 37 | 
            -
                parser.add_argument("-- | 
| 38 | 
            -
             | 
| 39 | 
            -
                parser.add_argument("-- | 
| 40 | 
            -
             | 
| 41 | 
            -
             | 
| 42 | 
            -
             | 
| 43 | 
            -
                parser.add_argument("-- | 
| 44 | 
            -
             | 
| 45 | 
            -
                parser.add_argument("-- | 
| 46 | 
            -
             | 
| 47 | 
            -
                parser.add_argument("-- | 
| 48 | 
            -
             | 
| 49 | 
            -
                parser.add_argument("-- | 
| 50 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 51 |  | 
| 52 | 
             
                args = parser.parse_args().__dict__
         | 
| 53 | 
             
                model_name: str = args.pop("model")
         | 
| @@ -74,12 +107,13 @@ def cli(): | |
| 74 | 
             
                vad_prompt_window = args.pop("vad_prompt_window")
         | 
| 75 | 
             
                vad_cpu_cores = args.pop("vad_cpu_cores")
         | 
| 76 | 
             
                auto_parallel = args.pop("auto_parallel")
         | 
| 77 | 
            -
             | 
| 78 | 
            -
                 | 
| 79 | 
            -
                transcriber = WhisperTranscriber(delete_uploaded_files=False, vad_cpu_cores=vad_cpu_cores)
         | 
| 80 | 
             
                transcriber.set_parallel_devices(args.pop("vad_parallel_devices"))
         | 
| 81 | 
             
                transcriber.set_auto_parallel(auto_parallel)
         | 
| 82 |  | 
|  | |
|  | |
| 83 | 
             
                if (transcriber._has_parallel_devices()):
         | 
| 84 | 
             
                    print("Using parallel devices:", transcriber.parallel_device_list)
         | 
| 85 |  | 
|  | |
| 6 | 
             
            import numpy as np
         | 
| 7 |  | 
| 8 | 
             
            import torch
         | 
| 9 | 
            +
            from app import LANGUAGES, WhisperTranscriber
         | 
| 10 | 
            +
            from src.config import ApplicationConfig
         | 
| 11 | 
             
            from src.download import download_url
         | 
| 12 |  | 
| 13 | 
             
            from src.utils import optional_float, optional_int, str2bool
         | 
| 14 | 
             
            from src.whisperContainer import WhisperContainer
         | 
| 15 |  | 
| 16 | 
             
            def cli():
         | 
| 17 | 
            +
                app_config = ApplicationConfig.parse_file(os.environ.get("WHISPER_WEBUI_CONFIG", "config.json5"))
         | 
| 18 | 
            +
                whisper_models = app_config.get_model_names()
         | 
| 19 | 
            +
             | 
| 20 | 
             
                parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
         | 
| 21 | 
            +
                parser.add_argument("audio", nargs="+", type=str, \
         | 
| 22 | 
            +
                                    help="audio file(s) to transcribe")
         | 
| 23 | 
            +
                parser.add_argument("--model", default=app_config.default_model_name, choices=whisper_models, \
         | 
| 24 | 
            +
                                    help="name of the Whisper model to use") # medium
         | 
| 25 | 
            +
                parser.add_argument("--model_dir", type=str, default=None, \
         | 
| 26 | 
            +
                                    help="the path to save model files; uses ~/.cache/whisper by default")
         | 
| 27 | 
            +
                parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", \
         | 
| 28 | 
            +
                                    help="device to use for PyTorch inference")
         | 
| 29 | 
            +
                parser.add_argument("--output_dir", "-o", type=str, default=".", \
         | 
| 30 | 
            +
                                    help="directory to save the outputs")
         | 
| 31 | 
            +
                parser.add_argument("--verbose", type=str2bool, default=True, \
         | 
| 32 | 
            +
                                    help="whether to print out the progress and debug messages")
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], \
         | 
| 35 | 
            +
                                    help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')")
         | 
| 36 | 
            +
                parser.add_argument("--language", type=str, default=None, choices=sorted(LANGUAGES), \
         | 
| 37 | 
            +
                                    help="language spoken in the audio, specify None to perform language detection")
         | 
| 38 | 
            +
             | 
| 39 | 
            +
                parser.add_argument("--vad", type=str, default=app_config.default_vad, choices=["none", "silero-vad", "silero-vad-skip-gaps", "silero-vad-expand-into-gaps", "periodic-vad"], \
         | 
| 40 | 
            +
                                    help="The voice activity detection algorithm to use") # silero-vad
         | 
| 41 | 
            +
                parser.add_argument("--vad_merge_window", type=optional_float, default=5, \
         | 
| 42 | 
            +
                                    help="The window size (in seconds) to merge voice segments")
         | 
| 43 | 
            +
                parser.add_argument("--vad_max_merge_size", type=optional_float, default=30,\
         | 
| 44 | 
            +
                                     help="The maximum size (in seconds) of a voice segment")
         | 
| 45 | 
            +
                parser.add_argument("--vad_padding", type=optional_float, default=1, \
         | 
| 46 | 
            +
                                    help="The padding (in seconds) to add to each voice segment")
         | 
| 47 | 
            +
                parser.add_argument("--vad_prompt_window", type=optional_float, default=3, \
         | 
| 48 | 
            +
                                    help="The window size of the prompt to pass to Whisper")
         | 
| 49 | 
            +
                parser.add_argument("--vad_cpu_cores", type=int, default=app_config.vad_cpu_cores, \
         | 
| 50 | 
            +
                                    help="The number of CPU cores to use for VAD pre-processing.") # 1
         | 
| 51 | 
            +
                parser.add_argument("--vad_parallel_devices", type=str, default=app_config.vad_parallel_devices, \
         | 
| 52 | 
            +
                                    help="A commma delimited list of CUDA devices to use for parallel processing. If None, disable parallel processing.") # ""
         | 
| 53 | 
            +
                parser.add_argument("--auto_parallel", type=bool, default=app_config.auto_parallel, \
         | 
| 54 | 
            +
                                    help="True to use all available GPUs and CPU cores for processing. Use vad_cpu_cores/vad_parallel_devices to specify the number of CPU cores/GPUs to use.") # False
         | 
| 55 | 
            +
             | 
| 56 | 
            +
                parser.add_argument("--temperature", type=float, default=0, \
         | 
| 57 | 
            +
                                    help="temperature to use for sampling")
         | 
| 58 | 
            +
                parser.add_argument("--best_of", type=optional_int, default=5, \
         | 
| 59 | 
            +
                                    help="number of candidates when sampling with non-zero temperature")
         | 
| 60 | 
            +
                parser.add_argument("--beam_size", type=optional_int, default=5, \
         | 
| 61 | 
            +
                                    help="number of beams in beam search, only applicable when temperature is zero")
         | 
| 62 | 
            +
                parser.add_argument("--patience", type=float, default=None, \
         | 
| 63 | 
            +
                                    help="optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search")
         | 
| 64 | 
            +
                parser.add_argument("--length_penalty", type=float, default=None, \
         | 
| 65 | 
            +
                                    help="optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple lengt normalization by default")
         | 
| 66 | 
            +
             | 
| 67 | 
            +
                parser.add_argument("--suppress_tokens", type=str, default="-1", \
         | 
| 68 | 
            +
                                    help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations")
         | 
| 69 | 
            +
                parser.add_argument("--initial_prompt", type=str, default=None, \
         | 
| 70 | 
            +
                                    help="optional text to provide as a prompt for the first window.")
         | 
| 71 | 
            +
                parser.add_argument("--condition_on_previous_text", type=str2bool, default=True, \
         | 
| 72 | 
            +
                                    help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop")
         | 
| 73 | 
            +
                parser.add_argument("--fp16", type=str2bool, default=True, \
         | 
| 74 | 
            +
                                    help="whether to perform inference in fp16; True by default")
         | 
| 75 | 
            +
             | 
| 76 | 
            +
                parser.add_argument("--temperature_increment_on_fallback", type=optional_float, default=0.2, \
         | 
| 77 | 
            +
                                    help="temperature to increase when falling back when the decoding fails to meet either of the thresholds below")
         | 
| 78 | 
            +
                parser.add_argument("--compression_ratio_threshold", type=optional_float, default=2.4, \
         | 
| 79 | 
            +
                                    help="if the gzip compression ratio is higher than this value, treat the decoding as failed")
         | 
| 80 | 
            +
                parser.add_argument("--logprob_threshold", type=optional_float, default=-1.0, \
         | 
| 81 | 
            +
                                    help="if the average log probability is lower than this value, treat the decoding as failed")
         | 
| 82 | 
            +
                parser.add_argument("--no_speech_threshold", type=optional_float, default=0.6, \
         | 
| 83 | 
            +
                                    help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence")
         | 
| 84 |  | 
| 85 | 
             
                args = parser.parse_args().__dict__
         | 
| 86 | 
             
                model_name: str = args.pop("model")
         | 
|  | |
| 107 | 
             
                vad_prompt_window = args.pop("vad_prompt_window")
         | 
| 108 | 
             
                vad_cpu_cores = args.pop("vad_cpu_cores")
         | 
| 109 | 
             
                auto_parallel = args.pop("auto_parallel")
         | 
| 110 | 
            +
                
         | 
| 111 | 
            +
                transcriber = WhisperTranscriber(delete_uploaded_files=False, vad_cpu_cores=vad_cpu_cores, app_config=app_config)
         | 
|  | |
| 112 | 
             
                transcriber.set_parallel_devices(args.pop("vad_parallel_devices"))
         | 
| 113 | 
             
                transcriber.set_auto_parallel(auto_parallel)
         | 
| 114 |  | 
| 115 | 
            +
                model = WhisperContainer(model_name, device=device, download_root=model_dir, models=app_config.models)
         | 
| 116 | 
            +
             | 
| 117 | 
             
                if (transcriber._has_parallel_devices()):
         | 
| 118 | 
             
                    print("Using parallel devices:", transcriber.parallel_device_list)
         | 
| 119 |  | 
    	
        config.json5
    ADDED
    
    | @@ -0,0 +1,62 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            {
         | 
| 2 | 
            +
                "models": [
         | 
| 3 | 
            +
                    // Configuration for the built-in models. You can remove any of these 
         | 
| 4 | 
            +
                    // if you don't want to use the default models.
         | 
| 5 | 
            +
                    {
         | 
| 6 | 
            +
                        "name": "tiny",
         | 
| 7 | 
            +
                        "url": "tiny" 
         | 
| 8 | 
            +
                    },
         | 
| 9 | 
            +
                    {
         | 
| 10 | 
            +
                        "name": "base",
         | 
| 11 | 
            +
                        "url": "base"
         | 
| 12 | 
            +
                    },
         | 
| 13 | 
            +
                    {
         | 
| 14 | 
            +
                        "name": "small",
         | 
| 15 | 
            +
                        "url": "small"
         | 
| 16 | 
            +
                    },
         | 
| 17 | 
            +
                    {
         | 
| 18 | 
            +
                        "name": "medium",
         | 
| 19 | 
            +
                        "url": "medium"
         | 
| 20 | 
            +
                    },
         | 
| 21 | 
            +
                    {
         | 
| 22 | 
            +
                        "name": "large",
         | 
| 23 | 
            +
                        "url": "large"
         | 
| 24 | 
            +
                    },
         | 
| 25 | 
            +
                    {
         | 
| 26 | 
            +
                        "name": "large-v2",
         | 
| 27 | 
            +
                        "url": "large-v2"
         | 
| 28 | 
            +
                    },
         | 
| 29 | 
            +
                    // Uncomment to add custom Japanese models
         | 
| 30 | 
            +
                    //{
         | 
| 31 | 
            +
                    //    "name": "whisper-large-v2-mix-jp",
         | 
| 32 | 
            +
                    //    "url": "vumichien/whisper-large-v2-mix-jp",
         | 
| 33 | 
            +
                    //    // The type of the model. Can be "huggingface" or "whisper" - "whisper" is the default.
         | 
| 34 | 
            +
                    //    // HuggingFace models are loaded using the HuggingFace transformers library and then converted to Whisper models.
         | 
| 35 | 
            +
                    //    "type": "huggingface",
         | 
| 36 | 
            +
                    //}
         | 
| 37 | 
            +
                ],
         | 
| 38 | 
            +
                // Configuration options that will be used if they are not specified in the command line arguments.
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                // Maximum audio file length in seconds, or -1 for no limit.
         | 
| 41 | 
            +
                "input_audio_max_duration": 600,
         | 
| 42 | 
            +
                // True to share the app on HuggingFace.
         | 
| 43 | 
            +
                "share": false,
         | 
| 44 | 
            +
                // The host or IP to bind to. If None, bind to localhost.
         | 
| 45 | 
            +
                "server_name": null,
         | 
| 46 | 
            +
                // The port to bind to.
         | 
| 47 | 
            +
                "server_port": 7860,
         | 
| 48 | 
            +
                // The default model name.
         | 
| 49 | 
            +
                "default_model_name": "medium",
         | 
| 50 | 
            +
                // The default VAD.
         | 
| 51 | 
            +
                "default_vad": "silero-vad",
         | 
| 52 | 
            +
                // A commma delimited list of CUDA devices to use for parallel processing. If None, disable parallel processing.
         | 
| 53 | 
            +
                "vad_parallel_devices": "",
         | 
| 54 | 
            +
                // The number of CPU cores to use for VAD pre-processing.
         | 
| 55 | 
            +
                "vad_cpu_cores": 1,
         | 
| 56 | 
            +
                // The number of seconds before inactivate processes are terminated. Use 0 to close processes immediately, or None for no timeout.
         | 
| 57 | 
            +
                "vad_process_timeout": 1800,
         | 
| 58 | 
            +
                // True to use all available GPUs and CPU cores for processing. Use vad_cpu_cores/vad_parallel_devices to specify the number of CPU cores/GPUs to use.
         | 
| 59 | 
            +
                "auto_parallel": false,
         | 
| 60 | 
            +
                // Directory to save the outputs
         | 
| 61 | 
            +
                "output_dir": null
         | 
| 62 | 
            +
            }
         | 
    	
        requirements.txt
    CHANGED
    
    | @@ -1,7 +1,9 @@ | |
|  | |
| 1 | 
             
            git+https://github.com/openai/whisper.git
         | 
| 2 | 
             
            transformers
         | 
| 3 | 
             
            ffmpeg-python==0.2.0
         | 
| 4 | 
             
            gradio==3.13.0
         | 
| 5 | 
             
            yt-dlp
         | 
| 6 | 
             
            torchaudio
         | 
| 7 | 
            -
            altair
         | 
|  | 
|  | |
| 1 | 
            +
            git+https://github.com/huggingface/transformers
         | 
| 2 | 
             
            git+https://github.com/openai/whisper.git
         | 
| 3 | 
             
            transformers
         | 
| 4 | 
             
            ffmpeg-python==0.2.0
         | 
| 5 | 
             
            gradio==3.13.0
         | 
| 6 | 
             
            yt-dlp
         | 
| 7 | 
             
            torchaudio
         | 
| 8 | 
            +
            altair
         | 
| 9 | 
            +
            json5
         | 
    	
        src/config.py
    ADDED
    
    | @@ -0,0 +1,134 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import urllib
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import os
         | 
| 4 | 
            +
            from typing import List
         | 
| 5 | 
            +
            from urllib.parse import urlparse
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            from tqdm import tqdm
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            from src.conversion.hf_converter import convert_hf_whisper
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            class ModelConfig:
         | 
| 12 | 
            +
                def __init__(self, name: str, url: str, path: str = None, type: str = "whisper"):
         | 
| 13 | 
            +
                    """
         | 
| 14 | 
            +
                    Initialize a model configuration.
         | 
| 15 | 
            +
             | 
| 16 | 
            +
                    name: Name of the model
         | 
| 17 | 
            +
                    url: URL to download the model from
         | 
| 18 | 
            +
                    path: Path to the model file. If not set, the model will be downloaded from the URL.
         | 
| 19 | 
            +
                    type: Type of model. Can be whisper or huggingface.
         | 
| 20 | 
            +
                    """
         | 
| 21 | 
            +
                    self.name = name
         | 
| 22 | 
            +
                    self.url = url
         | 
| 23 | 
            +
                    self.path = path
         | 
| 24 | 
            +
                    self.type = type
         | 
| 25 | 
            +
             | 
| 26 | 
            +
                def download_url(self, root_dir: str):
         | 
| 27 | 
            +
                    import whisper
         | 
| 28 | 
            +
             | 
| 29 | 
            +
                    # See if path is already set
         | 
| 30 | 
            +
                    if self.path is not None:
         | 
| 31 | 
            +
                        return self.path
         | 
| 32 | 
            +
                    
         | 
| 33 | 
            +
                    if root_dir is None:
         | 
| 34 | 
            +
                        root_dir = os.path.join(os.path.expanduser("~"), ".cache", "whisper")
         | 
| 35 | 
            +
             | 
| 36 | 
            +
                    model_type = self.type.lower() if self.type is not None else "whisper"
         | 
| 37 | 
            +
             | 
| 38 | 
            +
                    if model_type in ["huggingface", "hf"]:
         | 
| 39 | 
            +
                        self.path = self.url
         | 
| 40 | 
            +
                        destination_target = os.path.join(root_dir, self.name + ".pt")
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                        # Convert from HuggingFace format to Whisper format
         | 
| 43 | 
            +
                        if os.path.exists(destination_target):
         | 
| 44 | 
            +
                            print(f"File {destination_target} already exists, skipping conversion")
         | 
| 45 | 
            +
                        else:
         | 
| 46 | 
            +
                            print("Saving HuggingFace model in Whisper format to " + destination_target)
         | 
| 47 | 
            +
                            convert_hf_whisper(self.url, destination_target)
         | 
| 48 | 
            +
             | 
| 49 | 
            +
                        self.path = destination_target
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                    elif model_type in ["whisper", "w"]:
         | 
| 52 | 
            +
                        self.path = self.url
         | 
| 53 | 
            +
             | 
| 54 | 
            +
                        # See if URL is just a file
         | 
| 55 | 
            +
                        if self.url in whisper._MODELS:
         | 
| 56 | 
            +
                            # No need to download anything - Whisper will handle it
         | 
| 57 | 
            +
                            self.path = self.url
         | 
| 58 | 
            +
                        elif self.url.startswith("file://"):
         | 
| 59 | 
            +
                            # Get file path
         | 
| 60 | 
            +
                            self.path = urlparse(self.url).path
         | 
| 61 | 
            +
                        # See if it is an URL
         | 
| 62 | 
            +
                        elif self.url.startswith("http://") or self.url.startswith("https://"):
         | 
| 63 | 
            +
                            # Extension (or file name)
         | 
| 64 | 
            +
                            extension = os.path.splitext(self.url)[-1]
         | 
| 65 | 
            +
                            download_target = os.path.join(root_dir, self.name + extension)
         | 
| 66 | 
            +
             | 
| 67 | 
            +
                            if os.path.exists(download_target) and not os.path.isfile(download_target):
         | 
| 68 | 
            +
                                raise RuntimeError(f"{download_target} exists and is not a regular file")
         | 
| 69 | 
            +
             | 
| 70 | 
            +
                            if not os.path.isfile(download_target):
         | 
| 71 | 
            +
                                self._download_file(self.url, download_target)
         | 
| 72 | 
            +
                            else:
         | 
| 73 | 
            +
                                print(f"File {download_target} already exists, skipping download")
         | 
| 74 | 
            +
             | 
| 75 | 
            +
                            self.path = download_target
         | 
| 76 | 
            +
                        # Must be a local file
         | 
| 77 | 
            +
                        else:
         | 
| 78 | 
            +
                            self.path = self.url
         | 
| 79 | 
            +
             | 
| 80 | 
            +
                    else:
         | 
| 81 | 
            +
                        raise ValueError(f"Unknown model type {model_type}")
         | 
| 82 | 
            +
             | 
| 83 | 
            +
                    return self.path
         | 
| 84 | 
            +
             | 
| 85 | 
            +
                def _download_file(self, url: str, destination: str):
         | 
| 86 | 
            +
                    with urllib.request.urlopen(url) as source, open(destination, "wb") as output:
         | 
| 87 | 
            +
                        with tqdm(
         | 
| 88 | 
            +
                            total=int(source.info().get("Content-Length")),
         | 
| 89 | 
            +
                            ncols=80,
         | 
| 90 | 
            +
                            unit="iB",
         | 
| 91 | 
            +
                            unit_scale=True,
         | 
| 92 | 
            +
                            unit_divisor=1024,
         | 
| 93 | 
            +
                        ) as loop:
         | 
| 94 | 
            +
                            while True:
         | 
| 95 | 
            +
                                buffer = source.read(8192)
         | 
| 96 | 
            +
                                if not buffer:
         | 
| 97 | 
            +
                                    break
         | 
| 98 | 
            +
             | 
| 99 | 
            +
                                output.write(buffer)
         | 
| 100 | 
            +
                                loop.update(len(buffer))
         | 
| 101 | 
            +
             | 
| 102 | 
            +
            class ApplicationConfig:
         | 
| 103 | 
            +
                def __init__(self, models: List[ModelConfig] = [], input_audio_max_duration: int = 600, 
         | 
| 104 | 
            +
                             share: bool = False, server_name: str = None, server_port: int = 7860, default_model_name: str = "medium", 
         | 
| 105 | 
            +
                             default_vad: str = "silero-vad", vad_parallel_devices: str = "", vad_cpu_cores: int = 1, vad_process_timeout: int = 1800, 
         | 
| 106 | 
            +
                             auto_parallel: bool = False, output_dir: str = None):
         | 
| 107 | 
            +
                    self.models = models
         | 
| 108 | 
            +
                    self.input_audio_max_duration = input_audio_max_duration
         | 
| 109 | 
            +
                    self.share = share
         | 
| 110 | 
            +
                    self.server_name = server_name
         | 
| 111 | 
            +
                    self.server_port = server_port
         | 
| 112 | 
            +
                    self.default_model_name = default_model_name
         | 
| 113 | 
            +
                    self.default_vad = default_vad
         | 
| 114 | 
            +
                    self.vad_parallel_devices = vad_parallel_devices
         | 
| 115 | 
            +
                    self.vad_cpu_cores = vad_cpu_cores
         | 
| 116 | 
            +
                    self.vad_process_timeout = vad_process_timeout
         | 
| 117 | 
            +
                    self.auto_parallel = auto_parallel
         | 
| 118 | 
            +
                    self.output_dir = output_dir
         | 
| 119 | 
            +
             | 
| 120 | 
            +
                def get_model_names(self):
         | 
| 121 | 
            +
                    return [ x.name for x in self.models ]
         | 
| 122 | 
            +
             | 
| 123 | 
            +
                @staticmethod
         | 
| 124 | 
            +
                def parse_file(config_path: str):
         | 
| 125 | 
            +
                    import json5
         | 
| 126 | 
            +
             | 
| 127 | 
            +
                    with open(config_path, "r") as f:
         | 
| 128 | 
            +
                        # Load using json5
         | 
| 129 | 
            +
                        data = json5.load(f)
         | 
| 130 | 
            +
                        data_models = data.pop("models", [])
         | 
| 131 | 
            +
             | 
| 132 | 
            +
                        models = [ ModelConfig(**x) for x in data_models ]
         | 
| 133 | 
            +
             | 
| 134 | 
            +
                        return ApplicationConfig(models, **data)
         | 
    	
        src/conversion/hf_converter.py
    ADDED
    
    | @@ -0,0 +1,67 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # https://github.com/bayartsogt-ya/whisper-multiple-hf-datasets
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            from copy import deepcopy
         | 
| 4 | 
            +
            import torch
         | 
| 5 | 
            +
            from transformers import WhisperForConditionalGeneration
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            WHISPER_MAPPING = {
         | 
| 8 | 
            +
                "layers": "blocks",
         | 
| 9 | 
            +
                "fc1": "mlp.0",
         | 
| 10 | 
            +
                "fc2": "mlp.2",
         | 
| 11 | 
            +
                "final_layer_norm": "mlp_ln",
         | 
| 12 | 
            +
                "layers": "blocks",
         | 
| 13 | 
            +
                ".self_attn.q_proj": ".attn.query",
         | 
| 14 | 
            +
                ".self_attn.k_proj": ".attn.key",
         | 
| 15 | 
            +
                ".self_attn.v_proj": ".attn.value",
         | 
| 16 | 
            +
                ".self_attn_layer_norm": ".attn_ln",
         | 
| 17 | 
            +
                ".self_attn.out_proj": ".attn.out",
         | 
| 18 | 
            +
                ".encoder_attn.q_proj": ".cross_attn.query",
         | 
| 19 | 
            +
                ".encoder_attn.k_proj": ".cross_attn.key",
         | 
| 20 | 
            +
                ".encoder_attn.v_proj": ".cross_attn.value",
         | 
| 21 | 
            +
                ".encoder_attn_layer_norm": ".cross_attn_ln",
         | 
| 22 | 
            +
                ".encoder_attn.out_proj": ".cross_attn.out",
         | 
| 23 | 
            +
                "decoder.layer_norm.": "decoder.ln.",
         | 
| 24 | 
            +
                "encoder.layer_norm.": "encoder.ln_post.",
         | 
| 25 | 
            +
                "embed_tokens": "token_embedding",
         | 
| 26 | 
            +
                "encoder.embed_positions.weight": "encoder.positional_embedding",
         | 
| 27 | 
            +
                "decoder.embed_positions.weight": "decoder.positional_embedding",
         | 
| 28 | 
            +
                "layer_norm": "ln_post",
         | 
| 29 | 
            +
            }
         | 
| 30 | 
            +
             | 
| 31 | 
            +
             | 
| 32 | 
            +
            def rename_keys(s_dict):
         | 
| 33 | 
            +
                keys = list(s_dict.keys())
         | 
| 34 | 
            +
                for key in keys:
         | 
| 35 | 
            +
                    new_key = key
         | 
| 36 | 
            +
                    for k, v in WHISPER_MAPPING.items():
         | 
| 37 | 
            +
                        if k in key:
         | 
| 38 | 
            +
                            new_key = new_key.replace(k, v)
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                    print(f"{key} -> {new_key}")
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                    s_dict[new_key] = s_dict.pop(key)
         | 
| 43 | 
            +
                return s_dict
         | 
| 44 | 
            +
             | 
| 45 | 
            +
             | 
| 46 | 
            +
            def convert_hf_whisper(hf_model_name_or_path: str, whisper_state_path: str):
         | 
| 47 | 
            +
                transformer_model = WhisperForConditionalGeneration.from_pretrained(hf_model_name_or_path)
         | 
| 48 | 
            +
                config = transformer_model.config
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                # first build dims
         | 
| 51 | 
            +
                dims = {
         | 
| 52 | 
            +
                    'n_mels': config.num_mel_bins,
         | 
| 53 | 
            +
                    'n_vocab': config.vocab_size,
         | 
| 54 | 
            +
                    'n_audio_ctx': config.max_source_positions,
         | 
| 55 | 
            +
                    'n_audio_state': config.d_model,
         | 
| 56 | 
            +
                    'n_audio_head': config.encoder_attention_heads,
         | 
| 57 | 
            +
                    'n_audio_layer': config.encoder_layers,
         | 
| 58 | 
            +
                    'n_text_ctx': config.max_target_positions,
         | 
| 59 | 
            +
                    'n_text_state': config.d_model,
         | 
| 60 | 
            +
                    'n_text_head': config.decoder_attention_heads,
         | 
| 61 | 
            +
                    'n_text_layer': config.decoder_layers
         | 
| 62 | 
            +
                }
         | 
| 63 | 
            +
             | 
| 64 | 
            +
                state_dict = deepcopy(transformer_model.model.state_dict())
         | 
| 65 | 
            +
                state_dict = rename_keys(state_dict)
         | 
| 66 | 
            +
             | 
| 67 | 
            +
                torch.save({"dims": dims, "model_state_dict": state_dict}, whisper_state_path)
         | 
    	
        src/whisperContainer.py
    CHANGED
    
    | @@ -1,11 +1,14 @@ | |
| 1 | 
             
            # External programs
         | 
| 2 | 
             
            import os
         | 
|  | |
| 3 | 
             
            import whisper
         | 
|  | |
| 4 |  | 
| 5 | 
             
            from src.modelCache import GLOBAL_MODEL_CACHE, ModelCache
         | 
| 6 |  | 
| 7 | 
             
            class WhisperContainer:
         | 
| 8 | 
            -
                def __init__(self, model_name: str, device: str = None, download_root: str = None,  | 
|  | |
| 9 | 
             
                    self.model_name = model_name
         | 
| 10 | 
             
                    self.device = device
         | 
| 11 | 
             
                    self.download_root = download_root
         | 
| @@ -13,6 +16,9 @@ class WhisperContainer: | |
| 13 |  | 
| 14 | 
             
                    # Will be created on demand
         | 
| 15 | 
             
                    self.model = None
         | 
|  | |
|  | |
|  | |
| 16 |  | 
| 17 | 
             
                def get_model(self):
         | 
| 18 | 
             
                    if self.model is None:
         | 
| @@ -32,21 +38,40 @@ class WhisperContainer: | |
| 32 | 
             
                    # Warning: Using private API here
         | 
| 33 | 
             
                    try:
         | 
| 34 | 
             
                        root_dir = self.download_root
         | 
|  | |
| 35 |  | 
| 36 | 
             
                        if root_dir is None:
         | 
| 37 | 
             
                            root_dir = os.path.join(os.path.expanduser("~"), ".cache", "whisper")
         | 
| 38 |  | 
| 39 | 
             
                        if self.model_name in whisper._MODELS:
         | 
| 40 | 
             
                            whisper._download(whisper._MODELS[self.model_name], root_dir, False)
         | 
|  | |
|  | |
|  | |
| 41 | 
             
                        return True
         | 
|  | |
| 42 | 
             
                    except Exception as e:
         | 
| 43 | 
             
                        # Given that the API is private, it could change at any time. We don't want to crash the program
         | 
| 44 | 
             
                        print("Error pre-downloading model: " + str(e))
         | 
| 45 | 
             
                        return False
         | 
| 46 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 47 | 
             
                def _create_model(self):
         | 
| 48 | 
             
                    print("Loading whisper model " + self.model_name)
         | 
| 49 | 
            -
                     | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 50 |  | 
| 51 | 
             
                def create_callback(self, language: str = None, task: str = None, initial_prompt: str = None, **decodeOptions: dict):
         | 
| 52 | 
             
                    """
         | 
| @@ -71,12 +96,13 @@ class WhisperContainer: | |
| 71 |  | 
| 72 | 
             
                # This is required for multiprocessing
         | 
| 73 | 
             
                def __getstate__(self):
         | 
| 74 | 
            -
                    return { "model_name": self.model_name, "device": self.device, "download_root": self.download_root }
         | 
| 75 |  | 
| 76 | 
             
                def __setstate__(self, state):
         | 
| 77 | 
             
                    self.model_name = state["model_name"]
         | 
| 78 | 
             
                    self.device = state["device"]
         | 
| 79 | 
             
                    self.download_root = state["download_root"]
         | 
|  | |
| 80 | 
             
                    self.model = None
         | 
| 81 | 
             
                    # Depickled objects must use the global cache
         | 
| 82 | 
             
                    self.cache = GLOBAL_MODEL_CACHE
         | 
|  | |
| 1 | 
             
            # External programs
         | 
| 2 | 
             
            import os
         | 
| 3 | 
            +
            from typing import List
         | 
| 4 | 
             
            import whisper
         | 
| 5 | 
            +
            from src.config import ModelConfig
         | 
| 6 |  | 
| 7 | 
             
            from src.modelCache import GLOBAL_MODEL_CACHE, ModelCache
         | 
| 8 |  | 
| 9 | 
             
            class WhisperContainer:
         | 
| 10 | 
            +
                def __init__(self, model_name: str, device: str = None, download_root: str = None, 
         | 
| 11 | 
            +
                             cache: ModelCache = None, models: List[ModelConfig] = []):
         | 
| 12 | 
             
                    self.model_name = model_name
         | 
| 13 | 
             
                    self.device = device
         | 
| 14 | 
             
                    self.download_root = download_root
         | 
|  | |
| 16 |  | 
| 17 | 
             
                    # Will be created on demand
         | 
| 18 | 
             
                    self.model = None
         | 
| 19 | 
            +
             | 
| 20 | 
            +
                    # List of known models
         | 
| 21 | 
            +
                    self.models = models
         | 
| 22 |  | 
| 23 | 
             
                def get_model(self):
         | 
| 24 | 
             
                    if self.model is None:
         | 
|  | |
| 38 | 
             
                    # Warning: Using private API here
         | 
| 39 | 
             
                    try:
         | 
| 40 | 
             
                        root_dir = self.download_root
         | 
| 41 | 
            +
                        model_config = self.get_model_config()
         | 
| 42 |  | 
| 43 | 
             
                        if root_dir is None:
         | 
| 44 | 
             
                            root_dir = os.path.join(os.path.expanduser("~"), ".cache", "whisper")
         | 
| 45 |  | 
| 46 | 
             
                        if self.model_name in whisper._MODELS:
         | 
| 47 | 
             
                            whisper._download(whisper._MODELS[self.model_name], root_dir, False)
         | 
| 48 | 
            +
                        else:
         | 
| 49 | 
            +
                            # If the model is not in the official list, see if it needs to be downloaded
         | 
| 50 | 
            +
                            model_config.download_url(root_dir)
         | 
| 51 | 
             
                        return True
         | 
| 52 | 
            +
                    
         | 
| 53 | 
             
                    except Exception as e:
         | 
| 54 | 
             
                        # Given that the API is private, it could change at any time. We don't want to crash the program
         | 
| 55 | 
             
                        print("Error pre-downloading model: " + str(e))
         | 
| 56 | 
             
                        return False
         | 
| 57 |  | 
| 58 | 
            +
                def get_model_config(self) -> ModelConfig:
         | 
| 59 | 
            +
                    """
         | 
| 60 | 
            +
                    Get the model configuration for the model.
         | 
| 61 | 
            +
                    """
         | 
| 62 | 
            +
                    for model in self.models:
         | 
| 63 | 
            +
                        if model.name == self.model_name:
         | 
| 64 | 
            +
                            return model
         | 
| 65 | 
            +
                    return None
         | 
| 66 | 
            +
             | 
| 67 | 
             
                def _create_model(self):
         | 
| 68 | 
             
                    print("Loading whisper model " + self.model_name)
         | 
| 69 | 
            +
                    
         | 
| 70 | 
            +
                    model_config = self.get_model_config()
         | 
| 71 | 
            +
                    # Note that the model will not be downloaded in the case of an official Whisper model
         | 
| 72 | 
            +
                    model_path = model_config.download_url(self.download_root)
         | 
| 73 | 
            +
             | 
| 74 | 
            +
                    return whisper.load_model(model_path, device=self.device, download_root=self.download_root)
         | 
| 75 |  | 
| 76 | 
             
                def create_callback(self, language: str = None, task: str = None, initial_prompt: str = None, **decodeOptions: dict):
         | 
| 77 | 
             
                    """
         | 
|  | |
| 96 |  | 
| 97 | 
             
                # This is required for multiprocessing
         | 
| 98 | 
             
                def __getstate__(self):
         | 
| 99 | 
            +
                    return { "model_name": self.model_name, "device": self.device, "download_root": self.download_root, "models": self.models }
         | 
| 100 |  | 
| 101 | 
             
                def __setstate__(self, state):
         | 
| 102 | 
             
                    self.model_name = state["model_name"]
         | 
| 103 | 
             
                    self.device = state["device"]
         | 
| 104 | 
             
                    self.download_root = state["download_root"]
         | 
| 105 | 
            +
                    self.models = state["models"]
         | 
| 106 | 
             
                    self.model = None
         | 
| 107 | 
             
                    # Depickled objects must use the global cache
         | 
| 108 | 
             
                    self.cache = GLOBAL_MODEL_CACHE
         |