sovits-test / crepe /model.py
atsushieee's picture
Upload folder using huggingface_hub
9791162
raw
history blame
4.54 kB
import functools
import torch
import torch.nn.functional as F
import crepe
###########################################################################
# Model definition
###########################################################################
class Crepe(torch.nn.Module):
"""Crepe model definition"""
def __init__(self, model='full'):
super().__init__()
# Model-specific layer parameters
if model == 'full':
in_channels = [1, 1024, 128, 128, 128, 256]
out_channels = [1024, 128, 128, 128, 256, 512]
self.in_features = 2048
elif model == 'tiny':
in_channels = [1, 128, 16, 16, 16, 32]
out_channels = [128, 16, 16, 16, 32, 64]
self.in_features = 256
else:
raise ValueError(f'Model {model} is not supported')
# Shared layer parameters
kernel_sizes = [(512, 1)] + 5 * [(64, 1)]
strides = [(4, 1)] + 5 * [(1, 1)]
# Overload with eps and momentum conversion given by MMdnn
batch_norm_fn = functools.partial(torch.nn.BatchNorm2d,
eps=0.0010000000474974513,
momentum=0.0)
# Layer definitions
self.conv1 = torch.nn.Conv2d(
in_channels=in_channels[0],
out_channels=out_channels[0],
kernel_size=kernel_sizes[0],
stride=strides[0])
self.conv1_BN = batch_norm_fn(
num_features=out_channels[0])
self.conv2 = torch.nn.Conv2d(
in_channels=in_channels[1],
out_channels=out_channels[1],
kernel_size=kernel_sizes[1],
stride=strides[1])
self.conv2_BN = batch_norm_fn(
num_features=out_channels[1])
self.conv3 = torch.nn.Conv2d(
in_channels=in_channels[2],
out_channels=out_channels[2],
kernel_size=kernel_sizes[2],
stride=strides[2])
self.conv3_BN = batch_norm_fn(
num_features=out_channels[2])
self.conv4 = torch.nn.Conv2d(
in_channels=in_channels[3],
out_channels=out_channels[3],
kernel_size=kernel_sizes[3],
stride=strides[3])
self.conv4_BN = batch_norm_fn(
num_features=out_channels[3])
self.conv5 = torch.nn.Conv2d(
in_channels=in_channels[4],
out_channels=out_channels[4],
kernel_size=kernel_sizes[4],
stride=strides[4])
self.conv5_BN = batch_norm_fn(
num_features=out_channels[4])
self.conv6 = torch.nn.Conv2d(
in_channels=in_channels[5],
out_channels=out_channels[5],
kernel_size=kernel_sizes[5],
stride=strides[5])
self.conv6_BN = batch_norm_fn(
num_features=out_channels[5])
self.classifier = torch.nn.Linear(
in_features=self.in_features,
out_features=crepe.PITCH_BINS)
def forward(self, x, embed=False):
# Forward pass through first five layers
x = self.embed(x)
if embed:
return x
# Forward pass through layer six
x = self.layer(x, self.conv6, self.conv6_BN)
# shape=(batch, self.in_features)
x = x.permute(0, 2, 1, 3).reshape(-1, self.in_features)
# Compute logits
return torch.sigmoid(self.classifier(x))
###########################################################################
# Forward pass utilities
###########################################################################
def embed(self, x):
"""Map input audio to pitch embedding"""
# shape=(batch, 1, 1024, 1)
x = x[:, None, :, None]
# Forward pass through first five layers
x = self.layer(x, self.conv1, self.conv1_BN, (0, 0, 254, 254))
x = self.layer(x, self.conv2, self.conv2_BN)
x = self.layer(x, self.conv3, self.conv3_BN)
x = self.layer(x, self.conv4, self.conv4_BN)
x = self.layer(x, self.conv5, self.conv5_BN)
return x
def layer(self, x, conv, batch_norm, padding=(0, 0, 31, 32)):
"""Forward pass through one layer"""
x = F.pad(x, padding)
x = conv(x)
x = F.relu(x)
x = batch_norm(x)
return F.max_pool2d(x, (2, 1), (2, 1))