ASR / app.py
NightPrince's picture
Update app.py
0bbfec6 verified
import gradio as gr
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
import torch
import numpy as np
import librosa
# Load the pre-trained model and processor
model_name = "facebook/s2t-wav2vec2-large-en-ar"
model = Wav2Vec2ForCTC.from_pretrained(model_name)
processor = Wav2Vec2Processor.from_pretrained(model_name)
# Function to transcribe audio using the model
def transcribe(audio):
# Extract audio data from the tuple (audio, sample_rate)
audio_data, sample_rate = audio
# Resample the audio to 16kHz if necessary
if audio_data.ndim > 1: # If audio is stereo
audio_data = audio_data.mean(axis=1) # Convert to mono
# Ensure the audio is resampled to 16kHz if it's not already
if sample_rate != 16000:
audio_data = librosa.resample(audio_data, orig_sr=sample_rate, target_sr=16000)
# Process the audio to match the model's input format
inputs = processor(audio_data, return_tensors="pt", sampling_rate=16000)
# Get the model's predictions
with torch.no_grad():
logits = model(input_values=inputs.input_values).logits
# Decode the predicted text
predicted_ids = logits.argmax(dim=-1)
transcription = processor.decode(predicted_ids[0])
return transcription
# Create the Gradio interface
interface = gr.Interface(
fn=transcribe,
inputs=gr.Audio(type="numpy"), # Take the audio input as numpy array
outputs="text" # Optional: live transcribing as you speak
)
# Launch the interface
interface.launch()