|
import torchaudio |
|
import torch |
|
from model import M11 |
|
import gradio as gr |
|
|
|
def _cut_if_necessary(signal): |
|
if signal.shape[1] > 400000: |
|
signal = signal[:, :400000] |
|
|
|
return signal |
|
|
|
def _right_pad_if_necessary(signal): |
|
signal_length = signal.shape[1] |
|
if signal_length < 400000: |
|
num_missing_samples = 400000 - signal_length |
|
last_dim_padding = (0, num_missing_samples) |
|
signal = torch.nn.functional.pad(signal, last_dim_padding) |
|
|
|
return signal |
|
|
|
def preprocess(signal, sr, device): |
|
|
|
|
|
if len(signal.shape) == 1: |
|
signal = signal.unsqueeze(0) |
|
|
|
|
|
if sr != 8_000: |
|
resampler = torchaudio.transforms.Resample(sr, 8_000).to(device) |
|
signal = resampler(signal) |
|
|
|
|
|
if signal.shape[0] > 1: |
|
signal = torch.mean(signal, dim=0, keepdim=True) |
|
|
|
signal = _cut_if_necessary(signal) |
|
signal = _right_pad_if_necessary(signal) |
|
|
|
|
|
return signal |
|
|
|
|
|
def pipeline(audio_file): |
|
|
|
audio_PATH = audio_file.name |
|
audio, sample_rate = torchaudio.load(audio_PATH) |
|
|
|
processed_audio = preprocess(audio.to(DEVICE), sample_rate, DEVICE) |
|
|
|
with torch.no_grad(): |
|
pred = torch.exp(classifier(processed_audio.unsqueeze(0)).squeeze()) |
|
|
|
print({labels[i]: float(pred[i]) for i in range(3)}) |
|
print(classifier(processed_audio.unsqueeze(0)).squeeze()) |
|
|
|
return {labels[i]: float(pred[i]) for i in range(3)} |
|
|
|
|
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
model_PATH = "./model.ckpt" |
|
|
|
labels = ["Threat", "Normal", "Sarcastic"] |
|
|
|
classifier = M11.load_from_checkpoint(model_PATH).to(DEVICE) |
|
classifier.eval() |
|
|
|
|
|
inputs = gr.inputs.Audio(label="Input Audio", type="file") |
|
outputs = gr.outputs.Label(num_top_classes=3) |
|
title = "Threat Detection From Bengali Voice Calls" |
|
description = "Gradio demo for Audio Classification, simply upload your audio, or click one of the examples to load them. Read more at the links below." |
|
article = "<p style='text-align: center'><a href='https://github.com/khalidsaifullaah' target='_blank'>Github Repo</a></p>" |
|
examples = [ |
|
['sample_audio.wav'] |
|
] |
|
gr.Interface(pipeline, inputs, outputs, title=title, description=description, article=article, examples=examples).launch() |