ariel-eddie commited on
Commit
97dae9c
·
verified ·
1 Parent(s): d763135
Files changed (1) hide show
  1. audio_utils.py +179 -0
audio_utils.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This script implements a deep learning pipeline for audio classification using a pre-trained MobileNetV2 model.
3
+ The pipeline includes data loading, model training, evaluation, and emissions tracking.
4
+ """
5
+
6
+ import os
7
+ import torch
8
+ import torch.nn as nn
9
+ import torchaudio
10
+ from torch.utils.data import Dataset, DataLoader
11
+ import numpy as np
12
+ from transformers import AutoModelForImageClassification
13
+ from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
14
+ from tqdm import tqdm
15
+ import logging
16
+ from datasets import load_dataset
17
+ from accelerate import Accelerator
18
+ from codecarbon import EmissionsTracker
19
+ import time
20
+
21
+
22
+ class Config:
23
+ """
24
+ Configuration class to store hyperparameters and model settings.
25
+ """
26
+ SAMPLE_RATE = 16000
27
+ N_FFT = 800
28
+ N_MELS = 128
29
+ HOP_LENGTH = None
30
+ SIZE = (96, 96)
31
+ SCALING_DIM = (1, 2)
32
+ LEARNING_RATE = 0.0005
33
+ BATCH_SIZE = 32
34
+ NUM_WORKERS = 4
35
+ NUM_EPOCHS = 1
36
+ MODEL_NAME = "google/mobilenet_v2_0.35_96"
37
+ MODEL_PATH = "models-legacy/last/scaled_model_800_128_96x96_mobilenet_small_unscaled.pth"
38
+
39
+ config = Config()
40
+
41
+ class AudioDataset(Dataset):
42
+ """
43
+ Custom Dataset class for loading and processing audio data.
44
+
45
+ Args:
46
+ data (list): List of audio data samples.
47
+ sample_rate (int, optional): Target sample rate for audio resampling. Defaults to 16000.
48
+ audio_target_length (float, optional): Target length of audio in seconds. Defaults to 4.5.
49
+ """
50
+ def __init__(self, data, sample_rate=16000, audio_target_length=4.5):
51
+ self.data = data
52
+ self.sample_rate = sample_rate
53
+ self.audio_target_length = audio_target_length
54
+
55
+ def __len__(self):
56
+ return len(self.data)
57
+
58
+ def __getitem__(self, index):
59
+ # 1. Cache the resampler
60
+ if not hasattr(self, '_resampler_cache'):
61
+ self._resampler_cache = {}
62
+
63
+ # 2. Get data efficiently
64
+ data_item = self.data[index]
65
+ waveform = torch.FloatTensor(data_item["audio"]["array"]) if len(data_item["audio"]["array"]) > 0 else torch.ones(36000)*1E-5
66
+
67
+ # 4. Cached resampler creation
68
+ orig_freq = waveform.shape[-1]
69
+ target_freq = self.audio_target_length * self.sample_rate
70
+ resampler_key = (orig_freq, target_freq)
71
+
72
+ if resampler_key not in self._resampler_cache:
73
+ self._resampler_cache[resampler_key] = torchaudio.transforms.Resample(
74
+ orig_freq=orig_freq,
75
+ new_freq=target_freq
76
+ )
77
+ # 5. Apply resampling and return
78
+ return self._resampler_cache[resampler_key](waveform), data_item["label"]
79
+
80
+
81
+ def collate_fn(batch):
82
+ """
83
+ Collate function to stack inputs and labels into batches.
84
+ Args:
85
+ batch (list): List of tuples containing inputs and labels.
86
+ Returns:
87
+ tuple: Stacked inputs and labels.
88
+ """
89
+ return torch.stack([inputs for inputs, _ in batch]), torch.tensor([label for _, label in batch])
90
+
91
+
92
+ class AudioClassifier(nn.Module):
93
+ """
94
+ Audio classification model using a pre-trained MobileNetV2.
95
+ Args:
96
+ model_name (str): Name of the pre-trained model.
97
+ model_path (str): Path to save/load the model.
98
+ new (bool, optional): Whether to load a new model or an existing one. Defaults to True.
99
+ """
100
+ def __init__(self, model_name, model_path, new=True):
101
+ super().__init__()
102
+ self.model = self.load_model(model_name, model_path, new)
103
+ self.num_classes = 2
104
+ self.mel_spectrogram = torchaudio.transforms.MelSpectrogram(
105
+ sample_rate=config.SAMPLE_RATE,
106
+ n_fft=config.N_FFT,
107
+ n_mels=config.N_MELS,
108
+ hop_length=config.HOP_LENGTH
109
+ )
110
+ self.amplitude_to_db = torchaudio.transforms.AmplitudeToDB()
111
+
112
+ def load_model(self, model_name, model_path, new=False):
113
+ """
114
+ Load the pre-trained model and modify the classifier.
115
+
116
+ Args:
117
+ model_name (str): Name of the pre-trained model.
118
+ model_path (str): Path to save/load the model.
119
+ new (bool, optional): Whether to load a new model or an existing one. Defaults to False.
120
+ Returns:
121
+ nn.Module: Loaded model.
122
+ """
123
+ model = AutoModelForImageClassification.from_pretrained(model_name)
124
+ model.classifier = torch.nn.Sequential(
125
+ nn.Linear(in_features=1280, out_features=2))
126
+
127
+ for param in model.parameters():
128
+ param.requires_grad = True
129
+ state_dict = torch.load(model_path)
130
+ model.load_state_dict(state_dict)
131
+ return model
132
+
133
+ def forward(self, waveforms):
134
+ """
135
+ Forward pass through the model.
136
+ Args:
137
+ waveforms (torch.Tensor): Input audio waveforms.
138
+ Returns:
139
+ torch.Tensor: Model output.
140
+ """
141
+ melspectrogram = self.mel_spectrogram(waveforms)
142
+ melspectrogram = nn.functional.interpolate(melspectrogram.unsqueeze(1),
143
+ size=config.SIZE,
144
+ mode="bilinear",
145
+ align_corners=False).squeeze(1)
146
+ db_melspectrogram = self.amplitude_to_db(melspectrogram)
147
+ delta = torchaudio.functional.compute_deltas(melspectrogram)
148
+ x = torch.stack([melspectrogram, db_melspectrogram, delta], dim=1)
149
+ return self.model(x)
150
+
151
+
152
+ class Evaluator:
153
+ def __init__(self, model, dataloader, device):
154
+ self.model = model
155
+ self.dataloader = dataloader
156
+ self.device = device
157
+
158
+ @torch.no_grad()
159
+ def evaluate(self):
160
+ self.model.eval()
161
+ all_predictions = []
162
+ all_labels = []
163
+
164
+ idx = 0
165
+ for waveforms, labels in self.dataloader:
166
+ waveforms = waveforms.to(self.device)
167
+ outputs = self.model(waveforms).logits
168
+ predictions = torch.argmax(outputs, dim=1)
169
+ all_predictions.extend(predictions.cpu().numpy())
170
+ all_labels.extend(labels.cpu().numpy())
171
+ idx += 1
172
+ if idx % 10 == 0:
173
+ torch.cuda.empty_cache()
174
+
175
+ all_predictions = np.array(all_predictions)
176
+ all_labels = np.array(all_labels)
177
+
178
+ # return self.compute_metrics(all_predictions, all_labels)
179
+ return all_predictions