ASesYusuf1 commited on
Commit
0bac694
·
verified ·
1 Parent(s): 5e51096

Upload ensemble.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. ensemble.py +148 -0
ensemble.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+ __author__ = 'Roman Solovyev (ZFTurbo): https://github.com/ZFTurbo/'
3
+
4
+ import os
5
+ import librosa
6
+ import soundfile as sf
7
+ import numpy as np
8
+ import argparse
9
+
10
+ def stft(wave, nfft, hl):
11
+ wave_left = np.asfortranarray(wave[0])
12
+ wave_right = np.asfortranarray(wave[1])
13
+ spec_left = librosa.stft(wave_left, n_fft=nfft, hop_length=hl)
14
+ spec_right = librosa.stft(wave_right, n_fft=nfft, hop_length=hl)
15
+ spec = np.asfortranarray([spec_left, spec_right])
16
+ return spec
17
+
18
+ def istft(spec, hl, length):
19
+ spec_left = np.asfortranarray(spec[0])
20
+ spec_right = np.asfortranarray(spec[1])
21
+ wave_left = librosa.istft(spec_left, hop_length=hl, length=length)
22
+ wave_right = librosa.istft(spec_right, hop_length=hl, length=length)
23
+ wave = np.asfortranarray([wave_left, wave_right])
24
+ return wave
25
+
26
+ def absmax(a, *, axis):
27
+ dims = list(a.shape)
28
+ dims.pop(axis)
29
+ indices = list(np.ogrid[tuple(slice(0, d) for d in dims)])
30
+ argmax = np.abs(a).argmax(axis=axis)
31
+ insert_pos = (len(a.shape) + axis) % len(a.shape)
32
+ indices.insert(insert_pos, argmax)
33
+ return a[tuple(indices)]
34
+
35
+ def absmin(a, *, axis):
36
+ dims = list(a.shape)
37
+ dims.pop(axis)
38
+ indices = list(np.ogrid[tuple(slice(0, d) for d in dims)])
39
+ argmax = np.abs(a).argmin(axis=axis)
40
+ insert_pos = (len(a.shape) + axis) % len(a.shape)
41
+ indices.insert(insert_pos, argmax)
42
+ return a[tuple(indices)]
43
+
44
+ def lambda_max(arr, axis=None, key=None, keepdims=False):
45
+ idxs = np.argmax(key(arr), axis)
46
+ if axis is not None:
47
+ idxs = np.expand_dims(idxs, axis)
48
+ result = np.take_along_axis(arr, idxs, axis)
49
+ if not keepdims:
50
+ result = np.squeeze(result, axis=axis)
51
+ return result
52
+ else:
53
+ return arr.flatten()[idxs]
54
+
55
+ def lambda_min(arr, axis=None, key=None, keepdims=False):
56
+ idxs = np.argmin(key(arr), axis)
57
+ if axis is not None:
58
+ idxs = np.expand_dims(idxs, axis)
59
+ result = np.take_along_axis(arr, idxs, axis)
60
+ if not keepdims:
61
+ result = np.squeeze(result, axis=axis)
62
+ return result
63
+ else:
64
+ return arr.flatten()[idxs]
65
+
66
+ def average_waveforms(pred_track, weights, algorithm):
67
+ """
68
+ :param pred_track: shape = (num, channels, length)
69
+ :param weights: shape = (num, )
70
+ :param algorithm: One of avg_wave, median_wave, min_wave, max_wave, avg_fft, median_fft, min_fft, max_fft
71
+ :return: averaged waveform in shape (channels, length)
72
+ """
73
+ pred_track = np.array(pred_track)
74
+ final_length = pred_track.shape[-1]
75
+
76
+ mod_track = []
77
+ for i in range(pred_track.shape[0]):
78
+ if algorithm == 'avg_wave':
79
+ mod_track.append(pred_track[i] * weights[i])
80
+ elif algorithm in ['median_wave', 'min_wave', 'max_wave']:
81
+ mod_track.append(pred_track[i])
82
+ elif algorithm in ['avg_fft', 'min_fft', 'max_fft', 'median_fft']:
83
+ spec = stft(pred_track[i], nfft=2048, hl=1024)
84
+ if algorithm in ['avg_fft']:
85
+ mod_track.append(spec * weights[i])
86
+ else:
87
+ mod_track.append(spec)
88
+ pred_track = np.array(mod_track)
89
+
90
+ if algorithm in ['avg_wave']:
91
+ pred_track = pred_track.sum(axis=0)
92
+ pred_track /= np.array(weights).sum().T
93
+ elif algorithm in ['median_wave']:
94
+ pred_track = np.median(pred_track, axis=0)
95
+ elif algorithm in ['min_wave']:
96
+ pred_track = lambda_min(pred_track, axis=0, key=np.abs)
97
+ elif algorithm in ['max_wave']:
98
+ pred_track = lambda_max(pred_track, axis=0, key=np.abs)
99
+ elif algorithm in ['avg_fft']:
100
+ pred_track = pred_track.sum(axis=0)
101
+ pred_track /= np.array(weights).sum()
102
+ pred_track = istft(pred_track, 1024, final_length)
103
+ elif algorithm in ['min_fft']:
104
+ pred_track = lambda_min(pred_track, axis=0, key=np.abs)
105
+ pred_track = istft(pred_track, 1024, final_length)
106
+ elif algorithm in ['max_fft']:
107
+ pred_track = absmax(pred_track, axis=0)
108
+ pred_track = istft(pred_track, 1024, final_length)
109
+ elif algorithm in ['median_fft']:
110
+ pred_track = np.median(pred_track, axis=0)
111
+ pred_track = istft(pred_track, 1024, final_length)
112
+ return pred_track
113
+
114
+ def ensemble_files(args):
115
+ parser = argparse.ArgumentParser()
116
+ parser.add_argument("--files", type=str, required=True, nargs='+', help="Path to all audio-files to ensemble")
117
+ parser.add_argument("--type", type=str, default='avg_wave', help="One of avg_wave, median_wave, min_wave, max_wave, avg_fft, median_fft, min_fft, max_fft")
118
+ parser.add_argument("--weights", type=float, nargs='+', help="Weights to create ensemble. Number of weights must be equal to number of files")
119
+ parser.add_argument("--output", default="res.wav", type=str, help="Path to wav file where ensemble result will be stored")
120
+ args = parser.parse_args(args) if isinstance(args, list) else parser.parse_args()
121
+
122
+ print('Ensemble type: {}'.format(args.type))
123
+ print('Number of input files: {}'.format(len(args.files)))
124
+ if args.weights is not None:
125
+ weights = args.weights
126
+ else:
127
+ weights = np.ones(len(args.files))
128
+ print('Weights: {}'.format(weights))
129
+ print('Output file: {}'.format(args.output))
130
+
131
+ data = []
132
+ for f in args.files:
133
+ if not os.path.isfile(f):
134
+ print('Error. Can\'t find file: {}. Check paths.'.format(f))
135
+ return None
136
+ print('Reading file: {}'.format(f))
137
+ wav, sr = librosa.load(f, sr=None, mono=False)
138
+ print("Waveform shape: {} sample rate: {}".format(wav.shape, sr))
139
+ data.append(wav)
140
+
141
+ data = np.array(data)
142
+ res = average_waveforms(data, weights, args.type)
143
+ print('Result shape: {}'.format(res.shape))
144
+ sf.write(args.output, res.T, sr, 'FLOAT')
145
+ return args.output
146
+
147
+ if __name__ == "__main__":
148
+ ensemble_files(None)