streamlit_demo / src /models /specrnet.py
ldhldh's picture
Upload 11 files
bb5a96d verified
raw
history blame
6.76 kB
"""
This file contains implementation of SpecRNet architecture.
We base our codebase on the implementation of RawNet2 by Hemlata Tak ([email protected]).
It is available here: https://github.com/asvspoof-challenge/2021/blob/main/LA/Baseline-RawNet2/model.py
"""
from typing import Dict
import torch.nn as nn
from src import frontends
def get_config(input_channels: int) -> Dict:
return {
"filts": [input_channels, [input_channels, 20], [20, 64], [64, 64]],
"nb_fc_node": 64,
"gru_node": 64,
"nb_gru_layer": 2,
"nb_classes": 1,
}
class Residual_block2D(nn.Module):
def __init__(self, nb_filts, first=False):
super().__init__()
self.first = first
if not self.first:
self.bn1 = nn.BatchNorm2d(num_features=nb_filts[0])
self.lrelu = nn.LeakyReLU(negative_slope=0.3)
self.conv1 = nn.Conv2d(
in_channels=nb_filts[0],
out_channels=nb_filts[1],
kernel_size=3,
padding=1,
stride=1,
)
self.bn2 = nn.BatchNorm2d(num_features=nb_filts[1])
self.conv2 = nn.Conv2d(
in_channels=nb_filts[1],
out_channels=nb_filts[1],
padding=1,
kernel_size=3,
stride=1,
)
if nb_filts[0] != nb_filts[1]:
self.downsample = True
self.conv_downsample = nn.Conv2d(
in_channels=nb_filts[0],
out_channels=nb_filts[1],
padding=0,
kernel_size=1,
stride=1,
)
else:
self.downsample = False
self.mp = nn.MaxPool2d(2)
def forward(self, x):
identity = x
if not self.first:
out = self.bn1(x)
out = self.lrelu(out)
else:
out = x
out = self.conv1(x)
out = self.bn2(out)
out = self.lrelu(out)
out = self.conv2(out)
if self.downsample:
identity = self.conv_downsample(identity)
out += identity
out = self.mp(out)
return out
class SpecRNet(nn.Module):
def __init__(self, input_channels, **kwargs):
super().__init__()
config = get_config(input_channels=input_channels)
self.device = kwargs.get("device", "cuda")
self.first_bn = nn.BatchNorm2d(num_features=config["filts"][0])
self.selu = nn.SELU(inplace=True)
self.block0 = nn.Sequential(
Residual_block2D(nb_filts=config["filts"][1], first=True)
)
self.block2 = nn.Sequential(Residual_block2D(nb_filts=config["filts"][2]))
config["filts"][2][0] = config["filts"][2][1]
self.block4 = nn.Sequential(Residual_block2D(nb_filts=config["filts"][2]))
self.avgpool = nn.AdaptiveAvgPool2d(1)
self.fc_attention0 = self._make_attention_fc(
in_features=config["filts"][1][-1], l_out_features=config["filts"][1][-1]
)
self.fc_attention2 = self._make_attention_fc(
in_features=config["filts"][2][-1], l_out_features=config["filts"][2][-1]
)
self.fc_attention4 = self._make_attention_fc(
in_features=config["filts"][2][-1], l_out_features=config["filts"][2][-1]
)
self.bn_before_gru = nn.BatchNorm2d(num_features=config["filts"][2][-1])
self.gru = nn.GRU(
input_size=config["filts"][2][-1],
hidden_size=config["gru_node"],
num_layers=config["nb_gru_layer"],
batch_first=True,
bidirectional=True,
)
self.fc1_gru = nn.Linear(
in_features=config["gru_node"] * 2, out_features=config["nb_fc_node"] * 2
)
self.fc2_gru = nn.Linear(
in_features=config["nb_fc_node"] * 2,
out_features=config["nb_classes"],
bias=True,
)
self.sig = nn.Sigmoid()
def _compute_embedding(self, x):
x = self.first_bn(x)
x = self.selu(x)
x0 = self.block0(x)
y0 = self.avgpool(x0).view(x0.size(0), -1)
y0 = self.fc_attention0(y0)
y0 = self.sig(y0).view(y0.size(0), y0.size(1), -1)
y0 = y0.unsqueeze(-1)
x = x0 * y0 + y0
x = nn.MaxPool2d(2)(x)
x2 = self.block2(x)
y2 = self.avgpool(x2).view(x2.size(0), -1)
y2 = self.fc_attention2(y2)
y2 = self.sig(y2).view(y2.size(0), y2.size(1), -1)
y2 = y2.unsqueeze(-1)
x = x2 * y2 + y2
x = nn.MaxPool2d(2)(x)
x4 = self.block4(x)
y4 = self.avgpool(x4).view(x4.size(0), -1)
y4 = self.fc_attention4(y4)
y4 = self.sig(y4).view(y4.size(0), y4.size(1), -1)
y4 = y4.unsqueeze(-1)
x = x4 * y4 + y4
x = nn.MaxPool2d(2)(x)
x = self.bn_before_gru(x)
x = self.selu(x)
x = nn.AdaptiveAvgPool2d((1, None))(x)
x = x.squeeze(-2)
x = x.permute(0, 2, 1)
self.gru.flatten_parameters()
x, _ = self.gru(x)
x = x[:, -1, :]
x = self.fc1_gru(x)
x = self.fc2_gru(x)
return x
def forward(self, x):
x = self._compute_embedding(x)
return x
def _make_attention_fc(self, in_features, l_out_features):
l_fc = []
l_fc.append(nn.Linear(in_features=in_features, out_features=l_out_features))
return nn.Sequential(*l_fc)
class FrontendSpecRNet(SpecRNet):
def __init__(self, input_channels, **kwargs):
super().__init__(input_channels, **kwargs)
self.device = kwargs['device']
frontend_name = kwargs.get("frontend_algorithm", [])
self.frontend = frontends.get_frontend(frontend_name)
print(f"Using {frontend_name} frontend")
def _compute_frontend(self, x):
frontend = self.frontend(x)
if frontend.ndim < 4:
return frontend.unsqueeze(1) # (bs, 1, n_lfcc, frames)
return frontend # (bs, n, n_lfcc, frames)
def forward(self, x):
x = self._compute_frontend(x)
x = self._compute_embedding(x)
return x
if __name__ == "__main__":
print("Definition of model")
device = "cuda"
input_channels = 1
config = {
"filts": [input_channels, [input_channels, 20], [20, 64], [64, 64]],
"nb_fc_node": 64,
"gru_node": 64,
"nb_gru_layer": 2,
"nb_classes": 1,
}
def count_parameters(model) -> int:
pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
return pytorch_total_params
model = FrontendSpecRNet(input_channels=1, device=device, frontend_algorithm=["lfcc"])
model = model.to(device)
print(count_parameters(model))