R-Kentaren commited on
Commit
5c76acc
·
verified ·
1 Parent(s): b2b82a4

Create rmvpe.py

Browse files
Files changed (1) hide show
  1. pitch/rmvpe.py +614 -0
pitch/rmvpe.py ADDED
@@ -0,0 +1,614 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # These modules are licensed under the MIT License.
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from librosa.filters import mel
9
+ from librosa.util import pad_center, tiny, normalize
10
+ from scipy.signal import get_window
11
+
12
+
13
+ ###stft codes from https://github.com/pseeth/torch-stft/blob/master/torch_stft/util.py
14
+ def window_sumsquare(
15
+ window,
16
+ n_frames,
17
+ hop_length=200,
18
+ win_length=800,
19
+ n_fft=800,
20
+ dtype=np.float32,
21
+ norm=None,
22
+ ):
23
+ """
24
+ # from librosa 0.6
25
+ Compute the sum-square envelope of a window function at a given hop length.
26
+ This is used to estimate modulation effects induced by windowing
27
+ observations in short-time fourier transforms.
28
+ Parameters
29
+ ----------
30
+ window : string, tuple, number, callable, or list-like
31
+ Window specification, as in `get_window`
32
+ n_frames : int > 0
33
+ The number of analysis frames
34
+ hop_length : int > 0
35
+ The number of samples to advance between frames
36
+ win_length : [optional]
37
+ The length of the window function. By default, this matches `n_fft`.
38
+ n_fft : int > 0
39
+ The length of each analysis frame.
40
+ dtype : np.dtype
41
+ The data type of the output
42
+ Returns
43
+ -------
44
+ wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))`
45
+ The sum-squared envelope of the window function
46
+ """
47
+ if win_length is None:
48
+ win_length = n_fft
49
+
50
+ n = n_fft + hop_length * (n_frames - 1)
51
+ x = np.zeros(n, dtype=dtype)
52
+
53
+ # Compute the squared window at the desired length
54
+ win_sq = get_window(window, win_length, fftbins=True)
55
+ win_sq = normalize(win_sq, norm=norm) ** 2
56
+ win_sq = pad_center(win_sq, n_fft)
57
+
58
+ # Fill the envelope
59
+ for i in range(n_frames):
60
+ sample = i * hop_length
61
+ x[sample : min(n, sample + n_fft)] += win_sq[: max(0, min(n_fft, n - sample))]
62
+ return x
63
+
64
+
65
+ class STFT(torch.nn.Module):
66
+ def __init__(
67
+ self, filter_length=1024, hop_length=512, win_length=None, window="hann"
68
+ ):
69
+ """
70
+ This module implements an STFT using 1D convolution and 1D transpose convolutions.
71
+ This is a bit tricky so there are some cases that probably won't work as working
72
+ out the same sizes before and after in all overlap add setups is tough. Right now,
73
+ this code should work with hop lengths that are half the filter length (50% overlap
74
+ between frames).
75
+
76
+ Keyword Arguments:
77
+ filter_length {int} -- Length of filters used (default: {1024})
78
+ hop_length {int} -- Hop length of STFT (restrict to 50% overlap between frames) (default: {512})
79
+ win_length {[type]} -- Length of the window function applied to each frame (if not specified, it
80
+ equals the filter length). (default: {None})
81
+ window {str} -- Type of window to use (options are bartlett, hann, hamming, blackman, blackmanharris)
82
+ (default: {'hann'})
83
+ """
84
+ super(STFT, self).__init__()
85
+ self.filter_length = filter_length
86
+ self.hop_length = hop_length
87
+ self.win_length = win_length or filter_length
88
+ self.window = window
89
+ self.forward_transform = None
90
+ self.pad_amount = int(self.filter_length / 2)
91
+ scale = self.filter_length / self.hop_length
92
+ fourier_basis = np.fft.fft(np.eye(self.filter_length))
93
+
94
+ cutoff = int((self.filter_length / 2 + 1))
95
+ fourier_basis = np.vstack(
96
+ [np.real(fourier_basis[:cutoff, :]), np.imag(fourier_basis[:cutoff, :])]
97
+ )
98
+ forward_basis = torch.FloatTensor(fourier_basis[:, None, :])
99
+ inverse_basis = torch.FloatTensor(
100
+ np.linalg.pinv(scale * fourier_basis).T[:, None, :]
101
+ )
102
+
103
+ assert filter_length >= self.win_length
104
+ # get window and zero center pad it to filter_length
105
+ fft_window = get_window(window, self.win_length, fftbins=True)
106
+ fft_window = pad_center(fft_window, size=filter_length)
107
+ fft_window = torch.from_numpy(fft_window).float()
108
+
109
+ # window the bases
110
+ forward_basis *= fft_window
111
+ inverse_basis *= fft_window
112
+
113
+ self.register_buffer("forward_basis", forward_basis.float())
114
+ self.register_buffer("inverse_basis", inverse_basis.float())
115
+
116
+ def transform(self, input_data):
117
+ """Take input data (audio) to STFT domain.
118
+
119
+ Arguments:
120
+ input_data {tensor} -- Tensor of floats, with shape (num_batch, num_samples)
121
+
122
+ Returns:
123
+ magnitude {tensor} -- Magnitude of STFT with shape (num_batch,
124
+ num_frequencies, num_frames)
125
+ phase {tensor} -- Phase of STFT with shape (num_batch,
126
+ num_frequencies, num_frames)
127
+ """
128
+ num_batches = input_data.shape[0]
129
+ num_samples = input_data.shape[-1]
130
+
131
+ self.num_samples = num_samples
132
+
133
+ # similar to librosa, reflect-pad the input
134
+ input_data = input_data.view(num_batches, 1, num_samples)
135
+ # print(1234,input_data.shape)
136
+ input_data = F.pad(
137
+ input_data.unsqueeze(1),
138
+ (self.pad_amount, self.pad_amount, 0, 0, 0, 0),
139
+ mode="reflect",
140
+ ).squeeze(1)
141
+
142
+ forward_transform = F.conv1d(
143
+ input_data, self.forward_basis, stride=self.hop_length, padding=0
144
+ )
145
+
146
+ cutoff = int((self.filter_length / 2) + 1)
147
+ real_part = forward_transform[:, :cutoff, :]
148
+ imag_part = forward_transform[:, cutoff:, :]
149
+
150
+ return torch.sqrt(real_part**2 + imag_part**2)
151
+
152
+ def inverse(self, magnitude, phase):
153
+ """Call the inverse STFT (iSTFT), given magnitude and phase tensors produced
154
+ by the ```transform``` function.
155
+
156
+ Arguments:
157
+ magnitude {tensor} -- Magnitude of STFT with shape (num_batch,
158
+ num_frequencies, num_frames)
159
+ phase {tensor} -- Phase of STFT with shape (num_batch,
160
+ num_frequencies, num_frames)
161
+
162
+ Returns:
163
+ inverse_transform {tensor} -- Reconstructed audio given magnitude and phase. Of
164
+ shape (num_batch, num_samples)
165
+ """
166
+ recombine_magnitude_phase = torch.cat(
167
+ [magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1
168
+ )
169
+
170
+ inverse_transform = F.conv_transpose1d(
171
+ recombine_magnitude_phase,
172
+ self.inverse_basis,
173
+ stride=self.hop_length,
174
+ padding=0,
175
+ )
176
+
177
+ if self.window is not None:
178
+ window_sum = window_sumsquare(
179
+ self.window,
180
+ magnitude.size(-1),
181
+ hop_length=self.hop_length,
182
+ win_length=self.win_length,
183
+ n_fft=self.filter_length,
184
+ dtype=np.float32,
185
+ )
186
+ # remove modulation effects
187
+ approx_nonzero_indices = torch.from_numpy(
188
+ np.where(window_sum > tiny(window_sum))[0]
189
+ )
190
+ window_sum = torch.from_numpy(window_sum).to(inverse_transform.device)
191
+ inverse_transform[:, :, approx_nonzero_indices] /= window_sum[
192
+ approx_nonzero_indices
193
+ ]
194
+
195
+ # scale by hop ratio
196
+ inverse_transform *= float(self.filter_length) / self.hop_length
197
+
198
+ inverse_transform = inverse_transform[..., self.pad_amount :]
199
+ inverse_transform = inverse_transform[..., : self.num_samples]
200
+ return inverse_transform.squeeze(1)
201
+
202
+ def forward(self, input_data):
203
+ """Take input data (audio) to STFT domain and then back to audio.
204
+
205
+ Arguments:
206
+ input_data {tensor} -- Tensor of floats, with shape (num_batch, num_samples)
207
+
208
+ Returns:
209
+ reconstruction {tensor} -- Reconstructed audio given magnitude and phase. Of
210
+ shape (num_batch, num_samples)
211
+ """
212
+ self.magnitude, self.phase = self.transform(input_data)
213
+ return self.inverse(self.magnitude, self.phase)
214
+
215
+
216
+ class BiGRU(nn.Module):
217
+ def __init__(self, input_features, hidden_features, num_layers):
218
+ super(BiGRU, self).__init__()
219
+ self.gru = nn.GRU(
220
+ input_features,
221
+ hidden_features,
222
+ num_layers=num_layers,
223
+ batch_first=True,
224
+ bidirectional=True,
225
+ )
226
+
227
+ def forward(self, x):
228
+ return self.gru(x)[0]
229
+
230
+
231
+ class ConvBlockRes(nn.Module):
232
+ def __init__(self, in_channels, out_channels, momentum=0.01):
233
+ super(ConvBlockRes, self).__init__()
234
+ self.conv = nn.Sequential(
235
+ nn.Conv2d(
236
+ in_channels=in_channels,
237
+ out_channels=out_channels,
238
+ kernel_size=(3, 3),
239
+ stride=(1, 1),
240
+ padding=(1, 1),
241
+ bias=False,
242
+ ),
243
+ nn.BatchNorm2d(out_channels, momentum=momentum),
244
+ nn.ReLU(),
245
+ nn.Conv2d(
246
+ in_channels=out_channels,
247
+ out_channels=out_channels,
248
+ kernel_size=(3, 3),
249
+ stride=(1, 1),
250
+ padding=(1, 1),
251
+ bias=False,
252
+ ),
253
+ nn.BatchNorm2d(out_channels, momentum=momentum),
254
+ nn.ReLU(),
255
+ )
256
+ if in_channels != out_channels:
257
+ self.shortcut = nn.Conv2d(in_channels, out_channels, (1, 1))
258
+ self.is_shortcut = True
259
+ else:
260
+ self.is_shortcut = False
261
+
262
+ def forward(self, x):
263
+ if self.is_shortcut:
264
+ return self.conv(x) + self.shortcut(x)
265
+ else:
266
+ return self.conv(x) + x
267
+
268
+
269
+ class Encoder(nn.Module):
270
+ def __init__(
271
+ self,
272
+ in_channels,
273
+ in_size,
274
+ n_encoders,
275
+ kernel_size,
276
+ n_blocks,
277
+ out_channels=16,
278
+ momentum=0.01,
279
+ ):
280
+ super(Encoder, self).__init__()
281
+ self.n_encoders = n_encoders
282
+ self.bn = nn.BatchNorm2d(in_channels, momentum=momentum)
283
+ self.layers = nn.ModuleList()
284
+ self.latent_channels = []
285
+ for _ in range(self.n_encoders):
286
+ self.layers.append(
287
+ ResEncoderBlock(
288
+ in_channels, out_channels, kernel_size, n_blocks, momentum=momentum
289
+ )
290
+ )
291
+ self.latent_channels.append([out_channels, in_size])
292
+ in_channels = out_channels
293
+ out_channels *= 2
294
+ in_size //= 2
295
+ self.out_size = in_size
296
+ self.out_channel = out_channels
297
+
298
+ def forward(self, x):
299
+ concat_tensors = []
300
+ x = self.bn(x)
301
+ for i in range(self.n_encoders):
302
+ _, x = self.layers[i](x)
303
+ concat_tensors.append(_)
304
+ return x, concat_tensors
305
+
306
+
307
+ class ResEncoderBlock(nn.Module):
308
+ def __init__(
309
+ self, in_channels, out_channels, kernel_size, n_blocks=1, momentum=0.01
310
+ ):
311
+ super(ResEncoderBlock, self).__init__()
312
+ self.n_blocks = n_blocks
313
+ self.conv = nn.ModuleList()
314
+ self.conv.append(ConvBlockRes(in_channels, out_channels, momentum))
315
+ for _ in range(n_blocks - 1):
316
+ self.conv.append(ConvBlockRes(out_channels, out_channels, momentum))
317
+ self.kernel_size = kernel_size
318
+ if self.kernel_size is not None:
319
+ self.pool = nn.AvgPool2d(kernel_size=kernel_size)
320
+
321
+ def forward(self, x):
322
+ for i in range(self.n_blocks):
323
+ x = self.conv[i](x)
324
+ return (x, self.pool(x)) if self.kernel_size is not None else x
325
+
326
+
327
+ class Intermediate(nn.Module): #
328
+ def __init__(self, in_channels, out_channels, n_inters, n_blocks, momentum=0.01):
329
+ super(Intermediate, self).__init__()
330
+ self.n_inters = n_inters
331
+ self.layers = nn.ModuleList()
332
+ self.layers.append(
333
+ ResEncoderBlock(in_channels, out_channels, None, n_blocks, momentum)
334
+ )
335
+ for _ in range(self.n_inters - 1):
336
+ self.layers.append(
337
+ ResEncoderBlock(out_channels, out_channels, None, n_blocks, momentum)
338
+ )
339
+
340
+ def forward(self, x):
341
+ for i in range(self.n_inters):
342
+ x = self.layers[i](x)
343
+ return x
344
+
345
+
346
+ class ResDecoderBlock(nn.Module):
347
+ def __init__(self, in_channels, out_channels, stride, n_blocks=1, momentum=0.01):
348
+ super(ResDecoderBlock, self).__init__()
349
+ out_padding = (0, 1) if stride == (1, 2) else (1, 1)
350
+ self.n_blocks = n_blocks
351
+ self.conv1 = nn.Sequential(
352
+ nn.ConvTranspose2d(
353
+ in_channels=in_channels,
354
+ out_channels=out_channels,
355
+ kernel_size=(3, 3),
356
+ stride=stride,
357
+ padding=(1, 1),
358
+ output_padding=out_padding,
359
+ bias=False,
360
+ ),
361
+ nn.BatchNorm2d(out_channels, momentum=momentum),
362
+ nn.ReLU(),
363
+ )
364
+ self.conv2 = nn.ModuleList()
365
+ self.conv2.append(ConvBlockRes(out_channels * 2, out_channels, momentum))
366
+ for _ in range(n_blocks - 1):
367
+ self.conv2.append(ConvBlockRes(out_channels, out_channels, momentum))
368
+
369
+ def forward(self, x, concat_tensor):
370
+ x = self.conv1(x)
371
+ x = torch.cat((x, concat_tensor), dim=1)
372
+ for i in range(self.n_blocks):
373
+ x = self.conv2[i](x)
374
+ return x
375
+
376
+
377
+ class Decoder(nn.Module):
378
+ def __init__(self, in_channels, n_decoders, stride, n_blocks, momentum=0.01):
379
+ super(Decoder, self).__init__()
380
+ self.layers = nn.ModuleList()
381
+ self.n_decoders = n_decoders
382
+ for _ in range(self.n_decoders):
383
+ out_channels = in_channels // 2
384
+ self.layers.append(
385
+ ResDecoderBlock(in_channels, out_channels, stride, n_blocks, momentum)
386
+ )
387
+ in_channels = out_channels
388
+
389
+ def forward(self, x, concat_tensors):
390
+ for i in range(self.n_decoders):
391
+ x = self.layers[i](x, concat_tensors[-1 - i])
392
+ return x
393
+
394
+
395
+ class DeepUnet(nn.Module):
396
+ def __init__(
397
+ self,
398
+ kernel_size,
399
+ n_blocks,
400
+ en_de_layers=5,
401
+ inter_layers=4,
402
+ in_channels=1,
403
+ en_out_channels=16,
404
+ ):
405
+ super(DeepUnet, self).__init__()
406
+ self.encoder = Encoder(
407
+ in_channels, 128, en_de_layers, kernel_size, n_blocks, en_out_channels
408
+ )
409
+ self.intermediate = Intermediate(
410
+ self.encoder.out_channel // 2,
411
+ self.encoder.out_channel,
412
+ inter_layers,
413
+ n_blocks,
414
+ )
415
+ self.decoder = Decoder(
416
+ self.encoder.out_channel, en_de_layers, kernel_size, n_blocks
417
+ )
418
+
419
+ def forward(self, x):
420
+ x, concat_tensors = self.encoder(x)
421
+ x = self.intermediate(x)
422
+ x = self.decoder(x, concat_tensors)
423
+ return x
424
+
425
+
426
+ class E2E(nn.Module):
427
+ def __init__(
428
+ self,
429
+ n_blocks,
430
+ n_gru,
431
+ kernel_size,
432
+ en_de_layers=5,
433
+ inter_layers=4,
434
+ in_channels=1,
435
+ en_out_channels=16,
436
+ ):
437
+ super(E2E, self).__init__()
438
+ self.unet = DeepUnet(
439
+ kernel_size,
440
+ n_blocks,
441
+ en_de_layers,
442
+ inter_layers,
443
+ in_channels,
444
+ en_out_channels,
445
+ )
446
+ self.cnn = nn.Conv2d(en_out_channels, 3, (3, 3), padding=(1, 1))
447
+ if n_gru:
448
+ self.fc = nn.Sequential(
449
+ BiGRU(3 * 128, 256, n_gru),
450
+ nn.Linear(512, 360),
451
+ nn.Dropout(0.25),
452
+ nn.Sigmoid(),
453
+ )
454
+ else:
455
+ self.fc = nn.Sequential(
456
+ nn.Linear(3 * nn.N_MELS, nn.N_CLASS), nn.Dropout(0.25), nn.Sigmoid()
457
+ )
458
+
459
+ def forward(self, mel):
460
+ # print(mel.shape)
461
+ mel = mel.transpose(-1, -2).unsqueeze(1)
462
+ x = self.cnn(self.unet(mel)).transpose(1, 2).flatten(-2)
463
+ x = self.fc(x)
464
+ return x
465
+
466
+
467
+
468
+
469
+ class MelSpectrogram(torch.nn.Module):
470
+ def __init__(
471
+ self,
472
+ is_half,
473
+ n_mel_channels,
474
+ sampling_rate,
475
+ win_length,
476
+ hop_length,
477
+ n_fft=None,
478
+ mel_fmin=0,
479
+ mel_fmax=None,
480
+ clamp=1e-5,
481
+ ):
482
+ super().__init__()
483
+ n_fft = win_length if n_fft is None else n_fft
484
+ self.hann_window = {}
485
+ mel_basis = mel(
486
+ sr=sampling_rate,
487
+ n_fft=n_fft,
488
+ n_mels=n_mel_channels,
489
+ fmin=mel_fmin,
490
+ fmax=mel_fmax,
491
+ htk=True,
492
+ )
493
+ mel_basis = torch.from_numpy(mel_basis).float()
494
+ self.register_buffer("mel_basis", mel_basis)
495
+ self.n_fft = win_length if n_fft is None else n_fft
496
+ self.hop_length = hop_length
497
+ self.win_length = win_length
498
+ self.sampling_rate = sampling_rate
499
+ self.n_mel_channels = n_mel_channels
500
+ self.clamp = clamp
501
+ self.is_half = is_half
502
+
503
+ def forward(self, audio, keyshift=0, speed=1, center=True):
504
+ factor = 2 ** (keyshift / 12)
505
+ n_fft_new = int(np.round(self.n_fft * factor))
506
+ win_length_new = int(np.round(self.win_length * factor))
507
+ hop_length_new = int(np.round(self.hop_length * speed))
508
+ keyshift_key = f"{str(keyshift)}_{str(audio.device)}"
509
+ if keyshift_key not in self.hann_window:
510
+ self.hann_window[keyshift_key] = torch.hann_window(win_length_new).to(
511
+ audio.device
512
+ )
513
+ if not hasattr(self, "stft"):
514
+ self.stft = STFT(
515
+ filter_length=n_fft_new,
516
+ hop_length=hop_length_new,
517
+ win_length=win_length_new,
518
+ window="hann",
519
+ ).to(audio.device)
520
+ magnitude = self.stft.transform(audio) # phase
521
+
522
+ if keyshift != 0:
523
+ size = self.n_fft // 2 + 1
524
+ resize = magnitude.size(1)
525
+ if resize < size:
526
+ magnitude = F.pad(magnitude, (0, 0, 0, size - resize))
527
+ magnitude = magnitude[:, :size, :] * self.win_length / win_length_new
528
+ mel_output = torch.matmul(self.mel_basis, magnitude)
529
+ if self.is_half is True:
530
+ mel_output = mel_output.half()
531
+
532
+ return torch.log(torch.clamp(mel_output, min=self.clamp))
533
+
534
+
535
+ class RMVPE:
536
+ def __init__(
537
+ self,
538
+ model_path: str,
539
+ is_half: bool,
540
+ hop_length: int = 160,
541
+ mel_fmin: float = 30,
542
+ mel_fmax: float = 8000,
543
+ device: str | None = None,
544
+ ):
545
+ self.is_half = is_half
546
+ if device is None:
547
+ device = "cuda" if torch.cuda.is_available() else "cpu"
548
+ self.device = device
549
+ self.mel_extractor = MelSpectrogram(
550
+ is_half, 128, 16000, 1024, hop_length, None, mel_fmin, mel_fmax
551
+ ).to(device)
552
+
553
+ model = E2E(4, 1, (2, 2))
554
+ ckpt = torch.load(model_path, map_location="cpu")
555
+ model.load_state_dict(ckpt)
556
+ model.eval()
557
+ if is_half:
558
+ model = model.half()
559
+ self.model = model
560
+ self.model = self.model.to(device)
561
+ cents_mapping = 20 * np.arange(360) + 1997.3794084376191
562
+ self.cents_mapping = np.pad(cents_mapping, (4, 4)) # 368
563
+
564
+ def mel2hidden(self, mel):
565
+ with torch.no_grad():
566
+ n_frames = mel.shape[-1]
567
+ mel = F.pad(
568
+ mel, (0, 32 * ((n_frames - 1) // 32 + 1) - n_frames), mode="reflect"
569
+ )
570
+ hidden = self.model(mel)
571
+ return hidden[:, :n_frames]
572
+
573
+ def decode(self, hidden, thred=0.03):
574
+ cents_pred = self.to_local_average_cents(hidden, thred=thred)
575
+ f0 = 10 * (2 ** (cents_pred / 1200))
576
+ f0[f0 == 10] = 0
577
+ return f0
578
+
579
+ def infer_from_audio(self, audio: np.ndarray, thred: float = 0.03):
580
+ mel = self.mel_extractor(
581
+ torch.from_numpy(audio).float().to(self.device).unsqueeze(0), center=True
582
+ )
583
+
584
+ hidden = self.mel2hidden(mel)
585
+
586
+ hidden = hidden.squeeze(0).cpu().numpy()
587
+ if self.is_half is True:
588
+ hidden = hidden.astype("float32")
589
+
590
+ return self.decode(hidden, thred=thred)
591
+
592
+ def to_local_average_cents(self, salience, thred=0.05):
593
+ center = np.argmax(salience, axis=1)
594
+ salience = np.pad(salience, ((0, 0), (4, 4)))
595
+
596
+ center += 4
597
+ todo_salience = []
598
+ todo_cents_mapping = []
599
+ starts = center - 4
600
+ ends = center + 5
601
+ for idx in range(salience.shape[0]):
602
+ todo_salience.append(salience[:, starts[idx] : ends[idx]][idx])
603
+ todo_cents_mapping.append(self.cents_mapping[starts[idx] : ends[idx]])
604
+
605
+ todo_salience = np.array(todo_salience)
606
+ todo_cents_mapping = np.array(todo_cents_mapping)
607
+ product_sum = np.sum(todo_salience * todo_cents_mapping, 1)
608
+ weight_sum = np.sum(todo_salience, 1)
609
+ devided = product_sum / weight_sum
610
+
611
+ maxx = np.max(salience, axis=1)
612
+ devided[maxx <= thred] = 0
613
+
614
+ return devided