streamlit_demo / src /models /specrnet.py
ldhldh's picture
Upload 11 files
bb5a96d verified
"""
This file contains implementation of SpecRNet architecture.
We base our codebase on the implementation of RawNet2 by Hemlata Tak ([email protected]).
It is available here: https://github.com/asvspoof-challenge/2021/blob/main/LA/Baseline-RawNet2/model.py
"""
from typing import Dict
import torch.nn as nn
from src import frontends
def get_config(input_channels: int) -> Dict:
return {
"filts": [input_channels, [input_channels, 20], [20, 64], [64, 64]],
"nb_fc_node": 64,
"gru_node": 64,
"nb_gru_layer": 2,
"nb_classes": 1,
}
class Residual_block2D(nn.Module):
def __init__(self, nb_filts, first=False):
super().__init__()
self.first = first
if not self.first:
self.bn1 = nn.BatchNorm2d(num_features=nb_filts[0])
self.lrelu = nn.LeakyReLU(negative_slope=0.3)
self.conv1 = nn.Conv2d(
in_channels=nb_filts[0],
out_channels=nb_filts[1],
kernel_size=3,
padding=1,
stride=1,
)
self.bn2 = nn.BatchNorm2d(num_features=nb_filts[1])
self.conv2 = nn.Conv2d(
in_channels=nb_filts[1],
out_channels=nb_filts[1],
padding=1,
kernel_size=3,
stride=1,
)
if nb_filts[0] != nb_filts[1]:
self.downsample = True
self.conv_downsample = nn.Conv2d(
in_channels=nb_filts[0],
out_channels=nb_filts[1],
padding=0,
kernel_size=1,
stride=1,
)
else:
self.downsample = False
self.mp = nn.MaxPool2d(2)
def forward(self, x):
identity = x
if not self.first:
out = self.bn1(x)
out = self.lrelu(out)
else:
out = x
out = self.conv1(x)
out = self.bn2(out)
out = self.lrelu(out)
out = self.conv2(out)
if self.downsample:
identity = self.conv_downsample(identity)
out += identity
out = self.mp(out)
return out
class SpecRNet(nn.Module):
def __init__(self, input_channels, **kwargs):
super().__init__()
config = get_config(input_channels=input_channels)
self.device = kwargs.get("device", "cuda")
self.first_bn = nn.BatchNorm2d(num_features=config["filts"][0])
self.selu = nn.SELU(inplace=True)
self.block0 = nn.Sequential(
Residual_block2D(nb_filts=config["filts"][1], first=True)
)
self.block2 = nn.Sequential(Residual_block2D(nb_filts=config["filts"][2]))
config["filts"][2][0] = config["filts"][2][1]
self.block4 = nn.Sequential(Residual_block2D(nb_filts=config["filts"][2]))
self.avgpool = nn.AdaptiveAvgPool2d(1)
self.fc_attention0 = self._make_attention_fc(
in_features=config["filts"][1][-1], l_out_features=config["filts"][1][-1]
)
self.fc_attention2 = self._make_attention_fc(
in_features=config["filts"][2][-1], l_out_features=config["filts"][2][-1]
)
self.fc_attention4 = self._make_attention_fc(
in_features=config["filts"][2][-1], l_out_features=config["filts"][2][-1]
)
self.bn_before_gru = nn.BatchNorm2d(num_features=config["filts"][2][-1])
self.gru = nn.GRU(
input_size=config["filts"][2][-1],
hidden_size=config["gru_node"],
num_layers=config["nb_gru_layer"],
batch_first=True,
bidirectional=True,
)
self.fc1_gru = nn.Linear(
in_features=config["gru_node"] * 2, out_features=config["nb_fc_node"] * 2
)
self.fc2_gru = nn.Linear(
in_features=config["nb_fc_node"] * 2,
out_features=config["nb_classes"],
bias=True,
)
self.sig = nn.Sigmoid()
def _compute_embedding(self, x):
x = self.first_bn(x)
x = self.selu(x)
x0 = self.block0(x)
y0 = self.avgpool(x0).view(x0.size(0), -1)
y0 = self.fc_attention0(y0)
y0 = self.sig(y0).view(y0.size(0), y0.size(1), -1)
y0 = y0.unsqueeze(-1)
x = x0 * y0 + y0
x = nn.MaxPool2d(2)(x)
x2 = self.block2(x)
y2 = self.avgpool(x2).view(x2.size(0), -1)
y2 = self.fc_attention2(y2)
y2 = self.sig(y2).view(y2.size(0), y2.size(1), -1)
y2 = y2.unsqueeze(-1)
x = x2 * y2 + y2
x = nn.MaxPool2d(2)(x)
x4 = self.block4(x)
y4 = self.avgpool(x4).view(x4.size(0), -1)
y4 = self.fc_attention4(y4)
y4 = self.sig(y4).view(y4.size(0), y4.size(1), -1)
y4 = y4.unsqueeze(-1)
x = x4 * y4 + y4
x = nn.MaxPool2d(2)(x)
x = self.bn_before_gru(x)
x = self.selu(x)
x = nn.AdaptiveAvgPool2d((1, None))(x)
x = x.squeeze(-2)
x = x.permute(0, 2, 1)
self.gru.flatten_parameters()
x, _ = self.gru(x)
x = x[:, -1, :]
x = self.fc1_gru(x)
x = self.fc2_gru(x)
return x
def forward(self, x):
x = self._compute_embedding(x)
return x
def _make_attention_fc(self, in_features, l_out_features):
l_fc = []
l_fc.append(nn.Linear(in_features=in_features, out_features=l_out_features))
return nn.Sequential(*l_fc)
class FrontendSpecRNet(SpecRNet):
def __init__(self, input_channels, **kwargs):
super().__init__(input_channels, **kwargs)
self.device = kwargs['device']
frontend_name = kwargs.get("frontend_algorithm", [])
self.frontend = frontends.get_frontend(frontend_name)
print(f"Using {frontend_name} frontend")
def _compute_frontend(self, x):
frontend = self.frontend(x)
if frontend.ndim < 4:
return frontend.unsqueeze(1) # (bs, 1, n_lfcc, frames)
return frontend # (bs, n, n_lfcc, frames)
def forward(self, x):
x = self._compute_frontend(x)
x = self._compute_embedding(x)
return x
if __name__ == "__main__":
print("Definition of model")
device = "cuda"
input_channels = 1
config = {
"filts": [input_channels, [input_channels, 20], [20, 64], [64, 64]],
"nb_fc_node": 64,
"gru_node": 64,
"nb_gru_layer": 2,
"nb_classes": 1,
}
def count_parameters(model) -> int:
pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
return pytorch_total_params
model = FrontendSpecRNet(input_channels=1, device=device, frontend_algorithm=["lfcc"])
model = model.to(device)
print(count_parameters(model))