Tharya commited on
Commit
61fd02c
·
verified ·
1 Parent(s): e713ac9

Upload gtzan_dataset_linear_probe.py

Browse files
Files changed (1) hide show
  1. gtzan_dataset_linear_probe.py +296 -0
gtzan_dataset_linear_probe.py ADDED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union, Callable, List, Optional, Dict
2
+ import torch
3
+ import torch.nn as nn
4
+ from torch.utils.data import Dataset, DataLoader
5
+ from torch.optim import Adam
6
+ import numpy as np
7
+ import librosa
8
+ import miniaudio
9
+ from pathlib import Path
10
+ from sklearn.model_selection import train_test_split
11
+ from tqdm import tqdm
12
+ from functools import partial
13
+ import math
14
+
15
+ from mae import MaskedAutoencoderViT
16
+
17
+
18
+ def load_audio(
19
+ path: str,
20
+ sr: int = 32000,
21
+ duration: int = 20,
22
+ ) -> (np.ndarray, int):
23
+ g = miniaudio.stream_file(path, output_format=miniaudio.SampleFormat.FLOAT32, nchannels=1,
24
+ sample_rate=sr, frames_to_read=sr * duration)
25
+ signal = np.array(next(g))
26
+ return signal
27
+
28
+
29
+ def mel_spectrogram(
30
+ signal: np.ndarray,
31
+ sr: int = 32000,
32
+ n_fft: int = 800,
33
+ hop_length: int = 320,
34
+ n_mels: int = 128,
35
+ ) -> np.ndarray:
36
+ mel_spec = librosa.feature.melspectrogram(
37
+ y=signal, sr=sr, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels,
38
+ window='hann', pad_mode='constant'
39
+ )
40
+ mel_spec = librosa.power_to_db(mel_spec) # (freq, time)
41
+ return mel_spec.T # (time, freq)
42
+
43
+
44
+ def normalize(arr: np.ndarray, eps: float = 1e-8) -> np.ndarray:
45
+ return (arr - arr.mean()) / (arr.std() + eps)
46
+
47
+
48
+ device = 'cuda:0'
49
+ seed = 42
50
+ train_size = 0.8 # 80% train, 20% test
51
+ batch_size_train = 10
52
+ batch_size_test = 32
53
+ num_workers = 1
54
+ lr = 1e-3
55
+ epochs = 200
56
+ detection_epoch = 20
57
+
58
+ sr = 32000
59
+ n_fft = 800 # 25ms
60
+ hop_length = 320 # 10ms
61
+ duration = 10000 # seconds. 10000 ~= Inf for reading the whole audio file
62
+
63
+ feature_length = 2048 # length of mel spectrogram (MAE is trained with 2048x128 mel spectrogram)
64
+ patch_size = 16 # MAE split the mel spectrogram into patches with size 16x16
65
+
66
+ feature_padding = True
67
+ header = 'mean'
68
+
69
+ mlp_num_neurons = [768, 10]
70
+ mlp_activation_layer = nn.ReLU
71
+ mlp_bias = True
72
+
73
+ torch.manual_seed(seed)
74
+ np.random.seed(seed)
75
+ torch.cuda.manual_seed_all(seed)
76
+ torch.backends.cudnn.deterministic = True
77
+ torch.backends.cudnn.benchmark = False
78
+
79
+ # =============================== model ===============================
80
+ mae = MaskedAutoencoderViT(
81
+ img_size=(2048, 128),
82
+ patch_size=16,
83
+ in_chans=1,
84
+ embed_dim=768,
85
+ depth=12,
86
+ num_heads=12,
87
+ decoder_mode=1,
88
+ no_shift=False,
89
+ decoder_embed_dim=512,
90
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
91
+ norm_pix_loss=False,
92
+ pos_trainable=False,
93
+ )
94
+
95
+ # Load pre-trained weights
96
+ ckpt_path = 'music-mae-32kHz.pth.pth'
97
+ mae.load_state_dict(torch.load(ckpt_path, map_location='cpu'))
98
+ mae.to(device)
99
+ mae.eval()
100
+
101
+ # =============================== data ===============================
102
+ fp = Path('GTZAN-dataset/genres_original')
103
+ audio_data = dict() # {genre: [audio_file1, audio_file2, ...]}
104
+
105
+ for d in fp.iterdir():
106
+ if d.is_dir():
107
+ for f in d.iterdir():
108
+ if f.is_file():
109
+ genres = f.name.split('.')[0]
110
+ if genres not in audio_data:
111
+ audio_data[genres] = [str(f)]
112
+ else:
113
+ audio_data[genres].append(str(f))
114
+
115
+ audio_data_train = dict()
116
+ audio_data_test = dict()
117
+
118
+ for k, v in audio_data.items():
119
+ train_data, test_data = train_test_split(v, train_size=train_size, random_state=seed, shuffle=True)
120
+ audio_data_train[k] = train_data
121
+ audio_data_test[k] = test_data
122
+
123
+
124
+ @torch.no_grad()
125
+ def infer_mae_embedding(data: Dict) -> Dict:
126
+ emb_data = dict() # {genre: [embed1, embed2, ...]}
127
+
128
+ for k, v in tqdm(data.items(), desc='infer mae embedding', total=len(data)):
129
+ for f in v:
130
+ try:
131
+ mel_spec = mel_spectrogram(load_audio(f, duration=duration), sr=sr, n_fft=n_fft, hop_length=hop_length)
132
+ except Exception as e:
133
+ print(e)
134
+ print(f)
135
+ continue
136
+
137
+ # pad the mel spectrogram to the multiple of patch_size
138
+ input_length = mel_spec.shape[0]
139
+ n = math.ceil(input_length / patch_size)
140
+ if input_length < patch_size * n:
141
+ pad_length = patch_size * n - input_length
142
+ mel_spec = np.pad(mel_spec, ((0, pad_length), (0, 0)), mode='constant', constant_values=mel_spec.min())
143
+
144
+ # if the length of mel spectrogram after padding is longer than feature_length,
145
+ # split it into multiple snippets
146
+ input_length = mel_spec.shape[0]
147
+ embeds = []
148
+ for i in range(0, input_length, feature_length):
149
+ snippet = mel_spec[i:i + feature_length]
150
+ snippet = normalize(snippet)
151
+ snippet = snippet[None, None, :, :]
152
+ x = torch.from_numpy(snippet).to(device)
153
+ y = mae.forward_encoder_no_mask(x, header=header) # (1, 768)
154
+ y = y / y.norm(p=2, dim=-1, keepdim=True) # normalize
155
+ y = y.cpu().numpy().squeeze()
156
+ embeds.append(y)
157
+
158
+ y = np.mean(embeds, axis=0) # (768,)
159
+
160
+ if k not in emb_data:
161
+ emb_data[k] = [y]
162
+ else:
163
+ emb_data[k].append(y)
164
+
165
+ return emb_data
166
+
167
+
168
+ audio_emb_train = infer_mae_embedding(audio_data_train)
169
+ audio_emb_test = infer_mae_embedding(audio_data_test)
170
+
171
+ label_set = set(audio_emb_train.keys())
172
+ label_map = {label: i for i, label in enumerate(label_set)}
173
+ print(label_map)
174
+
175
+
176
+ class MLP(torch.nn.Sequential):
177
+ def __init__(
178
+ self,
179
+ num_neurons: List[int],
180
+ activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU,
181
+ bias: bool = True,
182
+ dropout: float = 0.0,
183
+ ):
184
+ layers = []
185
+ for c_in, c_out in zip(num_neurons[:-1], num_neurons[1:]):
186
+ layers.append(torch.nn.Linear(c_in, c_out, bias=bias))
187
+ layers.append(activation_layer())
188
+ layers.append(torch.nn.Dropout(dropout))
189
+
190
+ # remove the last two layers
191
+ layers.pop()
192
+ layers.pop()
193
+
194
+ super().__init__(*layers)
195
+
196
+
197
+ class SimpleDataset(Dataset):
198
+ def __init__(self, dict_data: Dict, label_map: Dict):
199
+ self.embed_with_label = []
200
+
201
+ for k, v in dict_data.items():
202
+ for emb in v:
203
+ self.embed_with_label.append((emb, label_map[k]))
204
+
205
+ def __len__(self):
206
+ return len(self.embed_with_label)
207
+
208
+ def __getitem__(self, idx):
209
+ return self.embed_with_label[idx]
210
+
211
+
212
+ train_dataset = SimpleDataset(audio_emb_train, label_map)
213
+ test_dataset = SimpleDataset(audio_emb_test, label_map)
214
+ print(f"len(train_dataset): {len(train_dataset)}")
215
+ print(f"len(test_dataset): {len(test_dataset)}")
216
+
217
+
218
+ def train_one_epoch(model, device, dataloader, loss_fn, optimizer):
219
+ model.train()
220
+
221
+ # for batch in tqdm(dataloader, desc='train', total=len(dataloader)):
222
+ for batch in dataloader:
223
+ x, y = batch
224
+ x = x.to(device)
225
+ y = y.to(device)
226
+
227
+ y_logit = model(x)
228
+ loss = loss_fn(y_logit, y)
229
+
230
+ optimizer.zero_grad()
231
+ loss.backward()
232
+ optimizer.step()
233
+
234
+
235
+ @torch.no_grad()
236
+ def eval_one_epoch(model, device, dataloader, loss_fn):
237
+ model.eval()
238
+
239
+ total_loss = 0.0
240
+ total_correct = 0.0
241
+ total_num = 0.0
242
+
243
+ for batch in dataloader:
244
+ x, y = batch
245
+ x = x.to(device)
246
+ y = y.to(device)
247
+
248
+ y_logit = model(x)
249
+ loss = loss_fn(y_logit, y)
250
+
251
+ total_loss += loss.item() * x.shape[0]
252
+ total_correct += (y_logit.argmax(dim=-1) == y).sum().item()
253
+ total_num += x.shape[0]
254
+
255
+ loss = total_loss / total_num
256
+ acc = total_correct / total_num
257
+
258
+ return loss, acc
259
+
260
+
261
+ mlp = MLP(
262
+ num_neurons=mlp_num_neurons,
263
+ activation_layer=mlp_activation_layer,
264
+ bias=mlp_bias,
265
+ dropout=0.0
266
+ )
267
+ print(MLP)
268
+
269
+ mlp.to(device)
270
+
271
+ optimizer = Adam(mlp.parameters(), lr=lr)
272
+ loss_fn = nn.CrossEntropyLoss()
273
+
274
+ train_dataloader = DataLoader(train_dataset, batch_size=batch_size_train, shuffle=True, num_workers=num_workers)
275
+ test_dataloader = DataLoader(test_dataset, batch_size=batch_size_test, shuffle=False, num_workers=num_workers)
276
+
277
+ test_loss, test_accuracy = eval_one_epoch(mlp, device, test_dataloader, loss_fn)
278
+ print(f"init: test loss {test_loss:.4f}, test accuracy {test_accuracy:.4f}")
279
+
280
+ best_accuracy = 0.0
281
+ at = 0
282
+
283
+ for epoch in range(epochs):
284
+ train_one_epoch(mlp, device, train_dataloader, loss_fn, optimizer)
285
+ test_loss, test_accuracy = eval_one_epoch(mlp, device, test_dataloader, loss_fn)
286
+
287
+ print(f"epoch {epoch}: test loss {test_loss:.4f}, test accuracy {test_accuracy:.4f}")
288
+
289
+ if test_accuracy > best_accuracy:
290
+ best_accuracy = test_accuracy
291
+ at = epoch
292
+
293
+ if epoch - at >= detection_epoch:
294
+ print(f"early stop at epoch {epoch}")
295
+ print(f"best accuracy: {best_accuracy:.4f} at epoch {at}")
296
+ break