ASesYusuf1's picture
Upload folder using huggingface_hub
3978e51
import os
from abc import abstractmethod
from typing import Callable
import numpy as np
import torch
from librosa import hz_to_midi, midi_to_hz
from torchaudio import functional as taF
# from spafe.fbanks import bark_fbanks
# from spafe.utils.converters import erb2hz, hz2bark, hz2erb
def band_widths_from_specs(band_specs):
return [e - i for i, e in band_specs]
def check_nonzero_bandwidth(band_specs):
# pprint(band_specs)
for fstart, fend in band_specs:
if fend - fstart <= 0:
raise ValueError("Bands cannot be zero-width")
def check_no_overlap(band_specs):
fend_prev = -1
for fstart_curr, fend_curr in band_specs:
if fstart_curr <= fend_prev:
raise ValueError("Bands cannot overlap")
def check_no_gap(band_specs):
fstart, _ = band_specs[0]
assert fstart == 0
fend_prev = -1
for fstart_curr, fend_curr in band_specs:
if fstart_curr - fend_prev > 1:
raise ValueError("Bands cannot leave gap")
fend_prev = fend_curr
class BandsplitSpecification:
def __init__(self, nfft: int, fs: int) -> None:
self.fs = fs
self.nfft = nfft
self.nyquist = fs / 2
self.max_index = nfft // 2 + 1
self.split500 = self.hertz_to_index(500)
self.split1k = self.hertz_to_index(1000)
self.split2k = self.hertz_to_index(2000)
self.split4k = self.hertz_to_index(4000)
self.split8k = self.hertz_to_index(8000)
self.split16k = self.hertz_to_index(16000)
self.split20k = self.hertz_to_index(20000)
self.above20k = [(self.split20k, self.max_index)]
self.above16k = [(self.split16k, self.split20k)] + self.above20k
def index_to_hertz(self, index: int):
return index * self.fs / self.nfft
def hertz_to_index(self, hz: float, round: bool = True):
index = hz * self.nfft / self.fs
if round:
index = int(np.round(index))
return index
def get_band_specs_with_bandwidth(self, start_index, end_index, bandwidth_hz):
band_specs = []
lower = start_index
while lower < end_index:
upper = int(np.floor(lower + self.hertz_to_index(bandwidth_hz)))
upper = min(upper, end_index)
band_specs.append((lower, upper))
lower = upper
return band_specs
@abstractmethod
def get_band_specs(self):
raise NotImplementedError
class VocalBandsplitSpecification(BandsplitSpecification):
def __init__(self, nfft: int, fs: int, version: str = "7") -> None:
super().__init__(nfft=nfft, fs=fs)
self.version = version
def get_band_specs(self):
return getattr(self, f"version{self.version}")()
@property
def version1(self):
return self.get_band_specs_with_bandwidth(
start_index=0, end_index=self.max_index, bandwidth_hz=1000
)
def version2(self):
below16k = self.get_band_specs_with_bandwidth(
start_index=0, end_index=self.split16k, bandwidth_hz=1000
)
below20k = self.get_band_specs_with_bandwidth(
start_index=self.split16k, end_index=self.split20k, bandwidth_hz=2000
)
return below16k + below20k + self.above20k
def version3(self):
below8k = self.get_band_specs_with_bandwidth(
start_index=0, end_index=self.split8k, bandwidth_hz=1000
)
below16k = self.get_band_specs_with_bandwidth(
start_index=self.split8k, end_index=self.split16k, bandwidth_hz=2000
)
return below8k + below16k + self.above16k
def version4(self):
below1k = self.get_band_specs_with_bandwidth(
start_index=0, end_index=self.split1k, bandwidth_hz=100
)
below8k = self.get_band_specs_with_bandwidth(
start_index=self.split1k, end_index=self.split8k, bandwidth_hz=1000
)
below16k = self.get_band_specs_with_bandwidth(
start_index=self.split8k, end_index=self.split16k, bandwidth_hz=2000
)
return below1k + below8k + below16k + self.above16k
def version5(self):
below1k = self.get_band_specs_with_bandwidth(
start_index=0, end_index=self.split1k, bandwidth_hz=100
)
below16k = self.get_band_specs_with_bandwidth(
start_index=self.split1k, end_index=self.split16k, bandwidth_hz=1000
)
below20k = self.get_band_specs_with_bandwidth(
start_index=self.split16k, end_index=self.split20k, bandwidth_hz=2000
)
return below1k + below16k + below20k + self.above20k
def version6(self):
below1k = self.get_band_specs_with_bandwidth(
start_index=0, end_index=self.split1k, bandwidth_hz=100
)
below4k = self.get_band_specs_with_bandwidth(
start_index=self.split1k, end_index=self.split4k, bandwidth_hz=500
)
below8k = self.get_band_specs_with_bandwidth(
start_index=self.split4k, end_index=self.split8k, bandwidth_hz=1000
)
below16k = self.get_band_specs_with_bandwidth(
start_index=self.split8k, end_index=self.split16k, bandwidth_hz=2000
)
return below1k + below4k + below8k + below16k + self.above16k
def version7(self):
below1k = self.get_band_specs_with_bandwidth(
start_index=0, end_index=self.split1k, bandwidth_hz=100
)
below4k = self.get_band_specs_with_bandwidth(
start_index=self.split1k, end_index=self.split4k, bandwidth_hz=250
)
below8k = self.get_band_specs_with_bandwidth(
start_index=self.split4k, end_index=self.split8k, bandwidth_hz=500
)
below16k = self.get_band_specs_with_bandwidth(
start_index=self.split8k, end_index=self.split16k, bandwidth_hz=1000
)
below20k = self.get_band_specs_with_bandwidth(
start_index=self.split16k, end_index=self.split20k, bandwidth_hz=2000
)
return below1k + below4k + below8k + below16k + below20k + self.above20k
class OtherBandsplitSpecification(VocalBandsplitSpecification):
def __init__(self, nfft: int, fs: int) -> None:
super().__init__(nfft=nfft, fs=fs, version="7")
class BassBandsplitSpecification(BandsplitSpecification):
def __init__(self, nfft: int, fs: int, version: str = "7") -> None:
super().__init__(nfft=nfft, fs=fs)
def get_band_specs(self):
below500 = self.get_band_specs_with_bandwidth(
start_index=0, end_index=self.split500, bandwidth_hz=50
)
below1k = self.get_band_specs_with_bandwidth(
start_index=self.split500, end_index=self.split1k, bandwidth_hz=100
)
below4k = self.get_band_specs_with_bandwidth(
start_index=self.split1k, end_index=self.split4k, bandwidth_hz=500
)
below8k = self.get_band_specs_with_bandwidth(
start_index=self.split4k, end_index=self.split8k, bandwidth_hz=1000
)
below16k = self.get_band_specs_with_bandwidth(
start_index=self.split8k, end_index=self.split16k, bandwidth_hz=2000
)
above16k = [(self.split16k, self.max_index)]
return below500 + below1k + below4k + below8k + below16k + above16k
class DrumBandsplitSpecification(BandsplitSpecification):
def __init__(self, nfft: int, fs: int) -> None:
super().__init__(nfft=nfft, fs=fs)
def get_band_specs(self):
below1k = self.get_band_specs_with_bandwidth(
start_index=0, end_index=self.split1k, bandwidth_hz=50
)
below2k = self.get_band_specs_with_bandwidth(
start_index=self.split1k, end_index=self.split2k, bandwidth_hz=100
)
below4k = self.get_band_specs_with_bandwidth(
start_index=self.split2k, end_index=self.split4k, bandwidth_hz=250
)
below8k = self.get_band_specs_with_bandwidth(
start_index=self.split4k, end_index=self.split8k, bandwidth_hz=500
)
below16k = self.get_band_specs_with_bandwidth(
start_index=self.split8k, end_index=self.split16k, bandwidth_hz=1000
)
above16k = [(self.split16k, self.max_index)]
return below1k + below2k + below4k + below8k + below16k + above16k
class PerceptualBandsplitSpecification(BandsplitSpecification):
def __init__(
self,
nfft: int,
fs: int,
fbank_fn: Callable[[int, int, float, float, int], torch.Tensor],
n_bands: int,
f_min: float = 0.0,
f_max: float = None,
) -> None:
super().__init__(nfft=nfft, fs=fs)
self.n_bands = n_bands
if f_max is None:
f_max = fs / 2
self.filterbank = fbank_fn(n_bands, fs, f_min, f_max, self.max_index)
weight_per_bin = torch.sum(self.filterbank, dim=0, keepdim=True) # (1, n_freqs)
normalized_mel_fb = self.filterbank / weight_per_bin # (n_mels, n_freqs)
freq_weights = []
band_specs = []
for i in range(self.n_bands):
active_bins = torch.nonzero(self.filterbank[i, :]).squeeze().tolist()
if isinstance(active_bins, int):
active_bins = (active_bins, active_bins)
if len(active_bins) == 0:
continue
start_index = active_bins[0]
end_index = active_bins[-1] + 1
band_specs.append((start_index, end_index))
freq_weights.append(normalized_mel_fb[i, start_index:end_index])
self.freq_weights = freq_weights
self.band_specs = band_specs
def get_band_specs(self):
return self.band_specs
def get_freq_weights(self):
return self.freq_weights
def save_to_file(self, dir_path: str) -> None:
os.makedirs(dir_path, exist_ok=True)
import pickle
with open(os.path.join(dir_path, "mel_bandsplit_spec.pkl"), "wb") as f:
pickle.dump(
{
"band_specs": self.band_specs,
"freq_weights": self.freq_weights,
"filterbank": self.filterbank,
},
f,
)
def mel_filterbank(n_bands, fs, f_min, f_max, n_freqs):
fb = taF.melscale_fbanks(
n_mels=n_bands,
sample_rate=fs,
f_min=f_min,
f_max=f_max,
n_freqs=n_freqs,
).T
fb[0, 0] = 1.0
return fb
class MelBandsplitSpecification(PerceptualBandsplitSpecification):
def __init__(
self, nfft: int, fs: int, n_bands: int, f_min: float = 0.0, f_max: float = None
) -> None:
super().__init__(
fbank_fn=mel_filterbank,
nfft=nfft,
fs=fs,
n_bands=n_bands,
f_min=f_min,
f_max=f_max,
)
def musical_filterbank(n_bands, fs, f_min, f_max, n_freqs, scale="constant"):
nfft = 2 * (n_freqs - 1)
df = fs / nfft
# init freqs
f_max = f_max or fs / 2
f_min = f_min or 0
f_min = fs / nfft
n_octaves = np.log2(f_max / f_min)
n_octaves_per_band = n_octaves / n_bands
bandwidth_mult = np.power(2.0, n_octaves_per_band)
low_midi = max(0, hz_to_midi(f_min))
high_midi = hz_to_midi(f_max)
midi_points = np.linspace(low_midi, high_midi, n_bands)
hz_pts = midi_to_hz(midi_points)
low_pts = hz_pts / bandwidth_mult
high_pts = hz_pts * bandwidth_mult
low_bins = np.floor(low_pts / df).astype(int)
high_bins = np.ceil(high_pts / df).astype(int)
fb = np.zeros((n_bands, n_freqs))
for i in range(n_bands):
fb[i, low_bins[i] : high_bins[i] + 1] = 1.0
fb[0, : low_bins[0]] = 1.0
fb[-1, high_bins[-1] + 1 :] = 1.0
return torch.as_tensor(fb)
class MusicalBandsplitSpecification(PerceptualBandsplitSpecification):
def __init__(
self, nfft: int, fs: int, n_bands: int, f_min: float = 0.0, f_max: float = None
) -> None:
super().__init__(
fbank_fn=musical_filterbank,
nfft=nfft,
fs=fs,
n_bands=n_bands,
f_min=f_min,
f_max=f_max,
)
# def bark_filterbank(
# n_bands, fs, f_min, f_max, n_freqs
# ):
# nfft = 2 * (n_freqs -1)
# fb, _ = bark_fbanks.bark_filter_banks(
# nfilts=n_bands,
# nfft=nfft,
# fs=fs,
# low_freq=f_min,
# high_freq=f_max,
# scale="constant"
# )
# return torch.as_tensor(fb)
# class BarkBandsplitSpecification(PerceptualBandsplitSpecification):
# def __init__(
# self,
# nfft: int,
# fs: int,
# n_bands: int,
# f_min: float = 0.0,
# f_max: float = None
# ) -> None:
# super().__init__(fbank_fn=bark_filterbank, nfft=nfft, fs=fs, n_bands=n_bands, f_min=f_min, f_max=f_max)
# def triangular_bark_filterbank(
# n_bands, fs, f_min, f_max, n_freqs
# ):
# all_freqs = torch.linspace(0, fs // 2, n_freqs)
# # calculate mel freq bins
# m_min = hz2bark(f_min)
# m_max = hz2bark(f_max)
# m_pts = torch.linspace(m_min, m_max, n_bands + 2)
# f_pts = 600 * torch.sinh(m_pts / 6)
# # create filterbank
# fb = _create_triangular_filterbank(all_freqs, f_pts)
# fb = fb.T
# first_active_band = torch.nonzero(torch.sum(fb, dim=-1))[0, 0]
# first_active_bin = torch.nonzero(fb[first_active_band, :])[0, 0]
# fb[first_active_band, :first_active_bin] = 1.0
# return fb
# class TriangularBarkBandsplitSpecification(PerceptualBandsplitSpecification):
# def __init__(
# self,
# nfft: int,
# fs: int,
# n_bands: int,
# f_min: float = 0.0,
# f_max: float = None
# ) -> None:
# super().__init__(fbank_fn=triangular_bark_filterbank, nfft=nfft, fs=fs, n_bands=n_bands, f_min=f_min, f_max=f_max)
# def minibark_filterbank(
# n_bands, fs, f_min, f_max, n_freqs
# ):
# fb = bark_filterbank(
# n_bands,
# fs,
# f_min,
# f_max,
# n_freqs
# )
# fb[fb < np.sqrt(0.5)] = 0.0
# return fb
# class MiniBarkBandsplitSpecification(PerceptualBandsplitSpecification):
# def __init__(
# self,
# nfft: int,
# fs: int,
# n_bands: int,
# f_min: float = 0.0,
# f_max: float = None
# ) -> None:
# super().__init__(fbank_fn=minibark_filterbank, nfft=nfft, fs=fs, n_bands=n_bands, f_min=f_min, f_max=f_max)
# def erb_filterbank(
# n_bands: int,
# fs: int,
# f_min: float,
# f_max: float,
# n_freqs: int,
# ) -> Tensor:
# # freq bins
# A = (1000 * np.log(10)) / (24.7 * 4.37)
# all_freqs = torch.linspace(0, fs // 2, n_freqs)
# # calculate mel freq bins
# m_min = hz2erb(f_min)
# m_max = hz2erb(f_max)
# m_pts = torch.linspace(m_min, m_max, n_bands + 2)
# f_pts = (torch.pow(10, (m_pts / A)) - 1)/ 0.00437
# # create filterbank
# fb = _create_triangular_filterbank(all_freqs, f_pts)
# fb = fb.T
# first_active_band = torch.nonzero(torch.sum(fb, dim=-1))[0, 0]
# first_active_bin = torch.nonzero(fb[first_active_band, :])[0, 0]
# fb[first_active_band, :first_active_bin] = 1.0
# return fb
# class EquivalentRectangularBandsplitSpecification(PerceptualBandsplitSpecification):
# def __init__(
# self,
# nfft: int,
# fs: int,
# n_bands: int,
# f_min: float = 0.0,
# f_max: float = None
# ) -> None:
# super().__init__(fbank_fn=erb_filterbank, nfft=nfft, fs=fs, n_bands=n_bands, f_min=f_min, f_max=f_max)
if __name__ == "__main__":
import pandas as pd
band_defs = []
for bands in [VocalBandsplitSpecification]:
band_name = bands.__name__.replace("BandsplitSpecification", "")
mbs = bands(nfft=2048, fs=44100).get_band_specs()
for i, (f_min, f_max) in enumerate(mbs):
band_defs.append(
{"band": band_name, "band_index": i, "f_min": f_min, "f_max": f_max}
)
df = pd.DataFrame(band_defs)
df.to_csv("vox7bands.csv", index=False)