ajayarora1235 commited on
Commit
903c36e
·
1 Parent(s): 4da933d

add missing file

Browse files
Files changed (1) hide show
  1. trainset_preprocess_pipeline_print.py +128 -0
trainset_preprocess_pipeline_print.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys, os, multiprocessing
2
+ from scipy import signal
3
+
4
+ now_dir = os.getcwd()
5
+ sys.path.append(now_dir)
6
+
7
+ import numpy as np, os, traceback
8
+ from slicer2 import Slicer
9
+ import librosa, traceback
10
+ from scipy.io import wavfile
11
+ import multiprocessing
12
+ from my_utils import load_audio
13
+
14
+ mutex = multiprocessing.Lock()
15
+
16
+
17
+ def println(strr):
18
+ mutex.acquire()
19
+ print(strr)
20
+ mutex.release()
21
+
22
+
23
+ class PreProcess:
24
+ def __init__(self, sr, exp_dir, noparallel=True):
25
+ self.slicer = Slicer(
26
+ sr=sr,
27
+ threshold=-42,
28
+ min_length=1500,
29
+ min_interval=400,
30
+ hop_size=15,
31
+ max_sil_kept=500,
32
+ )
33
+ self.sr = sr
34
+ self.bh, self.ah = signal.butter(N=5, Wn=48, btype="high", fs=self.sr)
35
+ self.per = 3.7
36
+ self.overlap = 0.3
37
+ self.tail = self.per + self.overlap
38
+ self.max = 0.9
39
+ self.alpha = 0.75
40
+ self.exp_dir = exp_dir
41
+ self.gt_wavs_dir = "%s/0_gt_wavs" % exp_dir
42
+ self.wavs16k_dir = "%s/1_16k_wavs" % exp_dir
43
+ self.noparallel = True
44
+ os.makedirs(self.exp_dir, exist_ok=True)
45
+ os.makedirs(self.gt_wavs_dir, exist_ok=True)
46
+ os.makedirs(self.wavs16k_dir, exist_ok=True)
47
+
48
+ def norm_write(self, tmp_audio, idx0, idx1):
49
+ tmp_audio = (tmp_audio / np.abs(tmp_audio).max() * (self.max * self.alpha)) + (
50
+ 1 - self.alpha
51
+ ) * tmp_audio
52
+ wavfile.write(
53
+ "%s/%s_%s.wav" % (self.gt_wavs_dir, idx0, idx1),
54
+ self.sr,
55
+ tmp_audio.astype(np.float32),
56
+ )
57
+ tmp_audio = librosa.resample(
58
+ tmp_audio, orig_sr=self.sr, target_sr=16000
59
+ ) # , res_type="soxr_vhq"
60
+ wavfile.write(
61
+ "%s/%s_%s.wav" % (self.wavs16k_dir, idx0, idx1),
62
+ 16000,
63
+ tmp_audio.astype(np.float32),
64
+ )
65
+
66
+ def pipeline(self, path, idx0):
67
+ try:
68
+ audio = load_audio(path, self.sr)
69
+ # zero phased digital filter cause pre-ringing noise...
70
+ # audio = signal.filtfilt(self.bh, self.ah, audio)
71
+ audio = signal.lfilter(self.bh, self.ah, audio)
72
+
73
+ idx1 = 0
74
+ for audio in self.slicer.slice(audio):
75
+ i = 0
76
+ while 1:
77
+ start = int(self.sr * (self.per - self.overlap) * i)
78
+ i += 1
79
+ if len(audio[start:]) > self.tail * self.sr:
80
+ tmp_audio = audio[start : start + int(self.per * self.sr)]
81
+ self.norm_write(tmp_audio, idx0, idx1)
82
+ idx1 += 1
83
+ else:
84
+ tmp_audio = audio[start:]
85
+ idx1 += 1
86
+ break
87
+ self.norm_write(tmp_audio, idx0, idx1)
88
+ println("%s->Suc." % path)
89
+ except:
90
+ println("%s->%s" % (path, traceback.format_exc()))
91
+
92
+ def pipeline_mp(self, infos):
93
+ for path, idx0 in infos:
94
+ self.pipeline(path, idx0)
95
+
96
+ def pipeline_mp_inp_dir(self, inp_root, n_p):
97
+ try:
98
+ infos = [
99
+ ("%s/%s" % (inp_root, name), idx)
100
+ for idx, name in enumerate(sorted(list(os.listdir(inp_root))))
101
+ ]
102
+ if self.noparallel:
103
+ for i in range(n_p):
104
+ self.pipeline_mp(infos[i::n_p])
105
+ else:
106
+ ps = []
107
+ for i in range(n_p):
108
+ p = multiprocessing.Process(
109
+ target=self.pipeline_mp, args=(infos[i::n_p],)
110
+ )
111
+ ps.append(p)
112
+ p.start()
113
+ for i in range(n_p):
114
+ ps[i].join()
115
+ except:
116
+ println("Fail. %s" % traceback.format_exc())
117
+
118
+
119
+ def preprocess_trainset(inp_root, sr, n_p, exp_dir, noparallel=True):
120
+ pp = PreProcess(sr, exp_dir, noparallel=noparallel)
121
+ println("start preprocess")
122
+ println(sys.argv)
123
+ pp.pipeline_mp_inp_dir(inp_root, n_p)
124
+ println("end preprocess")
125
+
126
+
127
+ if __name__ == "__main__":
128
+ preprocess_trainset(inp_root, sr, n_p, exp_dir)