import os import torch import whisper import gradio as gr import torchaudio from abc import ABC, abstractmethod from typing import BinaryIO, Union, Tuple, List import numpy as np from datetime import datetime from faster_whisper.vad import VadOptions from dataclasses import astuple import gc from copy import deepcopy from modules.uvr.music_separator import MusicSeparator from modules.utils.paths import (WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, OUTPUT_DIR, DEFAULT_PARAMETERS_CONFIG_PATH, UVR_MODELS_DIR) from modules.utils.subtitle_manager import get_srt, get_vtt, get_txt, get_plaintext, get_csv, write_file, safe_filename from modules.utils.youtube_manager import get_ytdata, get_ytaudio from modules.utils.files_manager import get_media_files, format_gradio_files, load_yaml, save_yaml from modules.whisper.whisper_parameter import * from modules.diarize.diarizer import Diarizer from modules.vad.silero_vad import SileroVAD from modules.translation.nllb_inference import NLLBInference from modules.translation.nllb_inference import NLLB_AVAILABLE_LANGS class WhisperBase(ABC): def __init__(self, model_dir: str = WHISPER_MODELS_DIR, diarization_model_dir: str = DIARIZATION_MODELS_DIR, uvr_model_dir: str = UVR_MODELS_DIR, output_dir: str = OUTPUT_DIR, ): self.model_dir = model_dir self.output_dir = output_dir os.makedirs(self.output_dir, exist_ok=True) os.makedirs(self.model_dir, exist_ok=True) self.diarizer = Diarizer( model_dir=diarization_model_dir ) self.vad = SileroVAD() self.music_separator = MusicSeparator( model_dir=uvr_model_dir, output_dir=os.path.join(output_dir, "UVR") ) self.model = None self.current_model_size = None self.available_models = whisper.available_models() self.available_langs = sorted(list(whisper.tokenizer.LANGUAGES.values())) #self.translatable_models = ["large", "large-v1", "large-v2", "large-v3"] self.translatable_models = whisper.available_models() self.device = self.get_device() self.available_compute_types = ["float16", "float32"] self.current_compute_type = "float16" if self.device == "cuda" else "float32" @abstractmethod def transcribe(self, audio: Union[str, BinaryIO, np.ndarray], progress: gr.Progress = gr.Progress(), *whisper_params, ): """Inference whisper model to transcribe""" pass @abstractmethod def update_model(self, model_size: str, compute_type: str, progress: gr.Progress = gr.Progress() ): """Initialize whisper model""" pass def run(self, audio: Union[str, BinaryIO, np.ndarray], progress: gr.Progress = gr.Progress(), add_timestamp: bool = True, *whisper_params, ) -> Tuple[List[dict], float]: """ Run transcription with conditional pre-processing and post-processing. The VAD will be performed to remove noise from the audio input in pre-processing, if enabled. The diarization will be performed in post-processing, if enabled. Parameters ---------- audio: Union[str, BinaryIO, np.ndarray] Audio input. This can be file path or binary type. progress: gr.Progress Indicator to show progress directly in gradio. add_timestamp: bool Whether to add a timestamp at the end of the filename. *whisper_params: tuple Parameters related with whisper. This will be dealt with "WhisperParameters" data class Returns ---------- segments_result: List[dict] list of dicts that includes start, end timestamps and transcribed text elapsed_time: float elapsed time for running """ start_time = datetime.now() params = WhisperParameters.as_value(*whisper_params) # Get the offload params default_params = load_yaml(DEFAULT_PARAMETERS_CONFIG_PATH) whisper_params = default_params["whisper"] diarization_params = default_params["diarization"] bool_whisper_enable_offload = whisper_params["enable_offload"] bool_diarization_enable_offload = diarization_params["enable_offload"] if params.lang is None: pass elif params.lang == "Automatic Detection": params.lang = None else: language_code_dict = {value: key for key, value in whisper.tokenizer.LANGUAGES.items()} params.lang = language_code_dict[params.lang] if params.is_bgm_separate: music, audio, _ = self.music_separator.separate( audio=audio, model_name=params.uvr_model_size, device=params.uvr_device, segment_size=params.uvr_segment_size, save_file=params.uvr_save_file, progress=progress ) if audio.ndim >= 2: audio = audio.mean(axis=1) if self.music_separator.audio_info is None: origin_sample_rate = 16000 else: origin_sample_rate = self.music_separator.audio_info.sample_rate audio = self.resample_audio(audio=audio, original_sample_rate=origin_sample_rate) if params.uvr_enable_offload: self.music_separator.offload() elapsed_time_bgm_sep = datetime.now() - start_time origin_audio = deepcopy(audio) if params.vad_filter: # Explicit value set for float('inf') from gr.Number() if params.max_speech_duration_s is None or params.max_speech_duration_s >= 9999: params.max_speech_duration_s = float('inf') progress(0, desc="Filtering silent parts from audio...") vad_options = VadOptions( threshold=params.threshold, min_speech_duration_ms=params.min_speech_duration_ms, max_speech_duration_s=params.max_speech_duration_s, min_silence_duration_ms=params.min_silence_duration_ms, speech_pad_ms=params.speech_pad_ms ) vad_processed, speech_chunks = self.vad.run( audio=audio, vad_parameters=vad_options, progress=progress ) if vad_processed.size > 0: audio = vad_processed else: params.vad_filter = False result, elapsed_time = self.transcribe( audio, progress, *astuple(params) ) if bool_whisper_enable_offload: self.offload() if params.vad_filter: restored_result = self.vad.restore_speech_timestamps( segments=result, speech_chunks=speech_chunks, ) if restored_result: result = restored_result else: print("VAD detected no speech segments in the audio.") if params.is_diarize: progress(0.99, desc="Diarizing speakers...") result, elapsed_time_diarization = self.diarizer.run( audio=origin_audio, use_auth_token=params.hf_token, transcribed_result=result, device=params.diarization_device ) if bool_diarization_enable_offload: self.diarizer.offload() if not result: print(f"Whisper did not detected any speech segments in the audio.") result = list() progress(1.0, desc="Processing done!") total_elapsed_time = datetime.now() - start_time return result, elapsed_time def transcribe_file(self, files: Optional[List] = None, input_folder_path: Optional[str] = None, file_format: str = "SRT", add_timestamp: bool = True, translate_output: bool = False, translate_model: str = "", target_lang: str = "", add_timestamp_preview: bool = False, progress=gr.Progress(), *whisper_params, ) -> list: """ Write subtitle file from Files Parameters ---------- files: list List of files to transcribe from gr.Files() input_folder_path: str Input folder path to transcribe from gr.Textbox(). If this is provided, `files` will be ignored and this will be used instead. file_format: str Subtitle File format to write from gr.Dropdown(). Supported format: [SRT, WebVTT, txt] add_timestamp: bool Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the subtitle filename. translate_output: bool Translate output translate_model: str Translation model to use target_lang: str Target language to use add_timestamp_preview: bool Boolean value from gr.Checkbox() that determines whether to add a timestamp to output preview progress: gr.Progress Indicator to show progress directly in gradio. *whisper_params: tuple Parameters related with whisper. This will be dealt with "WhisperParameters" data class Returns ---------- result_str: Result of transcription to return to gr.Textbox() result_file_path: Output file path to return to gr.Files() """ try: if input_folder_path: files = get_media_files(input_folder_path) if isinstance(files, str): files = [files] if files and isinstance(files[0], gr.utils.NamedString): files = [file.name for file in files] ## Initialization variables & start time files_info = {} files_to_download = {} time_start = datetime.now() ## Load parameters related with whisper params = WhisperParameters.as_value(*whisper_params) ## Load model to detect language model = whisper.load_model("base") for file in files: ## Detect language mel = whisper.log_mel_spectrogram(whisper.pad_or_trim(whisper.load_audio(file))).to(model.device) _, probs = model.detect_language(mel) file_language = "" file_lang_probs = "" for key,value in whisper.tokenizer.LANGUAGES.items(): if key == str(max(probs, key=probs.get)): file_language = value.capitalize() for key_prob,value_prob in probs.items(): if key == key_prob: file_lang_probs = str((round(value_prob*100))) break break transcribed_segments, time_for_task = self.run( file, progress, add_timestamp, *whisper_params, ) # Define source language source_lang = file_language # Translate to English using Whisper built-in functionality transcription_note = "" if params.is_translate: if source_lang != "English": transcription_note = "To English" source_lang = "English" else: transcription_note = "Already in English" # Translate the transcribed segments translation_note = "" if translate_output: if source_lang != target_lang: self.nllb_inf = NLLBInference() if source_lang in NLLB_AVAILABLE_LANGS.keys(): transcribed_segments = self.nllb_inf.translate_text( input_list_dict=transcribed_segments, model_size=translate_model, src_lang=source_lang, tgt_lang=target_lang, speaker_diarization=params.is_diarize ) translation_note = "To " + target_lang else: translation_note = source_lang + " not supported" else: translation_note = "Already in " + target_lang ## Get preview file_name, file_ext = os.path.splitext(os.path.basename(file)) ## With or without timestamps if add_timestamp_preview: subtitle = get_txt(transcribed_segments) else: subtitle = get_plaintext(transcribed_segments) files_info[file_name] = {"subtitle": subtitle, "time_for_task": time_for_task, "lang": file_language, "lang_prob": file_lang_probs, "input_source_file": (file_name+file_ext), "translation": translation_note, "transcription": transcription_note} ## Add output file as txt file_name, file_ext = os.path.splitext(os.path.basename(file)) subtitle, file_path = self.generate_and_write_file( file_name=file_name, transcribed_segments=transcribed_segments, add_timestamp=add_timestamp, file_format="txt", output_dir=self.output_dir ) files_to_download[file_name+"_txt"] = {"path": file_path} ## Add output file as srt file_name, file_ext = os.path.splitext(os.path.basename(file)) subtitle, file_path = self.generate_and_write_file( file_name=file_name, transcribed_segments=transcribed_segments, add_timestamp=add_timestamp, file_format="srt", output_dir=self.output_dir ) files_to_download[file_name+"_srt"] = {"path": file_path} ## Add output file as csv file_name, file_ext = os.path.splitext(os.path.basename(file)) subtitle, file_path = self.generate_and_write_file( file_name=file_name, transcribed_segments=transcribed_segments, add_timestamp=add_timestamp, file_format="csv", output_dir=self.output_dir ) files_to_download[file_name+"_csv"] = {"path": file_path} total_result = "" total_info = "" total_time = 0 for file_name, info in files_info.items(): total_result += f'{info["subtitle"]}' total_time += info["time_for_task"] total_info += f'Media file:\t{info["input_source_file"]}\nLanguage:\t{info["lang"]} (probability {info["lang_prob"]}%)\n' if params.is_translate: total_info += f'Translation:\t{info["transcription"]}\n\t⤷ Handled by OpenAI Whisper\n' if translate_output: total_info += f'Translation:\t{info["translation"]}\n\t⤷ Handled by Facebook NLLB\n' time_end = datetime.now() total_info += f"\nTotal processing time: {self.format_time((time_end-time_start).total_seconds())}" result_str = total_result.rstrip("\n") result_file_path = [info['path'] for info in files_to_download.values()] return [result_str,result_file_path,total_info] except Exception as e: print(f"Error transcribing file: {e}") finally: self.release_cuda_memory() def transcribe_mic(self, mic_audio: str, file_format: str = "SRT", add_timestamp: bool = True, progress=gr.Progress(), *whisper_params, ) -> list: """ Write subtitle file from microphone Parameters ---------- mic_audio: str Audio file path from gr.Microphone() file_format: str Subtitle File format to write from gr.Dropdown(). Supported format: [SRT, WebVTT, txt] add_timestamp: bool Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the filename. progress: gr.Progress Indicator to show progress directly in gradio. *whisper_params: tuple Parameters related with whisper. This will be dealt with "WhisperParameters" data class Returns ---------- result_str: Result of transcription to return to gr.Textbox() result_file_path: Output file path to return to gr.Files() """ try: progress(0, desc="Loading Audio...") transcribed_segments, time_for_task = self.run( mic_audio, progress, add_timestamp, *whisper_params, ) progress(1, desc="Completed!") subtitle, result_file_path = self.generate_and_write_file( file_name="Mic", transcribed_segments=transcribed_segments, add_timestamp=add_timestamp, file_format=file_format, output_dir=self.output_dir ) result_str = f"Done in {self.format_time(time_for_task)}! Subtitle file is in the outputs folder.\n\n{subtitle}" return [result_str, result_file_path] except Exception as e: print(f"Error transcribing file: {e}") finally: self.release_cuda_memory() def transcribe_youtube(self, youtube_link: str, file_format: str = "SRT", add_timestamp: bool = True, progress=gr.Progress(), *whisper_params, ) -> list: """ Write subtitle file from Youtube Parameters ---------- youtube_link: str URL of the Youtube video to transcribe from gr.Textbox() file_format: str Subtitle File format to write from gr.Dropdown(). Supported format: [SRT, WebVTT, txt] add_timestamp: bool Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the filename. progress: gr.Progress Indicator to show progress directly in gradio. *whisper_params: tuple Parameters related with whisper. This will be dealt with "WhisperParameters" data class Returns ---------- result_str: Result of transcription to return to gr.Textbox() result_file_path: Output file path to return to gr.Files() """ try: progress(0, desc="Loading Audio from Youtube...") yt = get_ytdata(youtube_link) audio = get_ytaudio(yt) transcribed_segments, time_for_task = self.run( audio, progress, add_timestamp, *whisper_params, ) progress(1, desc="Completed!") file_name = safe_filename(yt.title) subtitle, result_file_path = self.generate_and_write_file( file_name=file_name, transcribed_segments=transcribed_segments, add_timestamp=add_timestamp, file_format=file_format, output_dir=self.output_dir ) result_str = f"Done in {self.format_time(time_for_task)}! Subtitle file is in the outputs folder.\n\n{subtitle}" if os.path.exists(audio): os.remove(audio) return [result_str, result_file_path] except Exception as e: print(f"Error transcribing file: {e}") finally: self.release_cuda_memory() @staticmethod def generate_and_write_file(file_name: str, transcribed_segments: list, add_timestamp: bool, file_format: str, output_dir: str ) -> str: """ Writes subtitle file Parameters ---------- file_name: str Output file name transcribed_segments: list Text segments transcribed from audio add_timestamp: bool Determines whether to add a timestamp to the end of the filename. file_format: str File format to write. Supported formats: [SRT, WebVTT, txt, csv] output_dir: str Directory path of the output Returns ---------- content: str Result of the transcription output_path: str output file path """ if add_timestamp: #timestamp = datetime.now().strftime("%m%d%H%M%S") timestamp = datetime.now().strftime("%Y%m%d %H%M%S") output_path = os.path.join(output_dir, f"{file_name} - {timestamp}") else: output_path = os.path.join(output_dir, f"{file_name}") file_format = file_format.strip().lower() if file_format == "srt": content = get_srt(transcribed_segments) output_path += '.srt' elif file_format == "webvtt": content = get_vtt(transcribed_segments) output_path += '.vtt' elif file_format == "txt": content = get_txt(transcribed_segments) output_path += '.txt' elif file_format == "csv": content = get_csv(transcribed_segments) output_path += '.csv' write_file(content, output_path) return content, output_path def offload(self): """Offload the model and free up the memory""" if self.model is not None: del self.model self.model = None if self.device == "cuda": self.release_cuda_memory() gc.collect() @staticmethod def format_time(elapsed_time: float) -> str: """ Get {hours} {minutes} {seconds} time format string Parameters ---------- elapsed_time: str Elapsed time for transcription Returns ---------- Time format string """ hours, rem = divmod(elapsed_time, 3600) minutes, seconds = divmod(rem, 60) time_str = "" hours = round(hours) if hours: if hours == 1: time_str += f"{hours} hour " else: time_str += f"{hours} hours " minutes = round(minutes) if minutes: if minutes == 1: time_str += f"{minutes} minute " else: time_str += f"{minutes} minutes " seconds = round(seconds) if seconds == 1: time_str += f"{seconds} second" else: time_str += f"{seconds} seconds" return time_str.strip() @staticmethod def get_device(): if torch.cuda.is_available(): return "cuda" elif torch.backends.mps.is_available(): if not WhisperBase.is_sparse_api_supported(): # Device `SparseMPS` is not supported for now. See : https://github.com/pytorch/pytorch/issues/87886 return "cpu" return "mps" else: return "cpu" @staticmethod def is_sparse_api_supported(): if not torch.backends.mps.is_available(): return False try: device = torch.device("mps") sparse_tensor = torch.sparse_coo_tensor( indices=torch.tensor([[0, 1], [2, 3]]), values=torch.tensor([1, 2]), size=(4, 4), device=device ) return True except RuntimeError: return False @staticmethod def release_cuda_memory(): """Release memory""" if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.reset_max_memory_allocated() @staticmethod def remove_input_files(file_paths: List[str]): """Remove gradio cached files""" if not file_paths: return for file_path in file_paths: if file_path and os.path.exists(file_path): os.remove(file_path) @staticmethod def cache_parameters( params: WhisperValues, file_format: str = "SRT", add_timestamp: bool = True ): """Cache parameters to the yaml file""" cached_params = load_yaml(DEFAULT_PARAMETERS_CONFIG_PATH) param_to_cache = params.to_dict() cached_yaml = {**cached_params, **param_to_cache} cached_yaml["whisper"]["add_timestamp"] = add_timestamp cached_yaml["whisper"]["file_format"] = file_format suppress_token = cached_yaml["whisper"].get("suppress_tokens", None) if suppress_token and isinstance(suppress_token, list): cached_yaml["whisper"]["suppress_tokens"] = str(suppress_token) if cached_yaml["whisper"].get("lang", None) is None: cached_yaml["whisper"]["lang"] = AUTOMATIC_DETECTION.unwrap() else: language_dict = whisper.tokenizer.LANGUAGES cached_yaml["whisper"]["lang"] = language_dict[cached_yaml["whisper"]["lang"]] if cached_yaml["vad"].get("max_speech_duration_s", float('inf')) == float('inf'): cached_yaml["vad"]["max_speech_duration_s"] = GRADIO_NONE_NUMBER_MAX if cached_yaml is not None and cached_yaml: save_yaml(cached_yaml, DEFAULT_PARAMETERS_CONFIG_PATH) @staticmethod def resample_audio(audio: Union[str, np.ndarray], new_sample_rate: int = 16000, original_sample_rate: Optional[int] = None,) -> np.ndarray: """Resamples audio to 16k sample rate, standard on Whisper model""" if isinstance(audio, str): audio, original_sample_rate = torchaudio.load(audio) else: if original_sample_rate is None: raise ValueError("original_sample_rate must be provided when audio is numpy array.") audio = torch.from_numpy(audio) resampler = torchaudio.transforms.Resample(orig_freq=original_sample_rate, new_freq=new_sample_rate) resampled_audio = resampler(audio).numpy() return resampled_audio