import gradio as gr import torch import torch.nn as nn from torch.utils.data import Dataset, DataLoader from transformers import WhisperModel, WhisperFeatureExtractor import datasets from datasets import load_dataset, DatasetDict, Audio from huggingface_hub import PyTorchModelHubMixin import numpy as np # [Your existing code for device setup, config, SpeechInferenceDataset, SpeechClassifier] # Prepare data function def prepare_data(audio_data, sampling_rate, model_checkpoint="openai/whisper-base"): feature_extractor = WhisperFeatureExtractor.from_pretrained(model_checkpoint) inputs = feature_extractor(audio_data, sampling_rate=sampling_rate, return_tensors="pt") input_features = inputs.input_features decoder_input_ids = torch.tensor([[1, 1]]) return input_features.to(device), decoder_input_ids.to(device) # Prediction function def predict(audio_data, sampling_rate, config): input_features, decoder_input_ids = prepare_data(audio_data, sampling_rate, config["encoder"]) model = SpeechClassifier(config).to(device) model.load_state_dict(torch.hub.load_state_dict_from_url("https://huggingface.co/jcho02/whisper_cleft/resolve/main/pytorch_model.bin", map_location=device)) model.eval() with torch.no_grad(): logits = model(input_features, decoder_input_ids) predicted_ids = int(torch.argmax(logits, dim=-1)) return predicted_ids # Unified Gradio interface function def gradio_interface(audio_input): if isinstance(audio_input, tuple): # If the input is a tuple, it's from the microphone audio_data, sample_rate = audio_input else: # Otherwise, it's an uploaded file with open(audio_input, "rb") as f: audio_data = np.frombuffer(f.read(), np.int16) sample_rate = 16000 # Assume 16kHz sample rate for uploaded files prediction = predict(audio_data, sample_rate, config) label = "Hypernasality Detected" if prediction == 1 else "No Hypernasality Detected" return label # Create Gradio interface demo = gr.Interface( fn=gradio_interface, inputs=gr.Audio(type="numpy", label="Upload or Record Audio"), outputs=gr.Textbox(label="Prediction") ) # Launch the demo demo.launch(debug=True)