Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
import torch.nn as nn | |
from torch.utils.data import Dataset, DataLoader | |
from transformers import WhisperModel, WhisperFeatureExtractor | |
import datasets | |
from datasets import load_dataset, DatasetDict, Audio | |
from huggingface_hub import PyTorchModelHubMixin | |
import numpy as np | |
# [Your existing code for device setup, config, SpeechInferenceDataset, SpeechClassifier] | |
# Prepare data function | |
def prepare_data(audio_data, sampling_rate, model_checkpoint="openai/whisper-base"): | |
feature_extractor = WhisperFeatureExtractor.from_pretrained(model_checkpoint) | |
inputs = feature_extractor(audio_data, sampling_rate=sampling_rate, return_tensors="pt") | |
input_features = inputs.input_features | |
decoder_input_ids = torch.tensor([[1, 1]]) | |
return input_features.to(device), decoder_input_ids.to(device) | |
# Prediction function | |
def predict(audio_data, sampling_rate, config): | |
input_features, decoder_input_ids = prepare_data(audio_data, sampling_rate, config["encoder"]) | |
model = SpeechClassifier(config).to(device) | |
model.load_state_dict(torch.hub.load_state_dict_from_url("https://huggingface.co/jcho02/whisper_cleft/resolve/main/pytorch_model.bin", map_location=device)) | |
model.eval() | |
with torch.no_grad(): | |
logits = model(input_features, decoder_input_ids) | |
predicted_ids = int(torch.argmax(logits, dim=-1)) | |
return predicted_ids | |
# Unified Gradio interface function | |
def gradio_interface(audio_input): | |
if isinstance(audio_input, tuple): | |
# If the input is a tuple, it's from the microphone | |
audio_data, sample_rate = audio_input | |
else: | |
# Otherwise, it's an uploaded file | |
with open(audio_input, "rb") as f: | |
audio_data = np.frombuffer(f.read(), np.int16) | |
sample_rate = 16000 # Assume 16kHz sample rate for uploaded files | |
prediction = predict(audio_data, sample_rate, config) | |
label = "Hypernasality Detected" if prediction == 1 else "No Hypernasality Detected" | |
return label | |
# Create Gradio interface | |
demo = gr.Interface( | |
fn=gradio_interface, | |
inputs=gr.Audio(type="numpy", label="Upload or Record Audio"), | |
outputs=gr.Textbox(label="Prediction") | |
) | |
# Launch the demo | |
demo.launch(debug=True) | |