Spaces:
				
			
			
	
			
			
		Build error
		
	
	
	
			
			
	
	
	
	
		
		
		Build error
		
	Refactor function names
Browse filesAlso prepare code for creating a CLI
- app-local.py +2 -2
- app-network.py +2 -2
- app-shared.py +2 -2
- app.py +50 -46
- src/download.py +4 -4
    	
        app-local.py
    CHANGED
    
    | @@ -1,3 +1,3 @@ | |
| 1 | 
             
            # Run the app with no audio file restrictions
         | 
| 2 | 
            -
            from app import  | 
| 3 | 
            -
             | 
|  | |
| 1 | 
             
            # Run the app with no audio file restrictions
         | 
| 2 | 
            +
            from app import create_ui
         | 
| 3 | 
            +
            create_ui(-1)
         | 
    	
        app-network.py
    CHANGED
    
    | @@ -1,3 +1,3 @@ | |
| 1 | 
             
            # Run the app with no audio file restrictions, and make it available on the network
         | 
| 2 | 
            -
            from app import  | 
| 3 | 
            -
             | 
|  | |
| 1 | 
             
            # Run the app with no audio file restrictions, and make it available on the network
         | 
| 2 | 
            +
            from app import create_ui
         | 
| 3 | 
            +
            create_ui(-1, server_name="0.0.0.0")
         | 
    	
        app-shared.py
    CHANGED
    
    | @@ -1,3 +1,3 @@ | |
| 1 | 
             
            # Run the app with no audio file restrictions
         | 
| 2 | 
            -
            from app import  | 
| 3 | 
            -
             | 
|  | |
| 1 | 
             
            # Run the app with no audio file restrictions
         | 
| 2 | 
            +
            from app import create_ui
         | 
| 3 | 
            +
            create_ui(-1, share=True)
         | 
    	
        app.py
    CHANGED
    
    | @@ -12,7 +12,7 @@ import ffmpeg | |
| 12 | 
             
            # UI
         | 
| 13 | 
             
            import gradio as gr
         | 
| 14 |  | 
| 15 | 
            -
            from src.download import ExceededMaximumDuration,  | 
| 16 | 
             
            from src.utils import slugify, write_srt, write_vtt
         | 
| 17 | 
             
            from src.vad import VadPeriodicTranscription, VadSileroTranscription
         | 
| 18 |  | 
| @@ -45,26 +45,27 @@ LANGUAGES = [ | |
| 45 | 
             
             "Hausa", "Bashkir", "Javanese", "Sundanese"
         | 
| 46 | 
             
            ]
         | 
| 47 |  | 
| 48 | 
            -
             | 
|  | |
|  | |
| 49 |  | 
| 50 | 
            -
            class UI:
         | 
| 51 | 
            -
                def __init__(self, inputAudioMaxDuration):
         | 
| 52 | 
             
                    self.vad_model = None
         | 
| 53 | 
             
                    self.inputAudioMaxDuration = inputAudioMaxDuration
         | 
|  | |
| 54 |  | 
| 55 | 
            -
                def  | 
| 56 | 
             
                    try:
         | 
| 57 | 
            -
                        source, sourceName = self. | 
| 58 |  | 
| 59 | 
             
                        try:
         | 
| 60 | 
             
                            selectedLanguage = languageName.lower() if len(languageName) > 0 else None
         | 
| 61 | 
             
                            selectedModel = modelName if modelName is not None else "base"
         | 
| 62 |  | 
| 63 | 
            -
                            model = model_cache.get(selectedModel, None)
         | 
| 64 |  | 
| 65 | 
             
                            if not model:
         | 
| 66 | 
             
                                model = whisper.load_model(selectedModel)
         | 
| 67 | 
            -
                                model_cache[selectedModel] = model
         | 
| 68 |  | 
| 69 | 
             
                            # Callable for processing an audio file
         | 
| 70 | 
             
                            whisperCallable = lambda audio : model.transcribe(audio, language=selectedLanguage, task=task)
         | 
| @@ -100,36 +101,39 @@ class UI: | |
| 100 | 
             
                            text = result["text"]
         | 
| 101 |  | 
| 102 | 
             
                            language = result["language"]
         | 
