Spaces:
Running
Running
File size: 1,548 Bytes
bb70eb3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 |
import torch
class STFT:
def __init__(self, n_fft, hop_length, dim_f):
self.n_fft = n_fft
self.hop_length = hop_length
self.window = torch.hann_window(window_length=n_fft, periodic=True)
self.dim_f = dim_f
def __call__(self, x):
window = self.window.to(x.device)
batch_dims = x.shape[:-2]
c, t = x.shape[-2:]
x = x.reshape([-1, t])
x = torch.stft(
x,
n_fft=self.n_fft,
hop_length=self.hop_length,
window=window,
center=True,
return_complex=True,
)
x = torch.view_as_real(x)
x = x.permute([0, 3, 1, 2])
x = x.reshape([*batch_dims, c, 2, -1, x.shape[-1]]).reshape(
[*batch_dims, c * 2, -1, x.shape[-1]]
)
return x[..., : self.dim_f, :]
def inverse(self, x):
window = self.window.to(x.device)
batch_dims = x.shape[:-3]
c, f, t = x.shape[-3:]
n = self.n_fft // 2 + 1
f_pad = torch.zeros([*batch_dims, c, n - f, t]).to(x.device)
x = torch.cat([x, f_pad], -2)
x = x.reshape([*batch_dims, c // 2, 2, n, t]).reshape([-1, 2, n, t])
x = x.permute([0, 2, 3, 1])
x = x.contiguous()
t_complex = torch.view_as_complex(x)
x = torch.istft(
t_complex,
n_fft=self.n_fft,
hop_length=self.hop_length,
window=window,
center=True,
)
x = x.reshape([*batch_dims, 2, -1])
return x
|