Spaces:
Sleeping
Sleeping
from pathlib import Path | |
import gradio as gr | |
import torch | |
import torchaudio | |
from transformers import ( | |
WhisperFeatureExtractor, | |
WhisperForConditionalGeneration, | |
WhisperTokenizerFast, | |
) | |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32 | |
SAMPLING_RATE = 16_000 | |
model = WhisperForConditionalGeneration.from_pretrained( | |
"GiorgiSekhniashvili/whisper-tiny-ka-09", torch_dtype=DTYPE | |
) | |
feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-tiny") | |
tokenizer = WhisperTokenizerFast.from_pretrained("openai/whisper-tiny") | |
forced_decoder_ids = tokenizer.get_decoder_prompt_ids( | |
language="georgian", task="transcribe" | |
) | |
model.generation_config.forced_decoder_ids = forced_decoder_ids | |
model.to(DEVICE) | |
def load_audio(audio_path: Path, target_sr: int): | |
waveform, sr = torchaudio.load(audio_path) | |
if waveform.shape[0] > 1: | |
waveform = waveform.mean(dim=0, keepdim=True) | |
if sr != target_sr: | |
waveform = torchaudio.functional.resample( | |
waveform, orig_freq=sr, new_freq=target_sr | |
) | |
return waveform | |
def transcribe(audio): | |
try: | |
waveform = load_audio(audio, target_sr=SAMPLING_RATE) | |
except Exception as e: | |
return str(e) | |
input_values = feature_extractor( | |
waveform[0], sampling_rate=SAMPLING_RATE, return_tensors="pt" | |
) | |
input_features = input_values.input_features.to(DEVICE, dtype=DTYPE) | |
with torch.no_grad(): | |
outputs = model.generate( | |
input_features, | |
forced_decoder_ids=forced_decoder_ids, | |
max_new_tokens=444, | |
) | |
transcriptions = tokenizer.batch_decode(outputs, skip_special_tokens=False) | |
return transcriptions[0] | |
iface = gr.Interface( | |
fn=transcribe, | |
inputs=gr.Audio(sources=["microphone", "upload"], type="filepath"), | |
outputs="text", | |
title="Whisper Geo", | |
description="Realtime demo for Georgian speech recognition using a fine-tuned Whisper model.", | |
) | |
iface.launch() | |