geo-whisper / app.py
GiorgiSekhniashvili's picture
use any backend
ebbe76c
from pathlib import Path
import gradio as gr
import torch
import torchaudio
from transformers import (
WhisperFeatureExtractor,
WhisperForConditionalGeneration,
WhisperTokenizerFast,
)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32
SAMPLING_RATE = 16_000
model = WhisperForConditionalGeneration.from_pretrained(
"GiorgiSekhniashvili/whisper-tiny-ka-09", torch_dtype=DTYPE
)
feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-tiny")
tokenizer = WhisperTokenizerFast.from_pretrained("openai/whisper-tiny")
forced_decoder_ids = tokenizer.get_decoder_prompt_ids(
language="georgian", task="transcribe"
)
model.generation_config.forced_decoder_ids = forced_decoder_ids
model.to(DEVICE)
def load_audio(audio_path: Path, target_sr: int):
waveform, sr = torchaudio.load(audio_path)
if waveform.shape[0] > 1:
waveform = waveform.mean(dim=0, keepdim=True)
if sr != target_sr:
waveform = torchaudio.functional.resample(
waveform, orig_freq=sr, new_freq=target_sr
)
return waveform
def transcribe(audio):
try:
waveform = load_audio(audio, target_sr=SAMPLING_RATE)
except Exception as e:
return str(e)
input_values = feature_extractor(
waveform[0], sampling_rate=SAMPLING_RATE, return_tensors="pt"
)
input_features = input_values.input_features.to(DEVICE, dtype=DTYPE)
with torch.no_grad():
outputs = model.generate(
input_features,
forced_decoder_ids=forced_decoder_ids,
max_new_tokens=444,
)
transcriptions = tokenizer.batch_decode(outputs, skip_special_tokens=False)
return transcriptions[0]
iface = gr.Interface(
fn=transcribe,
inputs=gr.Audio(sources=["microphone", "upload"], type="filepath"),
outputs="text",
title="Whisper Geo",
description="Realtime demo for Georgian speech recognition using a fine-tuned Whisper model.",
)
iface.launch()