Spaces:
Running
Running
| from whisperx.alignment import ( | |
| DEFAULT_ALIGN_MODELS_TORCH as DAMT, | |
| DEFAULT_ALIGN_MODELS_HF as DAMHF, | |
| ) | |
| from whisperx.utils import TO_LANGUAGE_CODE | |
| import whisperx | |
| import torch | |
| import gc | |
| import os | |
| import soundfile as sf | |
| from IPython.utils import capture # noqa | |
| from .language_configuration import EXTRA_ALIGN, INVERTED_LANGUAGES | |
| from .logging_setup import logger | |
| from .postprocessor import sanitize_file_name | |
| from .utils import remove_directory_contents, run_command | |
| # ZERO GPU CONFIG | |
| import spaces | |
| import copy | |
| import random | |
| import time | |
| def random_sleep(): | |
| if os.environ.get("ZERO_GPU") == "TRUE": | |
| print("Random sleep") | |
| sleep_time = round(random.uniform(7.2, 9.9), 1) | |
| time.sleep(sleep_time) | |
| def load_and_transcribe_audio(asr_model, audio, compute_type, language, asr_options, batch_size, segment_duration_limit): | |
| # Load model | |
| model = whisperx.load_model( | |
| asr_model, | |
| os.environ.get("SONITR_DEVICE") if os.environ.get("ZERO_GPU") != "TRUE" else "cuda", | |
| compute_type=compute_type, | |
| language=language, | |
| asr_options=asr_options, | |
| ) | |
| # Transcribe audio | |
| result = model.transcribe( | |
| audio, | |
| batch_size=batch_size, | |
| chunk_size=segment_duration_limit, | |
| print_progress=True, | |
| ) | |
| del model | |
| gc.collect() | |
| torch.cuda.empty_cache() # noqa | |
| return result | |
| def load_align_and_align_segments(result, audio, DAMHF): | |
| # Load alignment model | |
| model_a, metadata = whisperx.load_align_model( | |
| language_code=result["language"], | |
| device=os.environ.get("SONITR_DEVICE") if os.environ.get("ZERO_GPU") != "TRUE" else "cuda", | |
| model_name=None | |
| if result["language"] in DAMHF.keys() | |
| else EXTRA_ALIGN[result["language"]], | |
| ) | |
| # Align segments | |
| alignment_result = whisperx.align( | |
| result["segments"], | |
| model_a, | |
| metadata, | |
| audio, | |
| os.environ.get("SONITR_DEVICE") if os.environ.get("ZERO_GPU") != "TRUE" else "cuda", | |
| return_char_alignments=True, | |
| print_progress=False, | |
| ) | |
| # Clean up | |
| del model_a | |
| gc.collect() | |
| torch.cuda.empty_cache() # noqa | |
| return alignment_result | |
| def diarize_audio(diarize_model, audio_wav, min_speakers, max_speakers): | |
| if os.environ.get("ZERO_GPU") == "TRUE": | |
| diarize_model.model.to(torch.device("cuda")) | |
| diarize_segments = diarize_model( | |
| audio_wav, | |
| min_speakers=min_speakers, | |
| max_speakers=max_speakers | |
| ) | |
| return diarize_segments | |
| # ZERO GPU CONFIG | |
| ASR_MODEL_OPTIONS = [ | |
| "tiny", | |
| "base", | |
| "small", | |
| "medium", | |
| "large", | |
| "large-v1", | |
| "large-v2", | |
| "large-v3", | |
| "distil-large-v2", | |
| "Systran/faster-distil-whisper-large-v3", | |
| "tiny.en", | |
| "base.en", | |
| "small.en", | |
| "medium.en", | |
| "distil-small.en", | |
| "distil-medium.en", | |
| "OpenAI_API_Whisper", | |
| ] | |
| COMPUTE_TYPE_GPU = [ | |
| "default", | |
| "auto", | |
| "int8", | |
| "int8_float32", | |
| "int8_float16", | |
| "int8_bfloat16", | |
| "float16", | |
| "bfloat16", | |
| "float32" | |
| ] | |
| COMPUTE_TYPE_CPU = [ | |
| "default", | |
| "auto", | |
| "int8", | |
| "int8_float32", | |
| "int16", | |
| "float32", | |
| ] | |
| WHISPER_MODELS_PATH = './WHISPER_MODELS' | |
| def openai_api_whisper( | |
| input_audio_file, | |
| source_lang=None, | |
| chunk_duration=1800 | |
| ): | |
| info = sf.info(input_audio_file) | |
| duration = info.duration | |
| output_directory = "./whisper_api_audio_parts" | |
| os.makedirs(output_directory, exist_ok=True) | |
| remove_directory_contents(output_directory) | |
| if duration > chunk_duration: | |
| # Split the audio file into smaller chunks with 30-minute duration | |
| cm = f'ffmpeg -i "{input_audio_file}" -f segment -segment_time {chunk_duration} -c:a libvorbis "{output_directory}/output%03d.ogg"' | |
| run_command(cm) | |
| # Get list of generated chunk files | |
| chunk_files = sorted( | |
| [f"{output_directory}/{f}" for f in os.listdir(output_directory) if f.endswith('.ogg')] | |
| ) | |
| else: | |
| one_file = f"{output_directory}/output000.ogg" | |
| cm = f'ffmpeg -i "{input_audio_file}" -c:a libvorbis {one_file}' | |
| run_command(cm) | |
| chunk_files = [one_file] | |
| # Transcript | |
| segments = [] | |
| language = source_lang if source_lang else None | |
| for i, chunk in enumerate(chunk_files): | |
| from openai import OpenAI | |
| client = OpenAI() | |
| audio_file = open(chunk, "rb") | |
| transcription = client.audio.transcriptions.create( | |
| model="whisper-1", | |
| file=audio_file, | |
| language=language, | |
| response_format="verbose_json", | |
| timestamp_granularities=["segment"], | |
| ) | |
| try: | |
| transcript_dict = transcription.model_dump() | |
| except: # noqa | |
| transcript_dict = transcription.to_dict() | |
| if language is None: | |
| logger.info(f'Language detected: {transcript_dict["language"]}') | |
| language = TO_LANGUAGE_CODE[transcript_dict["language"]] | |
| chunk_time = chunk_duration * (i) | |
| for seg in transcript_dict["segments"]: | |
| if "start" in seg.keys(): | |
| segments.append( | |
| { | |
| "text": seg["text"], | |
| "start": seg["start"] + chunk_time, | |
| "end": seg["end"] + chunk_time, | |
| } | |
| ) | |
| audio = whisperx.load_audio(input_audio_file) | |
| result = {"segments": segments, "language": language} | |
| return audio, result | |
| def find_whisper_models(): | |
| path = WHISPER_MODELS_PATH | |
| folders = [] | |
| if os.path.exists(path): | |
| for folder in os.listdir(path): | |
| folder_path = os.path.join(path, folder) | |
| if ( | |
| os.path.isdir(folder_path) | |
| and 'model.bin' in os.listdir(folder_path) | |
| ): | |
| folders.append(folder) | |
| return folders | |
| def transcribe_speech( | |
| audio_wav, | |
| asr_model, | |
| compute_type, | |
| batch_size, | |
| SOURCE_LANGUAGE, | |
| literalize_numbers=True, | |
| segment_duration_limit=15, | |
| ): | |
| """ | |
| Transcribe speech using a whisper model. | |
| Parameters: | |
| - audio_wav (str): Path to the audio file in WAV format. | |
| - asr_model (str): The whisper model to be loaded. | |
| - compute_type (str): Type of compute to be used (e.g., 'int8', 'float16'). | |
| - batch_size (int): Batch size for transcription. | |
| - SOURCE_LANGUAGE (str): Source language for transcription. | |
| Returns: | |
| - Tuple containing: | |
| - audio: Loaded audio file. | |
| - result: Transcription result as a dictionary. | |
| """ | |
| if asr_model == "OpenAI_API_Whisper": | |
| if literalize_numbers: | |
| logger.info( | |
| "OpenAI's API Whisper does not support " | |
| "the literalization of numbers." | |
| ) | |
| return openai_api_whisper(audio_wav, SOURCE_LANGUAGE) | |
| # https://github.com/openai/whisper/discussions/277 | |
| prompt = "以下是普通话的句子。" if SOURCE_LANGUAGE == "zh" else None | |
| SOURCE_LANGUAGE = ( | |
| SOURCE_LANGUAGE if SOURCE_LANGUAGE != "zh-TW" else "zh" | |
| ) | |
| asr_options = { | |
| "initial_prompt": prompt, | |
| "suppress_numerals": literalize_numbers | |
| } | |
| if asr_model not in ASR_MODEL_OPTIONS: | |
| base_dir = WHISPER_MODELS_PATH | |
| if not os.path.exists(base_dir): | |
| os.makedirs(base_dir) | |
| model_dir = os.path.join(base_dir, sanitize_file_name(asr_model)) | |
| if not os.path.exists(model_dir): | |
| from ctranslate2.converters import TransformersConverter | |
| quantization = "float32" | |
| # Download new model | |
| try: | |
| converter = TransformersConverter( | |
| asr_model, | |
| low_cpu_mem_usage=True, | |
| copy_files=[ | |
| "tokenizer_config.json", "preprocessor_config.json" | |
| ] | |
| ) | |
| converter.convert( | |
| model_dir, | |
| quantization=quantization, | |
| force=False | |
| ) | |
| except Exception as error: | |
| if "File tokenizer_config.json does not exist" in str(error): | |
| converter._copy_files = [ | |
| "tokenizer.json", "preprocessor_config.json" | |
| ] | |
| converter.convert( | |
| model_dir, | |
| quantization=quantization, | |
| force=True | |
| ) | |
| else: | |
| raise error | |
| asr_model = model_dir | |
| logger.info(f"ASR Model: {str(model_dir)}") | |
| audio = whisperx.load_audio(audio_wav) | |
| result = load_and_transcribe_audio( | |
| asr_model, audio, compute_type, SOURCE_LANGUAGE, asr_options, batch_size, segment_duration_limit | |
| ) | |
| if result["language"] == "zh" and not prompt: | |
| result["language"] = "zh-TW" | |
| logger.info("Chinese - Traditional (zh-TW)") | |
| return audio, result | |
| def align_speech(audio, result): | |
| """ | |
| Aligns speech segments based on the provided audio and result metadata. | |
| Parameters: | |
| - audio (array): The audio data in a suitable format for alignment. | |
| - result (dict): Metadata containing information about the segments | |
| and language. | |
| Returns: | |
| - result (dict): Updated metadata after aligning the segments with | |
| the audio. This includes character-level alignments if | |
| 'return_char_alignments' is set to True. | |
| Notes: | |
| - This function uses language-specific models to align speech segments. | |
| - It performs language compatibility checks and selects the | |
| appropriate alignment model. | |
| - Cleans up memory by releasing resources after alignment. | |
| """ | |
| DAMHF.update(DAMT) # lang align | |
| if ( | |
| not result["language"] in DAMHF.keys() | |
| and not result["language"] in EXTRA_ALIGN.keys() | |
| ): | |
| logger.warning( | |
| "Automatic detection: Source language not compatible with align" | |
| ) | |
| raise ValueError( | |
| f"Detected language {result['language']} incompatible, " | |
| "you can select the source language to avoid this error." | |
| ) | |
| if ( | |
| result["language"] in EXTRA_ALIGN.keys() | |
| and EXTRA_ALIGN[result["language"]] == "" | |
| ): | |
| lang_name = ( | |
| INVERTED_LANGUAGES[result["language"]] | |
| if result["language"] in INVERTED_LANGUAGES.keys() | |
| else result["language"] | |
| ) | |
| logger.warning( | |
| "No compatible wav2vec2 model found " | |
| f"for the language '{lang_name}', skipping alignment." | |
| ) | |
| return result | |
| # random_sleep() | |
| result = load_align_and_align_segments(result, audio, DAMHF) | |
| return result | |
| diarization_models = { | |
| "pyannote_3.1": "pyannote/speaker-diarization-3.1", | |
| "pyannote_2.1": "pyannote/[email protected]", | |
| "disable": "", | |
| } | |
| def reencode_speakers(result): | |
| if result["segments"][0]["speaker"] == "SPEAKER_00": | |
| return result | |
| speaker_mapping = {} | |
| counter = 0 | |
| logger.debug("Reencode speakers") | |
| for segment in result["segments"]: | |
| old_speaker = segment["speaker"] | |
| if old_speaker not in speaker_mapping: | |
| speaker_mapping[old_speaker] = f"SPEAKER_{counter:02d}" | |
| counter += 1 | |
| segment["speaker"] = speaker_mapping[old_speaker] | |
| return result | |
| def diarize_speech( | |
| audio_wav, | |
| result, | |
| min_speakers, | |
| max_speakers, | |
| YOUR_HF_TOKEN, | |
| model_name="pyannote/[email protected]", | |
| ): | |
| """ | |
| Performs speaker diarization on speech segments. | |
| Parameters: | |
| - audio_wav (array): Audio data in WAV format to perform speaker | |
| diarization. | |
| - result (dict): Metadata containing information about speech segments | |
| and alignments. | |
| - min_speakers (int): Minimum number of speakers expected in the audio. | |
| - max_speakers (int): Maximum number of speakers expected in the audio. | |
| - YOUR_HF_TOKEN (str): Your Hugging Face API token for model | |
| authentication. | |
| - model_name (str): Name of the speaker diarization model to be used | |
| (default: "pyannote/[email protected]"). | |
| Returns: | |
| - result_diarize (dict): Updated metadata after assigning speaker | |
| labels to segments. | |
| Notes: | |
| - This function utilizes a speaker diarization model to label speaker | |
| segments in the audio. | |
| - It assigns speakers to word-level segments based on diarization results. | |
| - Cleans up memory by releasing resources after diarization. | |
| - If only one speaker is specified, each segment is automatically assigned | |
| as the first speaker, eliminating the need for diarization inference. | |
| """ | |
| if max(min_speakers, max_speakers) > 1 and model_name: | |
| try: | |
| diarize_model = whisperx.DiarizationPipeline( | |
| model_name=model_name, | |
| use_auth_token=YOUR_HF_TOKEN, | |
| device=os.environ.get("SONITR_DEVICE"), | |
| ) | |
| except Exception as error: | |
| error_str = str(error) | |
| gc.collect() | |
| torch.cuda.empty_cache() # noqa | |
| if "'NoneType' object has no attribute 'to'" in error_str: | |
| if model_name == diarization_models["pyannote_2.1"]: | |
| raise ValueError( | |
| "Accept the license agreement for using Pyannote 2.1." | |
| " You need to have an account on Hugging Face and " | |
| "accept the license to use the models: " | |
| "https://huggingface.co/pyannote/speaker-diarization " | |
| "and https://huggingface.co/pyannote/segmentation " | |
| "Get your KEY TOKEN here: " | |
| "https://hf.co/settings/tokens " | |
| ) | |
| elif model_name == diarization_models["pyannote_3.1"]: | |
| raise ValueError( | |
| "New Licence Pyannote 3.1: You need to have an account" | |
| " on Hugging Face and accept the license to use the " | |
| "models: https://huggingface.co/pyannote/speaker-diarization-3.1 " # noqa | |
| "and https://huggingface.co/pyannote/segmentation-3.0 " | |
| ) | |
| else: | |
| raise error | |
| random_sleep() | |
| diarize_segments = diarize_audio(diarize_model, audio_wav, min_speakers, max_speakers) | |
| result_diarize = whisperx.assign_word_speakers( | |
| diarize_segments, result | |
| ) | |
| for segment in result_diarize["segments"]: | |
| if "speaker" not in segment: | |
| segment["speaker"] = "SPEAKER_00" | |
| logger.warning( | |
| f"No speaker detected in {segment['start']}. First TTS " | |
| f"will be used for the segment text: {segment['text']} " | |
| ) | |
| del diarize_model | |
| gc.collect() | |
| torch.cuda.empty_cache() # noqa | |
| else: | |
| result_diarize = result | |
| result_diarize["segments"] = [ | |
| {**item, "speaker": "SPEAKER_00"} | |
| for item in result_diarize["segments"] | |
| ] | |
| return reencode_speakers(result_diarize) | |