Spaces:
Running
Running
from transformers import WhisperProcessor, WhisperForConditionalGeneration | |
import gradio as gr | |
from pydub import AudioSegment, silence | |
import tempfile | |
import torch | |
import torchaudio | |
import os | |
# ---------------- Config ---------------- # | |
MODEL_NAME = "mrmuminov/whisper-small-uz" | |
SAMPLE_RATE = 16000 | |
MIN_LEN_MS = 15000 | |
MAX_LEN_MS = 25000 | |
SILENCE_THRESH = -40 # in dBFS | |
# ---------------- Load Model ---------------- # | |
processor = WhisperProcessor.from_pretrained(MODEL_NAME) | |
model = WhisperForConditionalGeneration.from_pretrained(MODEL_NAME) | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model = model.to(device).eval() # set to eval mode | |
# ---------------- Chunking Logic ---------------- # | |
def split_on_silence_with_duration_control(audio, min_len, max_len, silence_thresh): | |
silences = silence.detect_silence(audio, min_silence_len=500, silence_thresh=silence_thresh) | |
silence_midpoints = [((start + end) // 2) for start, end in silences] | |
chunks = [] | |
start = 0 | |
duration = len(audio) | |
while start < duration: | |
end = min(start + max_len, duration) | |
valid_splits = [s for s in silence_midpoints if start + min_len <= s <= end] | |
split_point = valid_splits[-1] if valid_splits else end | |
chunk = audio[start:split_point] | |
# Avoid zero-length chunks | |
if len(chunk) > 0: | |
chunks.append(chunk) | |
start = split_point | |
return chunks | |
# ---------------- Transcription ---------------- # | |
def transcribe(audio_file_path): | |
audio = AudioSegment.from_file(audio_file_path) | |
# Ensure mono and target sample rate | |
audio = audio.set_channels(1).set_frame_rate(SAMPLE_RATE) | |
chunks = split_on_silence_with_duration_control( | |
audio, min_len=MIN_LEN_MS, max_len=MAX_LEN_MS, silence_thresh=SILENCE_THRESH | |
) | |
results = [] | |
for chunk in chunks: | |
with tempfile.NamedTemporaryFile(suffix=".wav", delete=True) as tmpfile: | |
chunk.export(tmpfile.name, format="wav") | |
waveform, _ = torchaudio.load(tmpfile.name) | |
input_features = processor( | |
waveform.squeeze().numpy(), | |
sampling_rate=SAMPLE_RATE, | |
return_tensors="pt", | |
language="uz" | |
).input_features.to(device) | |
with torch.no_grad(): | |
predicted_ids = model.generate(input_features) | |
text = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0] | |
results.append(text) | |
return " ".join(results) | |
# ---------------- Gradio UI ---------------- # | |
with gr.Blocks() as demo: | |
gr.Markdown("### " + MODEL_NAME + " Transcribe Uzbek Audio") | |
file_transcribe = gr.Interface( | |
fn=transcribe, | |
inputs=gr.Audio(type="filepath", label="Upload Audio"), | |
outputs=gr.Textbox(label="Transcription"), | |
) | |
gr.TabbedInterface([file_transcribe], ["Audio File"]) | |
demo.launch() | |