|
|
|
|
|
import gradio as gr |
|
import torch |
|
import torch.nn as nn |
|
from torch.utils.data import Dataset, DataLoader |
|
from transformers import WhisperModel, WhisperFeatureExtractor |
|
import datasets |
|
from datasets import load_dataset, DatasetDict, Audio |
|
from huggingface_hub import PyTorchModelHubMixin |
|
|
|
|
|
class SpeechInferenceDataset(Dataset): |
|
def __init__(self, audio_data, text_processor): |
|
self.audio_data = audio_data |
|
self.text_processor = text_processor |
|
|
|
def __len__(self): |
|
return len(self.audio_data) |
|
|
|
def __getitem__(self, index): |
|
inputs = self.text_processor(self.audio_data[index]["audio"]["array"], |
|
return_tensors="pt", |
|
sampling_rate=self.audio_data[index]["audio"]["sampling_rate"]) |
|
input_features = inputs.input_features |
|
|
|
decoder_input_ids = torch.tensor([[1, 1]]) * encoder.config.decoder_start_token_id |
|
return input_features, decoder_input_ids |
|
|
|
|
|
class SpeechClassifier(nn.Module, PyTorchModelHubMixin): |
|
def __init__(self, config): |
|
super(SpeechClassifier, self).__init__() |
|
self.encoder = WhisperModel.from_pretrained(config["encoder"]) |
|
self.classifier = nn.Sequential( |
|
nn.Linear(self.encoder.config.hidden_size, 4096), |
|
nn.ReLU(), |
|
nn.Linear(4096, 2048), |
|
nn.ReLU(), |
|
nn.Linear(2048, 1024), |
|
nn.ReLU(), |
|
nn.Linear(1024, 512), |
|
nn.ReLU(), |
|
nn.Linear(512, config["num_labels"]) |
|
) |
|
|
|
def forward(self, input_features, decoder_input_ids): |
|
outputs = self.encoder(input_features, decoder_input_ids=decoder_input_ids) |
|
pooled_output = outputs['last_hidden_state'][:, 0, :] |
|
logits = self.classifier(pooled_output) |
|
return logits |
|
|
|
|
|
def prepare_data(audio_file_path, model_checkpoint="openai/whisper-base"): |
|
feature_extractor = WhisperFeatureExtractor.from_pretrained(model_checkpoint) |
|
inference_data = datasets.Dataset.from_dict({"path": [audio_file_path], "audio": [audio_file_path]}).cast_column("audio", Audio(sampling_rate=16_000)) |
|
inference_dataset = SpeechInferenceDataset(inference_data, feature_extractor) |
|
inference_loader = DataLoader(inference_dataset, batch_size=1, shuffle=False) |
|
input_features, decoder_input_ids = next(iter(inference_loader)) |
|
|
|
input_features = input_features.squeeze(1).to(device) |
|
decoder_input_ids = decoder_input_ids.squeeze(1).to(device) |
|
return input_features, decoder_input_ids |
|
|
|
|
|
def predict(audio_file_path, config={"encoder": "openai/whisper-base", "num_labels": 2}): |
|
input_features, decoder_input_ids = prepare_data(audio_file_path) |
|
model = SpeechClassifier(config) |
|
model.eval() |
|
with torch.no_grad(): |
|
logits = model(input_features, decoder_input_ids) |
|
predicted_ids = int(torch.argmax(logits, dim=-1)) |
|
return predicted_ids |
|
|
|
|
|
def gradio_interface(uploaded_file): |
|
with open(uploaded_file.name, "wb") as f: |
|
f.write(uploaded_file.read()) |
|
prediction = predict(uploaded_file.name) |
|
label = "Hypernasality Detected" if prediction == 1 else "No Hypernasality Detected" |
|
return label |
|
|
|
|
|
iface = gr.Interface(fn=gradio_interface, |
|
inputs=gr.inputs.File(label="Upload Audio File"), |
|
outputs="text") |
|
iface.launch() |
|
|