Spaces:
Runtime error
Runtime error
""" | |
This file contains implementation of SpecRNet architecture. | |
We base our codebase on the implementation of RawNet2 by Hemlata Tak ([email protected]). | |
It is available here: https://github.com/asvspoof-challenge/2021/blob/main/LA/Baseline-RawNet2/model.py | |
""" | |
from typing import Dict | |
import torch.nn as nn | |
from src import frontends | |
def get_config(input_channels: int) -> Dict: | |
return { | |
"filts": [input_channels, [input_channels, 20], [20, 64], [64, 64]], | |
"nb_fc_node": 64, | |
"gru_node": 64, | |
"nb_gru_layer": 2, | |
"nb_classes": 1, | |
} | |
class Residual_block2D(nn.Module): | |
def __init__(self, nb_filts, first=False): | |
super().__init__() | |
self.first = first | |
if not self.first: | |
self.bn1 = nn.BatchNorm2d(num_features=nb_filts[0]) | |
self.lrelu = nn.LeakyReLU(negative_slope=0.3) | |
self.conv1 = nn.Conv2d( | |
in_channels=nb_filts[0], | |
out_channels=nb_filts[1], | |
kernel_size=3, | |
padding=1, | |
stride=1, | |
) | |
self.bn2 = nn.BatchNorm2d(num_features=nb_filts[1]) | |
self.conv2 = nn.Conv2d( | |
in_channels=nb_filts[1], | |
out_channels=nb_filts[1], | |
padding=1, | |
kernel_size=3, | |
stride=1, | |
) | |
if nb_filts[0] != nb_filts[1]: | |
self.downsample = True | |
self.conv_downsample = nn.Conv2d( | |
in_channels=nb_filts[0], | |
out_channels=nb_filts[1], | |
padding=0, | |
kernel_size=1, | |
stride=1, | |
) | |
else: | |
self.downsample = False | |
self.mp = nn.MaxPool2d(2) | |
def forward(self, x): | |
identity = x | |
if not self.first: | |
out = self.bn1(x) | |
out = self.lrelu(out) | |
else: | |
out = x | |
out = self.conv1(x) | |
out = self.bn2(out) | |
out = self.lrelu(out) | |
out = self.conv2(out) | |
if self.downsample: | |
identity = self.conv_downsample(identity) | |
out += identity | |
out = self.mp(out) | |
return out | |
class SpecRNet(nn.Module): | |
def __init__(self, input_channels, **kwargs): | |
super().__init__() | |
config = get_config(input_channels=input_channels) | |
self.device = kwargs.get("device", "cuda") | |
self.first_bn = nn.BatchNorm2d(num_features=config["filts"][0]) | |
self.selu = nn.SELU(inplace=True) | |
self.block0 = nn.Sequential( | |
Residual_block2D(nb_filts=config["filts"][1], first=True) | |
) | |
self.block2 = nn.Sequential(Residual_block2D(nb_filts=config["filts"][2])) | |
config["filts"][2][0] = config["filts"][2][1] | |
self.block4 = nn.Sequential(Residual_block2D(nb_filts=config["filts"][2])) | |
self.avgpool = nn.AdaptiveAvgPool2d(1) | |
self.fc_attention0 = self._make_attention_fc( | |
in_features=config["filts"][1][-1], l_out_features=config["filts"][1][-1] | |
) | |
self.fc_attention2 = self._make_attention_fc( | |
in_features=config["filts"][2][-1], l_out_features=config["filts"][2][-1] | |
) | |
self.fc_attention4 = self._make_attention_fc( | |
in_features=config["filts"][2][-1], l_out_features=config["filts"][2][-1] | |
) | |
self.bn_before_gru = nn.BatchNorm2d(num_features=config["filts"][2][-1]) | |
self.gru = nn.GRU( | |
input_size=config["filts"][2][-1], | |
hidden_size=config["gru_node"], | |
num_layers=config["nb_gru_layer"], | |
batch_first=True, | |
bidirectional=True, | |
) | |
self.fc1_gru = nn.Linear( | |
in_features=config["gru_node"] * 2, out_features=config["nb_fc_node"] * 2 | |
) | |
self.fc2_gru = nn.Linear( | |
in_features=config["nb_fc_node"] * 2, | |
out_features=config["nb_classes"], | |
bias=True, | |
) | |
self.sig = nn.Sigmoid() | |
def _compute_embedding(self, x): | |
x = self.first_bn(x) | |
x = self.selu(x) | |
x0 = self.block0(x) | |
y0 = self.avgpool(x0).view(x0.size(0), -1) | |
y0 = self.fc_attention0(y0) | |
y0 = self.sig(y0).view(y0.size(0), y0.size(1), -1) | |
y0 = y0.unsqueeze(-1) | |
x = x0 * y0 + y0 | |
x = nn.MaxPool2d(2)(x) | |
x2 = self.block2(x) | |
y2 = self.avgpool(x2).view(x2.size(0), -1) | |
y2 = self.fc_attention2(y2) | |
y2 = self.sig(y2).view(y2.size(0), y2.size(1), -1) | |
y2 = y2.unsqueeze(-1) | |
x = x2 * y2 + y2 | |
x = nn.MaxPool2d(2)(x) | |
x4 = self.block4(x) | |
y4 = self.avgpool(x4).view(x4.size(0), -1) | |
y4 = self.fc_attention4(y4) | |
y4 = self.sig(y4).view(y4.size(0), y4.size(1), -1) | |
y4 = y4.unsqueeze(-1) | |
x = x4 * y4 + y4 | |
x = nn.MaxPool2d(2)(x) | |
x = self.bn_before_gru(x) | |
x = self.selu(x) | |
x = nn.AdaptiveAvgPool2d((1, None))(x) | |
x = x.squeeze(-2) | |
x = x.permute(0, 2, 1) | |
self.gru.flatten_parameters() | |
x, _ = self.gru(x) | |
x = x[:, -1, :] | |
x = self.fc1_gru(x) | |
x = self.fc2_gru(x) | |
return x | |
def forward(self, x): | |
x = self._compute_embedding(x) | |
return x | |
def _make_attention_fc(self, in_features, l_out_features): | |
l_fc = [] | |
l_fc.append(nn.Linear(in_features=in_features, out_features=l_out_features)) | |
return nn.Sequential(*l_fc) | |
class FrontendSpecRNet(SpecRNet): | |
def __init__(self, input_channels, **kwargs): | |
super().__init__(input_channels, **kwargs) | |
self.device = kwargs['device'] | |
frontend_name = kwargs.get("frontend_algorithm", []) | |
self.frontend = frontends.get_frontend(frontend_name) | |
print(f"Using {frontend_name} frontend") | |
def _compute_frontend(self, x): | |
frontend = self.frontend(x) | |
if frontend.ndim < 4: | |
return frontend.unsqueeze(1) # (bs, 1, n_lfcc, frames) | |
return frontend # (bs, n, n_lfcc, frames) | |
def forward(self, x): | |
x = self._compute_frontend(x) | |
x = self._compute_embedding(x) | |
return x | |
if __name__ == "__main__": | |
print("Definition of model") | |
device = "cuda" | |
input_channels = 1 | |
config = { | |
"filts": [input_channels, [input_channels, 20], [20, 64], [64, 64]], | |
"nb_fc_node": 64, | |
"gru_node": 64, | |
"nb_gru_layer": 2, | |
"nb_classes": 1, | |
} | |
def count_parameters(model) -> int: | |
pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) | |
return pytorch_total_params | |
model = FrontendSpecRNet(input_channels=1, device=device, frontend_algorithm=["lfcc"]) | |
model = model.to(device) | |
print(count_parameters(model)) | |