Demo / modules /whisper /insanely_fast_whisper_inference.py
LAP-DEV's picture
Update modules/whisper/insanely_fast_whisper_inference.py
05f12dc verified
import os
import time
import numpy as np
from typing import BinaryIO, Union, Tuple, List
import torch
from transformers import pipeline
from transformers.utils import is_flash_attn_2_available
import gradio as gr
from huggingface_hub import hf_hub_download
import whisper
from rich.progress import Progress, TimeElapsedColumn, BarColumn, TextColumn
from argparse import Namespace
from modules.utils.paths import (INSANELY_FAST_WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, UVR_MODELS_DIR, OUTPUT_DIR)
from modules.whisper.whisper_parameter import *
from modules.whisper.whisper_base import WhisperBase
class InsanelyFastWhisperInference(WhisperBase):
def __init__(self,
model_dir: str = INSANELY_FAST_WHISPER_MODELS_DIR,
diarization_model_dir: str = DIARIZATION_MODELS_DIR,
uvr_model_dir: str = UVR_MODELS_DIR,
output_dir: str = OUTPUT_DIR,
):
super().__init__(
model_dir=model_dir,
output_dir=output_dir,
diarization_model_dir=diarization_model_dir,
uvr_model_dir=uvr_model_dir
)
self.model_dir = model_dir
os.makedirs(self.model_dir, exist_ok=True)
openai_models = whisper.available_models()
distil_models = ["distil-large-v2", "distil-large-v3", "distil-medium.en", "distil-small.en"]
self.available_models = openai_models + distil_models
self.available_compute_types = ["float16"]
def transcribe(self,
audio: Union[str, np.ndarray, torch.Tensor],
progress: gr.Progress = gr.Progress(),
*whisper_params,
) -> Tuple[List[dict], float]:
"""
transcribe method for faster-whisper.
Parameters
----------
audio: Union[str, BinaryIO, np.ndarray]
Audio path or file binary or Audio numpy array
progress: gr.Progress
Indicator to show progress directly in gradio.
*whisper_params: tuple
Parameters related with whisper. This will be dealt with "WhisperParameters" data class
Returns
----------
segments_result: List[dict]
list of dicts that includes start, end timestamps and transcribed text
elapsed_time: float
elapsed time for transcription
"""
start_time = time.time()
params = WhisperParameters.as_value(*whisper_params)
if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type:
self.update_model(params.model_size, params.compute_type, progress)
progress(0, desc="Transcribing...Progress is not shown in insanely-fast-whisper.")
with Progress(
TextColumn("[progress.description]{task.description}"),
BarColumn(style="yellow1", pulse_style="white"),
TimeElapsedColumn(),
) as progress:
progress.add_task("[yellow]Transcribing...", total=None)
kwargs = {
"no_speech_threshold": params.no_speech_threshold,
"temperature": params.temperature,
"compression_ratio_threshold": params.compression_ratio_threshold,
"logprob_threshold": params.log_prob_threshold,
}
if self.current_model_size.endswith(".en"):
pass
else:
kwargs["language"] = params.lang
kwargs["task"] = "translate" if params.is_translate else "transcribe"
segments = self.model(
inputs=audio,
return_timestamps=True,
chunk_length_s=params.chunk_length,
batch_size=params.batch_size,
generate_kwargs=kwargs
)
segments_result = self.format_result(
transcribed_result=segments,
)
elapsed_time = time.time() - start_time
return segments_result, elapsed_time
def update_model(self,
model_size: str,
compute_type: str,
progress: gr.Progress = gr.Progress(),
):
"""
Update current model setting
Parameters
----------
model_size: str
Size of whisper model
compute_type: str
Compute type for transcription.
see more info : https://opennmt.net/CTranslate2/quantization.html
progress: gr.Progress
Indicator to show progress directly in gradio.
"""
progress(0, desc="Initializing Model...")
model_path = os.path.join(self.model_dir, model_size)
if not os.path.isdir(model_path) or not os.listdir(model_path):
self.download_model(
model_size=model_size,
download_root=model_path,
progress=progress
)
self.current_compute_type = compute_type
self.current_model_size = model_size
self.model = pipeline(
"automatic-speech-recognition",
model=os.path.join(self.model_dir, model_size),
torch_dtype=self.current_compute_type,
device=self.device,
model_kwargs={"attn_implementation": "flash_attention_2"} if is_flash_attn_2_available() else {"attn_implementation": "sdpa"},
)
@staticmethod
def format_result(
transcribed_result: dict
) -> List[dict]:
"""
Format the transcription result of insanely_fast_whisper as the same with other implementation.
Parameters
----------
transcribed_result: dict
Transcription result of the insanely_fast_whisper
Returns
----------
result: List[dict]
Formatted result as the same with other implementation
"""
result = transcribed_result["chunks"]
for item in result:
start, end = item["timestamp"][0], item["timestamp"][1]
if end is None:
end = start
item["start"] = start
item["end"] = end
return result
@staticmethod
def download_model(
model_size: str,
download_root: str,
progress: gr.Progress
):
progress(0, 'Initializing model..')
print(f'Downloading {model_size} to "{download_root}"....')
os.makedirs(download_root, exist_ok=True)
download_list = [
"model.safetensors",
"config.json",
"generation_config.json",
"preprocessor_config.json",
"tokenizer.json",
"tokenizer_config.json",
"added_tokens.json",
"special_tokens_map.json",
"vocab.json",
]
if model_size.startswith("distil"):
repo_id = f"distil-whisper/{model_size}"
else:
repo_id = f"openai/whisper-{model_size}"
for item in download_list:
hf_hub_download(repo_id=repo_id, filename=item, local_dir=download_root)