khalidsaifullaah
commited on
Commit
·
573ac8e
1
Parent(s):
009a9e0
input pipeline fixed
Browse files
app.py
CHANGED
@@ -3,16 +3,27 @@ import torch
|
|
3 |
from model import M11
|
4 |
import gradio as gr
|
5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
def preprocess(signal, sr, device):
|
8 |
|
9 |
-
#
|
10 |
-
if len(signal.shape)
|
11 |
-
signal = torch.permute(signal, (1, 0))
|
12 |
-
# Otherwise, add a channel dimension
|
13 |
-
else:
|
14 |
signal = signal.unsqueeze(0)
|
15 |
-
|
16 |
# resampling the audio signal with the training sample rate
|
17 |
if sr != 8_000:
|
18 |
resampler = torchaudio.transforms.Resample(sr, 8_000).to(device)
|
@@ -21,19 +32,23 @@ def preprocess(signal, sr, device):
|
|
21 |
# turning the stereo signals into mono
|
22 |
if signal.shape[0] > 1:
|
23 |
signal = torch.mean(signal, dim=0, keepdim=True)
|
24 |
-
|
|
|
|
|
|
|
|
|
25 |
return signal
|
26 |
|
27 |
|
28 |
-
def pipeline(
|
29 |
|
30 |
-
|
|
|
31 |
|
32 |
-
processed_audio = preprocess(
|
33 |
|
34 |
with torch.no_grad():
|
35 |
pred = torch.exp(classifier(processed_audio.unsqueeze(0)).squeeze()) # turning log_softmax into probabilities
|
36 |
-
|
37 |
|
38 |
return {labels[i]: float(pred[i]) for i in range(3)}
|
39 |
|
@@ -47,7 +62,7 @@ classifier = M11.load_from_checkpoint(model_PATH).to(DEVICE)
|
|
47 |
classifier.eval()
|
48 |
|
49 |
|
50 |
-
inputs = gr.inputs.Audio(label="Input Audio", type="
|
51 |
outputs = gr.outputs.Label(num_top_classes=3)
|
52 |
title = "Threat Detection From Bengali Voice Calls"
|
53 |
description = "Gradio demo for Audio Classification, simply upload your audio, or click one of the examples to load them. Read more at the links below."
|
@@ -55,4 +70,4 @@ article = "<p style='text-align: center'><a href='https://github.com/khalidsaifu
|
|
55 |
examples = [
|
56 |
['sample_audio.wav']
|
57 |
]
|
58 |
-
gr.Interface(pipeline,
|
|
|
3 |
from model import M11
|
4 |
import gradio as gr
|
5 |
|
6 |
+
def _cut_if_necessary(signal):
|
7 |
+
if signal.shape[1] > 400000:
|
8 |
+
signal = signal[:, :400000]
|
9 |
+
|
10 |
+
return signal
|
11 |
+
|
12 |
+
def _right_pad_if_necessary(signal):
|
13 |
+
signal_length = signal.shape[1]
|
14 |
+
if signal_length < 400000:
|
15 |
+
num_missing_samples = 400000 - signal_length
|
16 |
+
last_dim_padding = (0, num_missing_samples) # will add 0 number of zeros in the left side of array and num_missing_samples number of zeros in the right part
|
17 |
+
signal = torch.nn.functional.pad(signal, last_dim_padding)
|
18 |
+
|
19 |
+
return signal
|
20 |
|
21 |
def preprocess(signal, sr, device):
|
22 |
|
23 |
+
# add a channel dimension for 1d samples
|
24 |
+
if len(signal.shape) == 1:
|
|
|
|
|
|
|
25 |
signal = signal.unsqueeze(0)
|
26 |
+
|
27 |
# resampling the audio signal with the training sample rate
|
28 |
if sr != 8_000:
|
29 |
resampler = torchaudio.transforms.Resample(sr, 8_000).to(device)
|
|
|
32 |
# turning the stereo signals into mono
|
33 |
if signal.shape[0] > 1:
|
34 |
signal = torch.mean(signal, dim=0, keepdim=True)
|
35 |
+
|
36 |
+
signal = _cut_if_necessary(signal) # truncating longer signals
|
37 |
+
signal = _right_pad_if_necessary(signal) # extending shorter signals
|
38 |
+
|
39 |
+
|
40 |
return signal
|
41 |
|
42 |
|
43 |
+
def pipeline(audio_file):
|
44 |
|
45 |
+
audio_PATH = audio_file.name
|
46 |
+
audio, sample_rate = torchaudio.load(audio_PATH)
|
47 |
|
48 |
+
processed_audio = preprocess(audio.to(DEVICE), sample_rate, DEVICE)
|
49 |
|
50 |
with torch.no_grad():
|
51 |
pred = torch.exp(classifier(processed_audio.unsqueeze(0)).squeeze()) # turning log_softmax into probabilities
|
|
|
52 |
|
53 |
return {labels[i]: float(pred[i]) for i in range(3)}
|
54 |
|
|
|
62 |
classifier.eval()
|
63 |
|
64 |
|
65 |
+
inputs = gr.inputs.Audio(label="Input Audio", type="file")
|
66 |
outputs = gr.outputs.Label(num_top_classes=3)
|
67 |
title = "Threat Detection From Bengali Voice Calls"
|
68 |
description = "Gradio demo for Audio Classification, simply upload your audio, or click one of the examples to load them. Read more at the links below."
|
|
|
70 |
examples = [
|
71 |
['sample_audio.wav']
|
72 |
]
|
73 |
+
gr.Interface(pipeline, inputs, outputs, title=title, description=description, article=article, examples=examples).launch()
|