khalidsaifullaah commited on
Commit
573ac8e
·
1 Parent(s): 009a9e0

input pipeline fixed

Browse files
Files changed (1) hide show
  1. app.py +28 -13
app.py CHANGED
@@ -3,16 +3,27 @@ import torch
3
  from model import M11
4
  import gradio as gr
5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  def preprocess(signal, sr, device):
8
 
9
- # making sure channel dimension is in the 1st axis
10
- if len(signal.shape) > 1:
11
- signal = torch.permute(signal, (1, 0))
12
- # Otherwise, add a channel dimension
13
- else:
14
  signal = signal.unsqueeze(0)
15
-
16
  # resampling the audio signal with the training sample rate
17
  if sr != 8_000:
18
  resampler = torchaudio.transforms.Resample(sr, 8_000).to(device)
@@ -21,19 +32,23 @@ def preprocess(signal, sr, device):
21
  # turning the stereo signals into mono
22
  if signal.shape[0] > 1:
23
  signal = torch.mean(signal, dim=0, keepdim=True)
24
-
 
 
 
 
25
  return signal
26
 
27
 
28
- def pipeline(input):
29
 
30
- sample_rate, audio = input
 
31
 
32
- processed_audio = preprocess(torch.from_numpy(audio).float().to(DEVICE), sample_rate, DEVICE)
33
 
34
  with torch.no_grad():
35
  pred = torch.exp(classifier(processed_audio.unsqueeze(0)).squeeze()) # turning log_softmax into probabilities
36
-
37
 
38
  return {labels[i]: float(pred[i]) for i in range(3)}
39
 
@@ -47,7 +62,7 @@ classifier = M11.load_from_checkpoint(model_PATH).to(DEVICE)
47
  classifier.eval()
48
 
49
 
50
- inputs = gr.inputs.Audio(label="Input Audio", type="numpy")
51
  outputs = gr.outputs.Label(num_top_classes=3)
52
  title = "Threat Detection From Bengali Voice Calls"
53
  description = "Gradio demo for Audio Classification, simply upload your audio, or click one of the examples to load them. Read more at the links below."
@@ -55,4 +70,4 @@ article = "<p style='text-align: center'><a href='https://github.com/khalidsaifu
55
  examples = [
56
  ['sample_audio.wav']
57
  ]
58
- gr.Interface(pipeline, "audio", outputs, title=title, description=description, article=article, examples=examples).launch()
 
3
  from model import M11
4
  import gradio as gr
5
 
6
+ def _cut_if_necessary(signal):
7
+ if signal.shape[1] > 400000:
8
+ signal = signal[:, :400000]
9
+
10
+ return signal
11
+
12
+ def _right_pad_if_necessary(signal):
13
+ signal_length = signal.shape[1]
14
+ if signal_length < 400000:
15
+ num_missing_samples = 400000 - signal_length
16
+ last_dim_padding = (0, num_missing_samples) # will add 0 number of zeros in the left side of array and num_missing_samples number of zeros in the right part
17
+ signal = torch.nn.functional.pad(signal, last_dim_padding)
18
+
19
+ return signal
20
 
21
  def preprocess(signal, sr, device):
22
 
23
+ # add a channel dimension for 1d samples
24
+ if len(signal.shape) == 1:
 
 
 
25
  signal = signal.unsqueeze(0)
26
+
27
  # resampling the audio signal with the training sample rate
28
  if sr != 8_000:
29
  resampler = torchaudio.transforms.Resample(sr, 8_000).to(device)
 
32
  # turning the stereo signals into mono
33
  if signal.shape[0] > 1:
34
  signal = torch.mean(signal, dim=0, keepdim=True)
35
+
36
+ signal = _cut_if_necessary(signal) # truncating longer signals
37
+ signal = _right_pad_if_necessary(signal) # extending shorter signals
38
+
39
+
40
  return signal
41
 
42
 
43
+ def pipeline(audio_file):
44
 
45
+ audio_PATH = audio_file.name
46
+ audio, sample_rate = torchaudio.load(audio_PATH)
47
 
48
+ processed_audio = preprocess(audio.to(DEVICE), sample_rate, DEVICE)
49
 
50
  with torch.no_grad():
51
  pred = torch.exp(classifier(processed_audio.unsqueeze(0)).squeeze()) # turning log_softmax into probabilities
 
52
 
53
  return {labels[i]: float(pred[i]) for i in range(3)}
54
 
 
62
  classifier.eval()
63
 
64
 
65
+ inputs = gr.inputs.Audio(label="Input Audio", type="file")
66
  outputs = gr.outputs.Label(num_top_classes=3)
67
  title = "Threat Detection From Bengali Voice Calls"
68
  description = "Gradio demo for Audio Classification, simply upload your audio, or click one of the examples to load them. Read more at the links below."
 
70
  examples = [
71
  ['sample_audio.wav']
72
  ]
73
+ gr.Interface(pipeline, inputs, outputs, title=title, description=description, article=article, examples=examples).launch()