Spaces:
Paused
Paused
| import os | |
| from dataclasses import replace | |
| from math import ceil | |
| from typing import List, Optional, Union | |
| import ctranslate2 | |
| import faster_whisper | |
| import numpy as np | |
| import torch | |
| from faster_whisper.tokenizer import Tokenizer | |
| from faster_whisper.transcribe import TranscriptionOptions, get_ctranslate2_storage | |
| from transformers import Pipeline | |
| from transformers.pipelines.pt_utils import PipelineIterator | |
| from whisperx.audio import N_SAMPLES, SAMPLE_RATE, load_audio, log_mel_spectrogram | |
| from whisperx.types import SingleSegment, TranscriptionResult | |
| from whisperx.vads import Pyannote, Silero, Vad | |
| from whisperx.vads.pyannote import Binarize | |
| def find_numeral_symbol_tokens(tokenizer): | |
| numeral_symbol_tokens = [] | |
| for i in range(tokenizer.eot): | |
| token = tokenizer.decode([i]).removeprefix(" ") | |
| has_numeral_symbol = any(c in "0123456789%$£" for c in token) | |
| if has_numeral_symbol: | |
| numeral_symbol_tokens.append(i) | |
| return numeral_symbol_tokens | |
| class WhisperModel(faster_whisper.WhisperModel): | |
| """ | |
| FasterWhisperModel provides batched inference for faster-whisper. | |
| Currently only works in non-timestamp mode and fixed prompt for all samples in batch. | |
| """ | |
| def generate_segment_batched( | |
| self, | |
| features: np.ndarray, | |
| tokenizer: Tokenizer, | |
| options: TranscriptionOptions, | |
| ): | |
| batch_size = features.shape[0] | |
| all_tokens = [] | |
| prompt_reset_since = 0 | |
| if options.initial_prompt is not None: | |
| initial_prompt = " " + options.initial_prompt.strip() | |
| initial_prompt_tokens = tokenizer.encode(initial_prompt) | |
| all_tokens.extend(initial_prompt_tokens) | |
| previous_tokens = all_tokens[prompt_reset_since:] | |
| prompt = self.get_prompt( | |
| tokenizer, | |
| previous_tokens, | |
| without_timestamps=options.without_timestamps, | |
| prefix=options.prefix, | |
| hotwords=options.hotwords, | |
| ) | |
| encoder_output = self.encode(features) | |
| result = self.model.generate( | |
| encoder_output, | |
| [prompt] * batch_size, | |
| beam_size=options.beam_size, | |
| patience=options.patience, | |
| length_penalty=options.length_penalty, | |
| max_length=self.max_length, | |
| suppress_blank=options.suppress_blank, | |
| suppress_tokens=options.suppress_tokens, | |
| ) | |
| tokens_batch = [x.sequences_ids[0] for x in result] | |
| def decode_batch(tokens: List[List[int]]) -> str: | |
| res = [] | |
| for tk in tokens: | |
| res.append([token for token in tk if token < tokenizer.eot]) | |
| return tokenizer.tokenizer.decode_batch(res) | |
| text = decode_batch(tokens_batch) | |
| return encoder_output, text, tokens_batch | |
| def encode(self, features: np.ndarray) -> ctranslate2.StorageView: | |
| # When the model is running on multiple GPUs, the encoder output should be moved | |
| # to the CPU since we don't know which GPU will handle the next job. | |
| to_cpu = self.model.device == "cuda" and len(self.model.device_index) > 1 | |
| # unsqueeze if batch size = 1 | |
| if len(features.shape) == 2: | |
| features = np.expand_dims(features, 0) | |
| features = get_ctranslate2_storage(features) | |
| return self.model.encode(features, to_cpu=to_cpu) | |
| class FasterWhisperPipeline(Pipeline): | |
| """ | |
| Huggingface Pipeline wrapper for FasterWhisperModel. | |
| """ | |
| # TODO: | |
| # - add support for timestamp mode | |
| # - add support for custom inference kwargs | |
| def __init__( | |
| self, | |
| model: WhisperModel, | |
| vad, | |
| vad_params: dict, | |
| options: TranscriptionOptions, | |
| tokenizer: Optional[Tokenizer] = None, | |
| device: Union[int, str, "torch.device"] = -1, | |
| framework="pt", | |
| language: Optional[str] = None, | |
| suppress_numerals: bool = False, | |
| **kwargs, | |
| ): | |
| self.model = model | |
| self.tokenizer = tokenizer | |
| self.options = options | |
| self.preset_language = language | |
| self.suppress_numerals = suppress_numerals | |
| self._batch_size = kwargs.pop("batch_size", None) | |
| self._num_workers = 1 | |
| self._preprocess_params, self._forward_params, self._postprocess_params = ( | |
| self._sanitize_parameters(**kwargs) | |
| ) | |
| self.call_count = 0 | |
| self.framework = framework | |
| if self.framework == "pt": | |
| if isinstance(device, torch.device): | |
| self.device = device | |
| elif isinstance(device, str): | |
| self.device = torch.device(device) | |
| elif device < 0: | |
| self.device = torch.device("cpu") | |
| else: | |
| self.device = torch.device(f"cuda:{device}") | |
| else: | |
| self.device = device | |
| super(Pipeline, self).__init__() | |
| self.vad_model = vad | |
| self._vad_params = vad_params | |
| self.last_speech_timestamp = 0.0 | |
| def _sanitize_parameters(self, **kwargs): | |
| preprocess_kwargs = {} | |
| if "tokenizer" in kwargs: | |
| preprocess_kwargs["maybe_arg"] = kwargs["maybe_arg"] | |
| return preprocess_kwargs, {}, {} | |
| def preprocess(self, input_dict): | |
| audio = input_dict["inputs"] | |
| model_n_mels = self.model.feat_kwargs.get("feature_size") | |
| features = log_mel_spectrogram( | |
| audio, | |
| n_mels=model_n_mels if model_n_mels is not None else 80, | |
| padding=N_SAMPLES - audio.shape[0], | |
| ) | |
| return { | |
| "inputs": features, | |
| "start": input_dict["start"], | |
| "end": input_dict["end"], | |
| "segment_size": input_dict["segment_size"], | |
| } | |
| def _forward(self, model_inputs): | |
| encoder_output, _text, tokens = self.model.generate_segment_batched( | |
| model_inputs["inputs"], self.tokenizer, self.options | |
| ) | |
| outputs = [ | |
| [ | |
| { | |
| "tokens": tokens[i], | |
| "start": model_inputs["start"][i], | |
| "end": model_inputs["end"][i], | |
| "seek": int(model_inputs["start"][i] * 100), | |
| } | |
| ] | |
| for i in range(len(tokens)) | |
| ] | |
| self.last_speech_timestamp = self.model.add_word_timestamps( | |
| outputs, | |
| self.tokenizer, | |
| encoder_output, | |
| num_frames=model_inputs["segment_size"], | |
| prepend_punctuations="\"'“¿([{-", | |
| append_punctuations="\"'.。,,!!??::”)]}、", | |
| last_speech_timestamp=self.last_speech_timestamp, | |
| ) | |
| outputs = [outputs[i][0]["words"] for i in range(len(outputs))] | |
| outputs = sum(outputs, []) | |
| return { | |
| "words": [outputs], | |
| } | |
| def postprocess(self, model_outputs): | |
| return model_outputs | |
| def get_iterator( | |
| self, | |
| inputs, | |
| num_workers: int, | |
| batch_size: int, | |
| preprocess_params: dict, | |
| forward_params: dict, | |
| postprocess_params: dict, | |
| ): | |
| dataset = PipelineIterator(inputs, self.preprocess, preprocess_params) | |
| if "TOKENIZERS_PARALLELISM" not in os.environ: | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
| # TODO hack by collating feature_extractor and image_processor | |
| def stack(items): | |
| return { | |
| "inputs": torch.stack([x["inputs"] for x in items]), | |
| "start": [x["start"] for x in items], | |
| "end": [x["end"] for x in items], | |
| "segment_size": [x["segment_size"] for x in items], | |
| } | |
| dataloader = torch.utils.data.DataLoader( | |
| dataset, num_workers=num_workers, batch_size=batch_size, collate_fn=stack | |
| ) | |
| model_iterator = PipelineIterator( | |
| dataloader, self.forward, forward_params, loader_batch_size=batch_size | |
| ) | |
| final_iterator = PipelineIterator( | |
| model_iterator, self.postprocess, postprocess_params | |
| ) | |
| return final_iterator | |
| def transcribe( | |
| self, | |
| audio: Union[str, np.ndarray], | |
| batch_size: Optional[int] = None, | |
| num_workers=0, | |
| language: Optional[str] = None, | |
| task: Optional[str] = None, | |
| chunk_size=30, | |
| print_progress=False, | |
| combined_progress=False, | |
| verbose=False, | |
| ) -> TranscriptionResult: | |
| self.last_speech_timestamp = 0.0 | |
| if isinstance(audio, str): | |
| audio = load_audio(audio) | |
| def data(audio, segments): | |
| for seg in segments: | |
| f1 = int(seg["start"] * SAMPLE_RATE) | |
| f2 = int(seg["end"] * SAMPLE_RATE) | |
| yield { | |
| "inputs": audio[f1:f2], | |
| "start": seg["start"], | |
| "end": seg["end"], | |
| "segment_size": int( | |
| ceil(seg["end"] - seg["start"]) * self.model.frames_per_second | |
| ), | |
| } | |
| # Pre-process audio and merge chunks as defined by the respective VAD child class | |
| # In case vad_model is manually assigned (see 'load_model') follow the functionality of pyannote toolkit | |
| if issubclass(type(self.vad_model), Vad): | |
| waveform = self.vad_model.preprocess_audio(audio) | |
| merge_chunks = self.vad_model.merge_chunks | |
| else: | |
| waveform = Pyannote.preprocess_audio(audio) | |
| merge_chunks = Pyannote.merge_chunks | |
| pre_merge_vad_segments = self.vad_model( | |
| {"waveform": waveform, "sample_rate": SAMPLE_RATE} | |
| ) | |
| vad_segments = merge_chunks( | |
| pre_merge_vad_segments, | |
| chunk_size, | |
| onset=self._vad_params["vad_onset"], | |
| offset=self._vad_params["vad_offset"], | |
| ) | |
| if self.tokenizer is None: | |
| language = language or self.detect_language(audio) | |
| task = task or "transcribe" | |
| self.tokenizer = Tokenizer( | |
| self.model.hf_tokenizer, | |
| self.model.model.is_multilingual, | |
| task=task, | |
| language=language, | |
| ) | |
| else: | |
| language = language or self.tokenizer.language_code | |
| task = task or self.tokenizer.task | |
| if task != self.tokenizer.task or language != self.tokenizer.language_code: | |
| self.tokenizer = Tokenizer( | |
| self.model.hf_tokenizer, | |
| self.model.model.is_multilingual, | |
| task=task, | |
| language=language, | |
| ) | |
| if self.suppress_numerals: | |
| previous_suppress_tokens = self.options.suppress_tokens | |
| numeral_symbol_tokens = find_numeral_symbol_tokens(self.tokenizer) | |
| print("Suppressing numeral and symbol tokens") | |
| new_suppressed_tokens = numeral_symbol_tokens + self.options.suppress_tokens | |
| new_suppressed_tokens = list(set(new_suppressed_tokens)) | |
| self.options = replace(self.options, suppress_tokens=new_suppressed_tokens) | |
| binarize = Binarize( | |
| max_duration=chunk_size, | |
| onset=self._vad_params["vad_onset"], | |
| offset=self._vad_params["vad_offset"], | |
| ) | |
| segments = binarize(pre_merge_vad_segments).get_timeline() | |
| segments: List[SingleSegment] = [ | |
| { | |
| "start": seg.start, | |
| "end": seg.end, | |
| "text": "", | |
| } | |
| for seg in segments | |
| ] | |
| batch_size = batch_size or self._batch_size | |
| total_segments = len(vad_segments) | |
| for idx, out in enumerate( | |
| self.__call__( | |
| data(audio, vad_segments), | |
| batch_size=batch_size, | |
| num_workers=num_workers, | |
| ) | |
| ): | |
| if print_progress: | |
| base_progress = ((idx + 1) / total_segments) * 100 | |
| percent_complete = ( | |
| base_progress / 2 if combined_progress else base_progress | |
| ) | |
| print(f"Progress: {percent_complete:.2f}%...") | |
| last_speech_timestamp_index = 0 | |
| next_last_speech_timestamp_index = 0 | |
| for word in out["words"]: | |
| possiable_segment_indices = [] | |
| for i, segment in enumerate(segments[last_speech_timestamp_index:]): | |
| if segment["end"] < word["start"]: | |
| next_last_speech_timestamp_index = i + 1 | |
| overlap_start = max(segment["start"], word["start"]) | |
| overlap_end = min(segment["end"], word["end"]) | |
| if overlap_start <= overlap_end: | |
| possiable_segment_indices.append( | |
| last_speech_timestamp_index + i | |
| ) | |
| last_speech_timestamp_index = next_last_speech_timestamp_index | |
| if len(possiable_segment_indices) == 0: | |
| print( | |
| f"Warning: Word '{word['word']}' at [{round(word['start'], 3)} --> {round(word['end'], 3)}] is not in any segment." | |
| ) | |
| else: | |
| largest_overlap = -1 | |
| best_segment_index = None | |
| for i in possiable_segment_indices: | |
| segment = segments[i] | |
| overlap_start = max(segment["start"], word["start"]) | |
| overlap_end = min(segment["end"], word["end"]) | |
| overlap_duration = overlap_end - overlap_start | |
| if overlap_duration > largest_overlap: | |
| largest_overlap = overlap_duration | |
| best_segment_index = i | |
| segments[best_segment_index]["text"] += word["word"] | |
| # revert the tokenizer if multilingual inference is enabled | |
| if self.preset_language is None: | |
| self.tokenizer = None | |
| # revert suppressed tokens if suppress_numerals is enabled | |
| if self.suppress_numerals: | |
| self.options = replace( | |
| self.options, suppress_tokens=previous_suppress_tokens | |
| ) | |
| return {"segments": segments, "language": language} | |
| def detect_language(self, audio: np.ndarray) -> str: | |
| if audio.shape[0] < N_SAMPLES: | |
| print( | |
| "Warning: audio is shorter than 30s, language detection may be inaccurate." | |
| ) | |
| model_n_mels = self.model.feat_kwargs.get("feature_size") | |
| segment = log_mel_spectrogram( | |
| audio[:N_SAMPLES], | |
| n_mels=model_n_mels if model_n_mels is not None else 80, | |
| padding=0 if audio.shape[0] >= N_SAMPLES else N_SAMPLES - audio.shape[0], | |
| ) | |
| encoder_output = self.model.encode(segment) | |
| results = self.model.model.detect_language(encoder_output) | |
| language_token, language_probability = results[0][0] | |
| language = language_token[2:-2] | |
| print( | |
| f"Detected language: {language} ({language_probability:.2f}) in first 30s of audio..." | |
| ) | |
| return language | |
| def load_model( | |
| whisper_arch: str, | |
| device: str, | |
| device_index=0, | |
| compute_type="float16", | |
| asr_options: Optional[dict] = None, | |
| language: Optional[str] = None, | |
| vad_model: Optional[Vad] = None, | |
| vad_method: Optional[str] = "pyannote", | |
| vad_options: Optional[dict] = None, | |
| model: Optional[WhisperModel] = None, | |
| task="transcribe", | |
| download_root: Optional[str] = None, | |
| local_files_only=False, | |
| threads=4, | |
| ) -> FasterWhisperPipeline: | |
| """Load a Whisper model for inference. | |
| Args: | |
| whisper_arch - The name of the Whisper model to load. | |
| device - The device to load the model on. | |
| compute_type - The compute type to use for the model. | |
| vad_method - The vad method to use. vad_model has higher priority if is not None. | |
| options - A dictionary of options to use for the model. | |
| language - The language of the model. (use English for now) | |
| model - The WhisperModel instance to use. | |
| download_root - The root directory to download the model to. | |
| local_files_only - If `True`, avoid downloading the file and return the path to the local cached file if it exists. | |
| threads - The number of cpu threads to use per worker, e.g. will be multiplied by num workers. | |
| Returns: | |
| A Whisper pipeline. | |
| """ | |
| if whisper_arch.endswith(".en"): | |
| language = "en" | |
| model = model or WhisperModel( | |
| whisper_arch, | |
| device=device, | |
| device_index=device_index, | |
| compute_type=compute_type, | |
| download_root=download_root, | |
| local_files_only=local_files_only, | |
| cpu_threads=threads, | |
| ) | |
| if language is not None: | |
| tokenizer = Tokenizer( | |
| model.hf_tokenizer, | |
| model.model.is_multilingual, | |
| task=task, | |
| language=language, | |
| ) | |
| else: | |
| print( | |
| "No language specified, language will be first be detected for each audio file (increases inference time)." | |
| ) | |
| tokenizer = None | |
| default_asr_options = { | |
| "beam_size": 5, | |
| "best_of": 5, | |
| "patience": 1, | |
| "length_penalty": 1, | |
| "repetition_penalty": 1, | |
| "no_repeat_ngram_size": 0, | |
| "temperatures": [0.0, 0.2, 0.4, 0.6, 0.8, 1.0], | |
| "compression_ratio_threshold": 2.4, | |
| "log_prob_threshold": -1.0, | |
| "no_speech_threshold": 0.6, | |
| "condition_on_previous_text": False, | |
| "prompt_reset_on_temperature": 0.5, | |
| "initial_prompt": None, | |
| "prefix": None, | |
| "suppress_blank": True, | |
| "suppress_tokens": [-1], | |
| "without_timestamps": True, | |
| "max_initial_timestamp": 0.0, | |
| "word_timestamps": False, | |
| "prepend_punctuations": "\"'“¿([{-", | |
| "append_punctuations": "\"'.。,,!!??::”)]}、", | |
| "multilingual": model.model.is_multilingual, | |
| "suppress_numerals": False, | |
| "max_new_tokens": None, | |
| "clip_timestamps": None, | |
| "hallucination_silence_threshold": None, | |
| "hotwords": None, | |
| } | |
| if asr_options is not None: | |
| default_asr_options.update(asr_options) | |
| suppress_numerals = default_asr_options["suppress_numerals"] | |
| del default_asr_options["suppress_numerals"] | |
| default_asr_options = TranscriptionOptions(**default_asr_options) | |
| default_vad_options = { | |
| "chunk_size": 30, # needed by silero since binarization happens before merge_chunks | |
| "vad_onset": 0.500, | |
| "vad_offset": 0.363, | |
| } | |
| if vad_options is not None: | |
| default_vad_options.update(vad_options) | |
| # Note: manually assigned vad_model has higher priority than vad_method! | |
| if vad_model is not None: | |
| print("Use manually assigned vad_model. vad_method is ignored.") | |
| else: | |
| if vad_method == "silero": | |
| vad_model = Silero(**default_vad_options) | |
| elif vad_method == "pyannote": | |
| vad_model = Pyannote( | |
| torch.device(device), use_auth_token=None, **default_vad_options | |
| ) | |
| else: | |
| raise ValueError(f"Invalid vad_method: {vad_method}") | |
| return FasterWhisperPipeline( | |
| model=model, | |
| vad=vad_model, | |
| options=default_asr_options, | |
| tokenizer=tokenizer, | |
| language=language, | |
| suppress_numerals=suppress_numerals, | |
| vad_params=default_vad_options, | |
| ) | |