|
import torch |
|
from torch import nn |
|
from torch.nn import functional as F |
|
import torchaudio |
|
import pytorch_lightning as pl |
|
from torchmetrics import Accuracy, F1, Precision, Recall |
|
|
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
class M11(pl.LightningModule): |
|
def __init__(self, hidden_units_1, hidden_units_2, dropout_1, dropout_2, n_input=1, n_output=3, stride=4, n_channel=64, lr=1e-3, l2=1e-5): |
|
super().__init__() |
|
self.save_hyperparameters() |
|
|
|
|
|
self.conv1 = nn.Conv1d(n_input, n_channel, kernel_size=80, stride=stride) |
|
self.bn1 = nn.BatchNorm1d(n_channel) |
|
self.pool1 = nn.MaxPool1d(4) |
|
|
|
self.conv2 = nn.Conv1d(n_channel, n_channel, kernel_size=3,padding=1) |
|
self.bn2 = nn.BatchNorm1d(n_channel) |
|
self.conv3 = nn.Conv1d(n_channel, n_channel, kernel_size=3,padding=1) |
|
self.bn3 = nn.BatchNorm1d(n_channel) |
|
self.pool2 = nn.MaxPool1d(4) |
|
|
|
self.conv4 = nn.Conv1d(n_channel, 2 * n_channel, kernel_size=3,padding=1) |
|
self.bn4 = nn.BatchNorm1d(2 * n_channel) |
|
self.conv5 = nn.Conv1d(2 * n_channel, 2 * n_channel, kernel_size=3,padding=1) |
|
self.bn5 = nn.BatchNorm1d(2 * n_channel) |
|
self.pool3 = nn.MaxPool1d(4) |
|
|
|
self.conv6 = nn.Conv1d(2 * n_channel, 4 * n_channel, kernel_size=3,padding=1) |
|
self.bn6 = nn.BatchNorm1d(4 * n_channel) |
|
self.conv7 = nn.Conv1d(4 * n_channel, 4 * n_channel, kernel_size=3,padding=1) |
|
self.bn7 = nn.BatchNorm1d(4 * n_channel) |
|
self.conv8 = nn.Conv1d(4 * n_channel, 4 * n_channel, kernel_size=3,padding=1) |
|
self.bn8 = nn.BatchNorm1d(4 * n_channel) |
|
self.pool4 = nn.MaxPool1d(4) |
|
|
|
self.conv9 = nn.Conv1d(4 * n_channel, 8 * n_channel, kernel_size=3,padding=1) |
|
self.bn9 = nn.BatchNorm1d(8 * n_channel) |
|
self.conv10 = nn.Conv1d(8 * n_channel, 8 * n_channel, kernel_size=3,padding=1) |
|
self.bn10 = nn.BatchNorm1d(8 * n_channel) |
|
|
|
|
|
self.mlp = nn.Sequential( |
|
nn.Linear(8 * n_channel, hidden_units_1), |
|
nn.ReLU(), |
|
nn.Dropout(dropout_1), |
|
nn.Linear(hidden_units_1, hidden_units_2), |
|
nn.ReLU(), |
|
nn.Dropout(dropout_2), |
|
nn.Linear(hidden_units_2, n_output) |
|
) |
|
|
|
def forward(self, x): |
|
x = self.conv1(x) |
|
x = F.relu(self.bn1(x)) |
|
x = self.pool1(x) |
|
|
|
x = self.conv2(x) |
|
x = F.relu(self.bn2(x)) |
|
x = self.conv3(x) |
|
x = F.relu(self.bn3(x)) |
|
x = self.pool2(x) |
|
|
|
x = self.conv4(x) |
|
x = F.relu(self.bn4(x)) |
|
x = self.conv5(x) |
|
x = F.relu(self.bn5(x)) |
|
x = self.pool3(x) |
|
|
|
x = self.conv6(x) |
|
x = F.relu(self.bn6(x)) |
|
x = self.conv7(x) |
|
x = F.relu(self.bn7(x)) |
|
x = self.conv8(x) |
|
x = F.relu(self.bn8(x)) |
|
x = self.pool4(x) |
|
|
|
x = self.conv9(x) |
|
x = F.relu(self.bn9(x)) |
|
x = self.conv10(x) |
|
x = F.relu(self.bn10(x)) |
|
|
|
x = F.avg_pool1d(x, x.shape[-1]) |
|
x = x.permute(0, 2, 1) |
|
|
|
x = self.mlp(x) |
|
return F.log_softmax(x, dim=2) |
|
|
|
def training_step(self, batch, batch_idx): |
|
|
|
data, target = batch |
|
logits = self(data) |
|
preds = torch.argmax(logits, dim=-1).squeeze() |
|
|
|
loss = unweighted_cost(logits.squeeze(), target) |
|
|
|
f1 = f1_metric(preds, target) |
|
|
|
self.log('train_loss', loss, on_epoch=True, prog_bar=True) |
|
self.log('train_f1', f1, on_epoch=True, prog_bar=True) |
|
return loss |
|
|
|
def validation_step(self, batch, batch_idx): |
|
data, target = batch |
|
logits = self(data) |
|
preds = torch.argmax(logits, dim=-1).squeeze() |
|
|
|
loss = unweighted_cost(logits.squeeze(), target) |
|
|
|
acc = accuracy(preds, target) |
|
f1 = f1_metric(preds, target) |
|
prec = precision(preds, target) |
|
rec = recall(preds, target) |
|
|
|
self.log('val_loss', loss, on_epoch=True, prog_bar=True) |
|
self.log('val_acc', acc, on_epoch=True, prog_bar=True) |
|
self.log('val_f1', f1, on_epoch=True, prog_bar=True) |
|
self.log('val_precision', prec, on_epoch=True, prog_bar=True) |
|
self.log('val_recall', rec, on_epoch=True, prog_bar=True) |
|
return loss, acc, f1, prec, rec |
|
|
|
def configure_optimizers(self): |
|
optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.lr, weight_decay=self.hparams.l2) |
|
return optimizer |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|