|
""" |
|
This script implements a deep learning pipeline for audio classification using a pre-trained MobileNetV2 model. |
|
The pipeline includes data loading, model training, evaluation, and emissions tracking. |
|
""" |
|
|
|
import os |
|
import torch |
|
import torch.nn as nn |
|
import torchaudio |
|
from torch.utils.data import Dataset, DataLoader |
|
import numpy as np |
|
from transformers import AutoModelForImageClassification |
|
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score |
|
from tqdm import tqdm |
|
import logging |
|
from datasets import load_dataset |
|
from accelerate import Accelerator |
|
from codecarbon import EmissionsTracker |
|
import time |
|
|
|
|
|
class Config: |
|
""" |
|
Configuration class to store hyperparameters and model settings. |
|
""" |
|
SAMPLE_RATE = 16000 |
|
N_FFT = 800 |
|
N_MELS = 128 |
|
HOP_LENGTH = None |
|
SIZE = (96, 96) |
|
SCALING_DIM = (1, 2) |
|
LEARNING_RATE = 0.0005 |
|
BATCH_SIZE = 32 |
|
NUM_WORKERS = 4 |
|
NUM_EPOCHS = 1 |
|
MODEL_NAME = "google/mobilenet_v2_0.35_96" |
|
MODEL_PATH = "scaled_model_800_128_96x96_mobilenet_small_unscaled_submission.pth" |
|
|
|
config = Config() |
|
|
|
class AudioDataset(Dataset): |
|
""" |
|
Custom Dataset class for loading and processing audio data. |
|
|
|
Args: |
|
data (list): List of audio data samples. |
|
sample_rate (int, optional): Target sample rate for audio resampling. Defaults to 16000. |
|
audio_target_length (float, optional): Target length of audio in seconds. Defaults to 4.5. |
|
""" |
|
def __init__(self, data, sample_rate=16000, audio_target_length=4.5): |
|
self.data = data |
|
self.sample_rate = sample_rate |
|
self.audio_target_length = audio_target_length |
|
|
|
def __len__(self): |
|
return len(self.data) |
|
|
|
def __getitem__(self, index): |
|
|
|
if not hasattr(self, '_resampler_cache'): |
|
self._resampler_cache = {} |
|
|
|
|
|
data_item = self.data[index] |
|
waveform = torch.FloatTensor(data_item["audio"]["array"]) if len(data_item["audio"]["array"]) > 0 else torch.ones(36000)*1E-5 |
|
|
|
|
|
orig_freq = waveform.shape[-1] |
|
target_freq = self.audio_target_length * self.sample_rate |
|
resampler_key = (orig_freq, target_freq) |
|
|
|
if resampler_key not in self._resampler_cache: |
|
self._resampler_cache[resampler_key] = torchaudio.transforms.Resample( |
|
orig_freq=orig_freq, |
|
new_freq=target_freq |
|
) |
|
|
|
return self._resampler_cache[resampler_key](waveform), data_item["label"] |
|
|
|
|
|
def collate_fn(batch): |
|
""" |
|
Collate function to stack inputs and labels into batches. |
|
Args: |
|
batch (list): List of tuples containing inputs and labels. |
|
Returns: |
|
tuple: Stacked inputs and labels. |
|
""" |
|
return torch.stack([inputs for inputs, _ in batch]), torch.tensor([label for _, label in batch]) |
|
|
|
|
|
class AudioClassifier(nn.Module): |
|
""" |
|
Audio classification model using a pre-trained MobileNetV2. |
|
Args: |
|
model_name (str): Name of the pre-trained model. |
|
model_path (str): Path to save/load the model. |
|
new (bool, optional): Whether to load a new model or an existing one. Defaults to True. |
|
""" |
|
def __init__(self, model_name, model_path, new=True): |
|
super().__init__() |
|
self.model = self.load_model(model_name, model_path, new) |
|
self.num_classes = 2 |
|
self.mel_spectrogram = torchaudio.transforms.MelSpectrogram( |
|
sample_rate=config.SAMPLE_RATE, |
|
n_fft=config.N_FFT, |
|
n_mels=config.N_MELS, |
|
hop_length=config.HOP_LENGTH |
|
) |
|
self.amplitude_to_db = torchaudio.transforms.AmplitudeToDB() |
|
|
|
def load_model(self, model_name, model_path, new=False): |
|
""" |
|
Load the pre-trained model and modify the classifier. |
|
|
|
Args: |
|
model_name (str): Name of the pre-trained model. |
|
model_path (str): Path to save/load the model. |
|
new (bool, optional): Whether to load a new model or an existing one. Defaults to False. |
|
Returns: |
|
nn.Module: Loaded model. |
|
""" |
|
model = AutoModelForImageClassification.from_pretrained(model_name) |
|
model.classifier = torch.nn.Sequential( |
|
nn.Linear(in_features=1280, out_features=2)) |
|
|
|
for param in model.parameters(): |
|
param.requires_grad = True |
|
state_dict = torch.load(model_path) |
|
model.load_state_dict(state_dict) |
|
return model |
|
|
|
def forward(self, waveforms): |
|
""" |
|
Forward pass through the model. |
|
Args: |
|
waveforms (torch.Tensor): Input audio waveforms. |
|
Returns: |
|
torch.Tensor: Model output. |
|
""" |
|
melspectrogram = self.mel_spectrogram(waveforms) |
|
melspectrogram = nn.functional.interpolate(melspectrogram.unsqueeze(1), |
|
size=config.SIZE, |
|
mode="bilinear", |
|
align_corners=False).squeeze(1) |
|
db_melspectrogram = self.amplitude_to_db(melspectrogram) |
|
delta = torchaudio.functional.compute_deltas(melspectrogram) |
|
x = torch.stack([melspectrogram, db_melspectrogram, delta], dim=1) |
|
return self.model(x) |
|
|
|
|
|
class Evaluator: |
|
def __init__(self, model, dataloader, device): |
|
self.model = model |
|
self.dataloader = dataloader |
|
self.device = device |
|
|
|
@torch.no_grad() |
|
def evaluate(self): |
|
self.model.eval() |
|
all_predictions = [] |
|
all_labels = [] |
|
|
|
idx = 0 |
|
for waveforms, labels in self.dataloader: |
|
waveforms = waveforms.to(self.device) |
|
outputs = self.model(waveforms).logits |
|
predictions = torch.argmax(outputs, dim=1) |
|
all_predictions.extend(predictions.cpu().numpy()) |
|
all_labels.extend(labels.cpu().numpy()) |
|
idx += 1 |
|
if idx % 10 == 0: |
|
torch.cuda.empty_cache() |
|
|
|
all_predictions = np.array(all_predictions) |
|
all_labels = np.array(all_labels) |
|
|
|
|
|
return all_predictions |
|
|