Spaces:
Runtime error
Runtime error
| ''' | |
| @File : subband_util.py | |
| @Contact : [email protected] | |
| @License : (C)Copyright 2020-2021 | |
| @Modify Time @Author @Version @Desciption | |
| ------------ ------- -------- ----------- | |
| 2020/4/3 4:54 PM Haohe Liu 1.0 None | |
| ''' | |
| import torch | |
| import torch.nn.functional as F | |
| import torch.nn as nn | |
| import numpy as np | |
| import os.path as op | |
| from scipy.io import loadmat | |
| def load_mat2numpy(fname=""): | |
| ''' | |
| Args: | |
| fname: pth to mat | |
| type: | |
| Returns: dic object | |
| ''' | |
| if len(fname) == 0: | |
| return None | |
| else: | |
| return loadmat(fname) | |
| class PQMF(nn.Module): | |
| def __init__(self, N, M, project_root): | |
| super().__init__() | |
| self.N = N # nsubband | |
| self.M = M # nfilter | |
| try: | |
| assert (N, M) in [(8, 64), (4, 64), (2, 64)] | |
| except: | |
| print("Warning:", N, "subbandand ", M, " filter is not supported") | |
| self.pad_samples = 64 | |
| self.name = str(N) + "_" + str(M) + ".mat" | |
| self.ana_conv_filter = nn.Conv1d( | |
| 1, out_channels=N, kernel_size=M, stride=N, bias=False | |
| ) | |
| data = load_mat2numpy(op.join(project_root, "f_" + self.name)) | |
| data = data['f'].astype(np.float32) / N | |
| data = np.flipud(data.T).T | |
| data = np.reshape(data, (N, 1, M)).copy() | |
| dict_new = self.ana_conv_filter.state_dict().copy() | |
| dict_new['weight'] = torch.from_numpy(data) | |
| self.ana_pad = nn.ConstantPad1d((M - N, 0), 0) | |
| self.ana_conv_filter.load_state_dict(dict_new) | |
| self.syn_pad = nn.ConstantPad1d((0, M // N - 1), 0) | |
| self.syn_conv_filter = nn.Conv1d( | |
| N, out_channels=N, kernel_size=M // N, stride=1, bias=False | |
| ) | |
| gk = load_mat2numpy(op.join(project_root, "h_" + self.name)) | |
| gk = gk['h'].astype(np.float32) | |
| gk = np.transpose(np.reshape(gk, (N, M // N, N)), (1, 0, 2)) * N | |
| gk = np.transpose(gk[::-1, :, :], (2, 1, 0)).copy() | |
| dict_new = self.syn_conv_filter.state_dict().copy() | |
| dict_new['weight'] = torch.from_numpy(gk) | |
| self.syn_conv_filter.load_state_dict(dict_new) | |
| for param in self.parameters(): | |
| param.requires_grad = False | |
| def __analysis_channel(self, inputs): | |
| return self.ana_conv_filter(self.ana_pad(inputs)) | |
| def __systhesis_channel(self, inputs): | |
| ret = self.syn_conv_filter(self.syn_pad(inputs)).permute(0, 2, 1) | |
| return torch.reshape(ret, (ret.shape[0], 1, -1)) | |
| def analysis(self, inputs): | |
| ''' | |
| :param inputs: [batchsize,channel,raw_wav],value:[0,1] | |
| :return: | |
| ''' | |
| inputs = F.pad(inputs, ((0, self.pad_samples))) | |
| ret = None | |
| for i in range(inputs.size()[1]): # channels | |
| if ret is None: | |
| ret = self.__analysis_channel(inputs[:, i : i + 1, :]) | |
| else: | |
| ret = torch.cat( | |
| (ret, self.__analysis_channel(inputs[:, i : i + 1, :])), dim=1 | |
| ) | |
| return ret | |
| def synthesis(self, data): | |
| ''' | |
| :param data: [batchsize,self.N*K,raw_wav_sub],value:[0,1] | |
| :return: | |
| ''' | |
| ret = None | |
| # data = F.pad(data,((0,self.pad_samples//self.N))) | |
| for i in range(data.size()[1]): # channels | |
| if i % self.N == 0: | |
| if ret is None: | |
| ret = self.__systhesis_channel(data[:, i : i + self.N, :]) | |
| else: | |
| new = self.__systhesis_channel(data[:, i : i + self.N, :]) | |
| ret = torch.cat((ret, new), dim=1) | |
| ret = ret[..., : -self.pad_samples] | |
| return ret | |
| def forward(self, inputs): | |
| return self.ana_conv_filter(self.ana_pad(inputs)) | |
| if __name__ == "__main__": | |
| import torch | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| from tools.file.wav import * | |
| pqmf = PQMF(N=4, M=64, project_root="/Users/admin/Documents/projects") | |
| rs = np.random.RandomState(0) | |
| x = torch.tensor(rs.rand(4, 2, 32000), dtype=torch.float32) | |
| a1 = pqmf.analysis(x) | |
| a2 = pqmf.synthesis(a1) | |
| print(a2.size(), x.size()) | |
| plt.subplot(211) | |
| plt.plot(x[0, 0, -500:]) | |
| plt.subplot(212) | |
| plt.plot(a2[0, 0, -500:]) | |
| plt.plot(x[0, 0, -500:] - a2[0, 0, -500:]) | |
| plt.show() | |
| print(torch.sum(torch.abs(x[...] - a2[...]))) | |