jcho02's picture
Update app.py
a6d5ae5 verified
raw
history blame
2.25 kB
import gradio as gr
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import WhisperModel, WhisperFeatureExtractor
import datasets
from datasets import load_dataset, DatasetDict, Audio
from huggingface_hub import PyTorchModelHubMixin
import numpy as np
# [Your existing code for device setup, config, SpeechInferenceDataset, SpeechClassifier]
# Prepare data function
def prepare_data(audio_data, sampling_rate, model_checkpoint="openai/whisper-base"):
feature_extractor = WhisperFeatureExtractor.from_pretrained(model_checkpoint)
inputs = feature_extractor(audio_data, sampling_rate=sampling_rate, return_tensors="pt")
input_features = inputs.input_features
decoder_input_ids = torch.tensor([[1, 1]])
return input_features.to(device), decoder_input_ids.to(device)
# Prediction function
def predict(audio_data, sampling_rate, config):
input_features, decoder_input_ids = prepare_data(audio_data, sampling_rate, config["encoder"])
model = SpeechClassifier(config).to(device)
model.load_state_dict(torch.hub.load_state_dict_from_url("https://huggingface.co/jcho02/whisper_cleft/resolve/main/pytorch_model.bin", map_location=device))
model.eval()
with torch.no_grad():
logits = model(input_features, decoder_input_ids)
predicted_ids = int(torch.argmax(logits, dim=-1))
return predicted_ids
# Unified Gradio interface function
def gradio_interface(audio_input):
if isinstance(audio_input, tuple):
# If the input is a tuple, it's from the microphone
audio_data, sample_rate = audio_input
else:
# Otherwise, it's an uploaded file
with open(audio_input, "rb") as f:
audio_data = np.frombuffer(f.read(), np.int16)
sample_rate = 16000 # Assume 16kHz sample rate for uploaded files
prediction = predict(audio_data, sample_rate, config)
label = "Hypernasality Detected" if prediction == 1 else "No Hypernasality Detected"
return label
# Create Gradio interface
demo = gr.Interface(
fn=gradio_interface,
inputs=gr.Audio(type="numpy", label="Upload or Record Audio"),
outputs=gr.Textbox(label="Prediction")
)
# Launch the demo
demo.launch(debug=True)