| 103 | 
            -
                            languageMaxLineWidth =  | 
| 104 |  | 
| 105 | 
             
                            print("Max line width " + str(languageMaxLineWidth))
         | 
| 106 | 
            -
                            vtt =  | 
| 107 | 
            -
                            srt =  | 
| 108 |  | 
| 109 | 
             
                            # Files that can be downloaded
         | 
| 110 | 
             
                            downloadDirectory = tempfile.mkdtemp()
         | 
| 111 | 
             
                            filePrefix = slugify(sourceName, allow_unicode=True)
         | 
| 112 |  | 
| 113 | 
             
                            download = []
         | 
| 114 | 
            -
                            download.append( | 
| 115 | 
            -
                            download.append( | 
| 116 | 
            -
                            download.append( | 
| 117 |  | 
| 118 | 
             
                            return download, text, vtt
         | 
| 119 |  | 
| 120 | 
             
                        finally:
         | 
| 121 | 
             
                            # Cleanup source
         | 
| 122 | 
            -
                            if  | 
| 123 | 
             
                                print("Deleting source file " + source)
         | 
| 124 | 
             
                                os.remove(source)
         | 
| 125 |  | 
| 126 | 
             
                    except ExceededMaximumDuration as e:
         | 
| 127 | 
             
                        return [], ("[ERROR]: Maximum remote video length is " + str(e.maxDuration) + "s, file was " + str(e.videoDuration) + "s"), "[ERROR]"
         | 
| 128 |  | 
| 129 | 
            -
                def  | 
|  | |
|  | |
|  | |
| 130 | 
             
                    if urlData:
         | 
| 131 | 
             
                        # Download from YouTube
         | 
| 132 | 
            -
                        source =  | 
| 133 | 
             
                    else:
         | 
| 134 | 
             
                        # File input
         | 
| 135 | 
             
                        source = uploadFile if uploadFile is not None else microphoneData
         | 
| @@ -146,38 +150,38 @@ class UI: | |
| 146 |  | 
| 147 | 
             
                    return source, sourceName
         | 
| 148 |  | 
| 149 | 
            -
            def  | 
| 150 | 
            -
             | 
| 151 | 
            -
             | 
| 152 | 
            -
             | 
| 153 | 
            -
             | 
| 154 | 
            -
             | 
| 155 | 
            -
             | 
| 156 | 
            -
             | 
|  | |
|  | |
|  | |
| 157 |  | 
| 158 | 
            -
             | 
| 159 | 
            -
             | 
| 160 | 
            -
             | 
| 161 | 
            -
             | 
|  | |
|  | |
| 162 |  | 
| 163 | 
            -
             | 
|  | |
| 164 |  | 
| 165 | 
            -
            def  | 
| 166 | 
            -
             | 
|  | |
|  | |
| 167 |  | 
| 168 | 
            -
             | 
| 169 | 
            -
                    write_vtt(segments, file=segmentStream, maxLineWidth=maxLineWidth)
         | 
| 170 | 
            -
                elif format == 'srt':
         | 
| 171 | 
            -
                    write_srt(segments, file=segmentStream, maxLineWidth=maxLineWidth)
         | 
| 172 | 
            -
                else:
         | 
| 173 | 
            -
                    raise Exception("Unknown format " + format)
         | 
| 174 |  | 
| 175 | 
            -
                segmentStream.seek(0)
         | 
| 176 | 
            -
                return segmentStream.read()
         | 
| 177 | 
            -
                
         | 
| 178 |  | 
| 179 | 
            -
            def  | 
| 180 | 
            -
                ui =  | 
| 181 |  | 
| 182 | 
             
                ui_description = "Whisper is a general-purpose speech recognition model. It is trained on a large dataset of diverse " 
         | 
| 183 | 
             
                ui_description += " audio and is also a multi-task model that can perform multilingual speech recognition "
         | 
| @@ -188,9 +192,9 @@ def createUi(inputAudioMaxDuration, share=False, server_name: str = None): | |
| 188 | 
             
                if inputAudioMaxDuration > 0:
         | 
| 189 | 
             
                    ui_description += "\n\n" + "Max audio file length: " + str(inputAudioMaxDuration) + " s"
         | 
| 190 |  | 
| 191 | 
            -
                ui_article = "Read the [documentation  | 
| 192 |  | 
| 193 | 
            -
                demo = gr.Interface(fn=ui. | 
| 194 | 
             
                    gr.Dropdown(choices=["tiny", "base", "small", "medium", "large"], value="medium", label="Model"),
         | 
| 195 | 
             
                    gr.Dropdown(choices=sorted(LANGUAGES), label="Language"),
         | 
| 196 | 
             
                    gr.Text(label="URL (YouTube, etc.)"),
         | 
| @@ -210,4 +214,4 @@ def createUi(inputAudioMaxDuration, share=False, server_name: str = None): | |
| 210 | 
             
                demo.launch(share=share, server_name=server_name)   
         | 
| 211 |  | 
| 212 | 
             
            if __name__ == '__main__':
         | 
| 213 | 
            -
                 | 
|  | |
| 12 | 
             
            # UI
         | 
| 13 | 
             
            import gradio as gr
         | 
| 14 |  | 
| 15 | 
            +
            from src.download import ExceededMaximumDuration, download_url
         | 
| 16 | 
             
            from src.utils import slugify, write_srt, write_vtt
         | 
| 17 | 
             
            from src.vad import VadPeriodicTranscription, VadSileroTranscription
         | 
| 18 |  | 
|  | |
| 45 | 
             
             "Hausa", "Bashkir", "Javanese", "Sundanese"
         | 
| 46 | 
             
            ]
         | 
| 47 |  | 
| 48 | 
            +
            class WhisperTranscriber:
         | 
| 49 | 
            +
                def __init__(self, inputAudioMaxDuration: float = DEFAULT_INPUT_AUDIO_MAX_DURATION, deleteUploadedFiles: bool = DELETE_UPLOADED_FILES):
         | 
| 50 | 
            +
                    self.model_cache = dict()
         | 
| 51 |  | 
|  | |
|  | |
| 52 | 
             
                    self.vad_model = None
         | 
| 53 | 
             
                    self.inputAudioMaxDuration = inputAudioMaxDuration
         | 
| 54 | 
            +
                    self.deleteUploadedFiles = deleteUploadedFiles
         | 
| 55 |  | 
| 56 | 
            +
                def transcribe_file(self, modelName, languageName, urlData, uploadFile, microphoneData, task, vad, vadMergeWindow, vadMaxMergeSize, vadPadding):
         | 
| 57 | 
             
                    try:
         | 
| 58 | 
            +
                        source, sourceName = self.__get_source(urlData, uploadFile, microphoneData)
         | 
| 59 |  | 
| 60 | 
             
                        try:
         | 
| 61 | 
             
                            selectedLanguage = languageName.lower() if len(languageName) > 0 else None
         | 
| 62 | 
             
                            selectedModel = modelName if modelName is not None else "base"
         | 
| 63 |  | 
| 64 | 
            +
                            model = self.model_cache.get(selectedModel, None)
         | 
| 65 |  | 
| 66 | 
             
                            if not model:
         | 
| 67 | 
             
                                model = whisper.load_model(selectedModel)
         | 
| 68 | 
            +
                                self.model_cache[selectedModel] = model
         | 
| 69 |  | 
| 70 | 
             
                            # Callable for processing an audio file
         | 
| 71 | 
             
                            whisperCallable = lambda audio : model.transcribe(audio, language=selectedLanguage, task=task)
         | 
|  | |
| 101 | 
             
                            text = result["text"]
         | 
| 102 |  | 
| 103 | 
             
                            language = result["language"]
         | 
| 104 | 
            +
                            languageMaxLineWidth = self.__get_max_line_width(language)
         | 
| 105 |  | 
| 106 | 
             
                            print("Max line width " + str(languageMaxLineWidth))
         | 
| 107 | 
            +
                            vtt = self.__get_subs(result["segments"], "vtt", languageMaxLineWidth)
         | 
| 108 | 
            +
                            srt = self.__get_subs(result["segments"], "srt", languageMaxLineWidth)
         | 
| 109 |  | 
| 110 | 
             
                            # Files that can be downloaded
         | 
| 111 | 
             
                            downloadDirectory = tempfile.mkdtemp()
         | 
| 112 | 
             
                            filePrefix = slugify(sourceName, allow_unicode=True)
         | 
| 113 |  | 
| 114 | 
             
                            download = []
         | 
| 115 | 
            +
                            download.append(self.__create_file(srt, downloadDirectory, filePrefix + "-subs.srt"));
         | 
| 116 | 
            +
                            download.append(self.__create_file(vtt, downloadDirectory, filePrefix + "-subs.vtt"));
         | 
| 117 | 
            +
                            download.append(self.__create_file(text, downloadDirectory, filePrefix + "-transcript.txt"));
         | 
| 118 |  | 
| 119 | 
             
                            return download, text, vtt
         | 
| 120 |  | 
| 121 | 
             
                        finally:
         | 
| 122 | 
             
                            # Cleanup source
         | 
| 123 | 
            +
                            if self.deleteUploadedFiles:
         | 
| 124 | 
             
                                print("Deleting source file " + source)
         | 
| 125 | 
             
                                os.remove(source)
         | 
| 126 |  | 
| 127 | 
             
                    except ExceededMaximumDuration as e:
         | 
| 128 | 
             
                        return [], ("[ERROR]: Maximum remote video length is " + str(e.maxDuration) + "s, file was " + str(e.videoDuration) + "s"), "[ERROR]"
         | 
| 129 |  | 
| 130 | 
            +
                def clear_cache(self):
         | 
| 131 | 
            +
                    self.model_cache = dict()
         | 
| 132 | 
            +
             | 
| 133 | 
            +
                def __get_source(self, urlData, uploadFile, microphoneData):
         | 
| 134 | 
             
                    if urlData:
         | 
| 135 | 
             
                        # Download from YouTube
         | 
| 136 | 
            +
                        source = download_url(urlData, self.inputAudioMaxDuration)
         | 
| 137 | 
             
                    else:
         | 
| 138 | 
             
                        # File input
         | 
| 139 | 
             
                        source = uploadFile if uploadFile is not None else microphoneData
         | 
|  | |
| 150 |  | 
| 151 | 
             
                    return source, sourceName
         | 
| 152 |  | 
| 153 | 
            +
                def __get_max_line_width(self, language: str) -> int:
         | 
| 154 | 
            +
                    if (language and language.lower() in ["japanese", "ja", "chinese", "zh"]):
         | 
| 155 | 
            +
                        # Chinese characters and kana are wider, so limit line length to 40 characters
         | 
| 156 | 
            +
                        return 40
         | 
| 157 | 
            +
                    else:
         | 
| 158 | 
            +
                        # TODO: Add more languages
         | 
| 159 | 
            +
                        # 80 latin characters should fit on a 1080p/720p screen
         | 
| 160 | 
            +
                        return 80
         | 
| 161 | 
            +
             | 
| 162 | 
            +
                def __get_subs(self, segments: Iterator[dict], format: str, maxLineWidth: int) -> str:
         | 
| 163 | 
            +
                    segmentStream = StringIO()
         | 
| 164 |  | 
| 165 | 
            +
                    if format == 'vtt':
         | 
| 166 | 
            +
                        write_vtt(segments, file=segmentStream, maxLineWidth=maxLineWidth)
         | 
| 167 | 
            +
                    elif format == 'srt':
         | 
| 168 | 
            +
                        write_srt(segments, file=segmentStream, maxLineWidth=maxLineWidth)
         | 
| 169 | 
            +
                    else:
         | 
| 170 | 
            +
                        raise Exception("Unknown format " + format)
         | 
| 171 |  | 
| 172 | 
            +
                    segmentStream.seek(0)
         | 
| 173 | 
            +
                    return segmentStream.read()
         | 
| 174 |  | 
| 175 | 
            +
                def __create_file(self, text: str, directory: str, fileName: str) -> str:
         | 
| 176 | 
            +
                    # Write the text to a file
         | 
| 177 | 
            +
                    with open(os.path.join(directory, fileName), 'w+', encoding="utf-8") as file:
         | 
| 178 | 
            +
                        file.write(text)
         | 
| 179 |  | 
| 180 | 
            +
                    return file.name
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 181 |  | 
|  | |
|  | |
|  | |
| 182 |  | 
| 183 | 
            +
            def create_ui(inputAudioMaxDuration, share=False, server_name: str = None):
         | 
| 184 | 
            +
                ui = WhisperTranscriber(inputAudioMaxDuration)
         | 
| 185 |  | 
| 186 | 
             
                ui_description = "Whisper is a general-purpose speech recognition model. It is trained on a large dataset of diverse " 
         | 
| 187 | 
             
                ui_description += " audio and is also a multi-task model that can perform multilingual speech recognition "
         | 
|  | |
| 192 | 
             
                if inputAudioMaxDuration > 0:
         | 
| 193 | 
             
                    ui_description += "\n\n" + "Max audio file length: " + str(inputAudioMaxDuration) + " s"
         | 
| 194 |  | 
| 195 | 
            +
                ui_article = "Read the [documentation here](https://huggingface.co/spaces/aadnk/whisper-webui/blob/main/docs/options.md)"
         | 
| 196 |  | 
| 197 | 
            +
                demo = gr.Interface(fn=ui.transcribe_file, description=ui_description, article=ui_article, inputs=[
         | 
| 198 | 
             
                    gr.Dropdown(choices=["tiny", "base", "small", "medium", "large"], value="medium", label="Model"),
         | 
| 199 | 
             
                    gr.Dropdown(choices=sorted(LANGUAGES), label="Language"),
         | 
| 200 | 
             
                    gr.Text(label="URL (YouTube, etc.)"),
         | 
|  | |
| 214 | 
             
                demo.launch(share=share, server_name=server_name)   
         | 
| 215 |  | 
| 216 | 
             
            if __name__ == '__main__':
         | 
| 217 | 
            +
                create_ui(DEFAULT_INPUT_AUDIO_MAX_DURATION)
         | 
    	
        src/download.py
    CHANGED
    
    | @@ -13,16 +13,16 @@ class FilenameCollectorPP(PostProcessor): | |
| 13 | 
             
                    self.filenames.append(information["filepath"])
         | 
| 14 | 
             
                    return [], information
         | 
| 15 |  | 
| 16 | 
            -
            def  | 
| 17 | 
             
                try:
         | 
| 18 | 
            -
                    return  | 
| 19 | 
             
                except yt_dlp.utils.DownloadError as e:
         | 
| 20 | 
             
                    # In case of an OS error, try again with a different output template
         | 
| 21 | 
             
                    if e.msg and e.msg.find("[Errno 36] File name too long") >= 0:
         | 
| 22 | 
            -
                        return  | 
| 23 | 
             
                    pass
         | 
| 24 |  | 
| 25 | 
            -
            def  | 
| 26 | 
             
                destinationDirectory = mkdtemp()
         | 
| 27 |  | 
| 28 | 
             
                ydl_opts = {
         | 
|  | |
| 13 | 
             
                    self.filenames.append(information["filepath"])
         | 
| 14 | 
             
                    return [], information
         | 
| 15 |  | 
| 16 | 
            +
            def download_url(url: str, maxDuration: int = None):
         | 
| 17 | 
             
                try:
         | 
| 18 | 
            +
                    return _perform_download(url, maxDuration=maxDuration)
         | 
| 19 | 
             
                except yt_dlp.utils.DownloadError as e:
         | 
| 20 | 
             
                    # In case of an OS error, try again with a different output template
         | 
| 21 | 
             
                    if e.msg and e.msg.find("[Errno 36] File name too long") >= 0:
         | 
| 22 | 
            +
                        return _perform_download(url, maxDuration=maxDuration, outputTemplate="%(title).10s %(id)s.%(ext)s")
         | 
| 23 | 
             
                    pass
         | 
| 24 |  | 
| 25 | 
            +
            def _perform_download(url: str, maxDuration: int = None, outputTemplate: str = None):
         | 
| 26 | 
             
                destinationDirectory = mkdtemp()
         | 
| 27 |  | 
| 28 | 
             
                ydl_opts = {
         |