import torch from torch import nn from torch.nn import functional as F import torchaudio import pytorch_lightning as pl from torchmetrics import Accuracy, F1, Precision, Recall import torch.nn as nn import torch.nn.functional as F class M11(pl.LightningModule): def __init__(self, hidden_units_1, hidden_units_2, dropout_1, dropout_2, n_input=1, n_output=3, stride=4, n_channel=64, lr=1e-3, l2=1e-5): super().__init__() self.save_hyperparameters() self.conv1 = nn.Conv1d(n_input, n_channel, kernel_size=80, stride=stride) self.bn1 = nn.BatchNorm1d(n_channel) self.pool1 = nn.MaxPool1d(4) self.conv2 = nn.Conv1d(n_channel, n_channel, kernel_size=3,padding=1) self.bn2 = nn.BatchNorm1d(n_channel) self.conv3 = nn.Conv1d(n_channel, n_channel, kernel_size=3,padding=1) self.bn3 = nn.BatchNorm1d(n_channel) self.pool2 = nn.MaxPool1d(4) self.conv4 = nn.Conv1d(n_channel, 2 * n_channel, kernel_size=3,padding=1) self.bn4 = nn.BatchNorm1d(2 * n_channel) self.conv5 = nn.Conv1d(2 * n_channel, 2 * n_channel, kernel_size=3,padding=1) self.bn5 = nn.BatchNorm1d(2 * n_channel) self.pool3 = nn.MaxPool1d(4) self.conv6 = nn.Conv1d(2 * n_channel, 4 * n_channel, kernel_size=3,padding=1) self.bn6 = nn.BatchNorm1d(4 * n_channel) self.conv7 = nn.Conv1d(4 * n_channel, 4 * n_channel, kernel_size=3,padding=1) self.bn7 = nn.BatchNorm1d(4 * n_channel) self.conv8 = nn.Conv1d(4 * n_channel, 4 * n_channel, kernel_size=3,padding=1) self.bn8 = nn.BatchNorm1d(4 * n_channel) self.pool4 = nn.MaxPool1d(4) self.conv9 = nn.Conv1d(4 * n_channel, 8 * n_channel, kernel_size=3,padding=1) self.bn9 = nn.BatchNorm1d(8 * n_channel) self.conv10 = nn.Conv1d(8 * n_channel, 8 * n_channel, kernel_size=3,padding=1) self.bn10 = nn.BatchNorm1d(8 * n_channel) # self.fc1 = nn.Linear(8 * n_channel, n_output) self.mlp = nn.Sequential( nn.Linear(8 * n_channel, hidden_units_1), nn.ReLU(), nn.Dropout(dropout_1), nn.Linear(hidden_units_1, hidden_units_2), nn.ReLU(), nn.Dropout(dropout_2), nn.Linear(hidden_units_2, n_output) ) def forward(self, x): x = self.conv1(x) x = F.relu(self.bn1(x)) x = self.pool1(x) x = self.conv2(x) x = F.relu(self.bn2(x)) x = self.conv3(x) x = F.relu(self.bn3(x)) x = self.pool2(x) x = self.conv4(x) x = F.relu(self.bn4(x)) x = self.conv5(x) x = F.relu(self.bn5(x)) x = self.pool3(x) x = self.conv6(x) x = F.relu(self.bn6(x)) x = self.conv7(x) x = F.relu(self.bn7(x)) x = self.conv8(x) x = F.relu(self.bn8(x)) x = self.pool4(x) x = self.conv9(x) x = F.relu(self.bn9(x)) x = self.conv10(x) x = F.relu(self.bn10(x)) x = F.avg_pool1d(x, x.shape[-1]) x = x.permute(0, 2, 1) # x = self.fc1(x) x = self.mlp(x) return F.log_softmax(x, dim=2) def training_step(self, batch, batch_idx): # Very simple training loop data, target = batch logits = self(data) # this calls self.forward preds = torch.argmax(logits, dim=-1).squeeze() # loss = cost(logits.squeeze(), target) loss = unweighted_cost(logits.squeeze(), target) f1 = f1_metric(preds, target) self.log('train_loss', loss, on_epoch=True, prog_bar=True) self.log('train_f1', f1, on_epoch=True, prog_bar=True) return loss def validation_step(self, batch, batch_idx): data, target = batch logits = self(data) preds = torch.argmax(logits, dim=-1).squeeze() # loss = val_cost(logits.squeeze(), target) loss = unweighted_cost(logits.squeeze(), target) acc = accuracy(preds, target) f1 = f1_metric(preds, target) prec = precision(preds, target) rec = recall(preds, target) self.log('val_loss', loss, on_epoch=True, prog_bar=True) self.log('val_acc', acc, on_epoch=True, prog_bar=True) self.log('val_f1', f1, on_epoch=True, prog_bar=True) self.log('val_precision', prec, on_epoch=True, prog_bar=True) self.log('val_recall', rec, on_epoch=True, prog_bar=True) return loss, acc, f1, prec, rec def configure_optimizers(self): optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.lr, weight_decay=self.hparams.l2) return optimizer # DEVICE = "cuda" if torch.cuda.is_available() else "cpu" # model_PATH = "./model.ckpt" # audio_PATH = "./sample_audio.wav" # def _resample_if_necessary(signal, sr, device): # if sr != 8_000: # resampler = torchaudio.transforms.Resample(sr, 8_000).to(device) # signal = resampler(signal) # return signal # def _mix_down_if_necessary(signal): # if signal.shape[0] > 1: # signal = torch.mean(signal, dim=0, keepdim=True) # return signal # def get_likely_index(tensor): # # find most likely label index for each element in the batch # return tensor.argmax(dim=-1) # model = M11.load_from_checkpoint(model_PATH).to(DEVICE) # model.eval() # audio, sr = torchaudio.load(audio_PATH) # # resampler = torchaudio.transforms.Resample(sr, 8_000).to(DEVICE) # processed_audio = _mix_down_if_necessary(_resample_if_necessary(audio, sr, DEVICE)) # print(processed_audio.shape) # with torch.no_grad(): # pred = get_likely_index(model(processed_audio.unsqueeze(0).to(DEVICE))).view(-1) # # y_true = target.tolist() # # y_pred = pred.tolist() # # target_names = eval_dataset.label_list # # print(classification_report(y_true, y_pred, target_names=target_names)) # print(pred)