File size: 2,121 Bytes
b60a7b6
 
 
81d87a8
b60a7b6
 
22935a7
3dfc859
b60a7b6
5fb3738
77d5702
5fb3738
81d87a8
 
3dfc859
81d87a8
3dfc859
 
 
 
 
 
 
 
 
81d87a8
 
 
 
 
 
 
 
77d5702
b3b61c9
 
 
 
 
 
 
 
3dfc859
b60a7b6
 
 
 
 
 
 
 
 
b3b61c9
 
 
5fb3738
b3b61c9
5fb3738
b3b61c9
 
 
 
b60a7b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81d87a8
b60a7b6
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
"""
Util functions to process any incoming audio data to be processable by the model 
"""
import os
import torch
import torchaudio
# import wget
import requests

DEFAULT_SAMPLE_RATE=48000
DEFAULT_WAVE_LENGTH=3

def process_from_url(url):
    # download UI audio
    req_url = requests.get(url)

    with open('temp.wav', 'wb') as file:
        file.write(req_url.content)

    
    # filename = 'temp.wav'
    # audio = torchaudio.load(filename)

    # # remove wget file
    # os.remove(filename)

    # spec
    spec = process_from_filename('temp.wav')

    os.remove('temp.wav')
    return spec


def process_from_filename(filename, target_sample_rate=DEFAULT_SAMPLE_RATE, wav_length=DEFAULT_WAVE_LENGTH):
    wav, sample_rate = torchaudio.load(filename)

    wav = process_raw_wav(wav, sample_rate, target_sample_rate, wav_length)

    spec = _wav_to_spec(wav, target_sample_rate)

    return spec

def process_raw_wav(wav, sample_rate, target_sample_rate, wav_length):
    num_samples = wav_length * target_sample_rate

    wav = _resample(wav, sample_rate, target_sample_rate)
    wav = _mix_down(wav)
    wav = _cut(wav, num_samples)
    wav = _pad(wav, num_samples)

    return wav

def _wav_to_spec(wav, target_sample_rate):
    mel_spectrogram = torchaudio.transforms.MelSpectrogram(
        sample_rate=target_sample_rate,
        n_fft=2048,
        hop_length=512,
        n_mels=128,
    )

    return mel_spectrogram(wav)

def _resample(wav, sample_rate, target_sample_rate):
    if sample_rate != target_sample_rate:
        resampler = torchaudio.transforms.Resample(sample_rate, target_sample_rate)
        wav = resampler(wav)
    
    return wav

def _mix_down(wav):
    if wav.shape[0] > 1:
        wav = torch.mean(wav, dim=0, keepdim=True)
    
    return wav

def _cut(wav, num_samples):
    if wav.shape[1] > num_samples:
        wav = wav[:, :num_samples]
    
    return wav

def _pad(wav, num_samples):
    if wav.shape[1] < num_samples:
        missing_samples = num_samples - wav.shape[1]
        pad = (0, missing_samples)
        wav = torch.nn.functional.pad(wav, pad)
    
    return wav