ldhldh commited on
Commit
2c0f55c
·
verified ·
1 Parent(s): bb5a96d

Upload 28 files

Browse files
src/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ import logging
2
+
3
+ logging.getLogger(__name__).addHandler(logging.NullHandler())
src/commons.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Utility file for src toolkit."""
2
+ import os
3
+ import random
4
+
5
+ import numpy as np
6
+ import torch
7
+
8
+ WHISPER_MODEL_WEIGHTS_PATH = "src/models/assets/tiny_enc.en.pt"
9
+
10
+
11
+ def set_seed(seed: int):
12
+ """Fix PRNG seed for reproducable experiments.
13
+ """
14
+ random.seed(seed)
15
+ np.random.seed(seed)
16
+ torch.manual_seed(seed)
17
+ if torch.cuda.is_available():
18
+ torch.cuda.manual_seed(seed)
19
+ torch.cuda.manual_seed_all(seed)
20
+ torch.backends.cudnn.deterministic = True
21
+ torch.backends.cudnn.benchmark = False
22
+ os.environ["PYTHONHASHSEED"] = str(seed)
src/datasets/KoAAD_dataset.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from src.datasets.base_dataset import SimpleAudioFakeDataset
2
+ import pandas as pd
3
+ from pathlib import Path
4
+ import os
5
+
6
+ class KoAAD(SimpleAudioFakeDataset):
7
+ def __init__(self, root_path, subset=None, **kwargs):
8
+ super().__init__(root_path, subset, **kwargs)
9
+ self.root_path = Path(f'{root_path}')
10
+ self.subset = subset
11
+ self.samples = self.load_samples()
12
+
13
+ def load_samples(self):
14
+ samples = {
15
+ "user_id": [],
16
+ "sample_name": [],
17
+ "attack_type": [],
18
+ "label": [],
19
+ "path": []
20
+ }
21
+
22
+ folders_1 = list(self.root_path.glob("*"))
23
+ for f1 in folders_1:
24
+ if not os.path.isdir(f1):
25
+ continue
26
+
27
+ if not f1.exists():
28
+ print(f"{path} 경로를 찾을 수 없습니다.")
29
+
30
+ samples_list = list(f1.rglob("*.[wm][ap][v3]"))
31
+ if self.subset == 'train':
32
+ samples_list = samples_list[:int(len(samples_list)*0.7)]
33
+ else:
34
+ samples_list = samples_list[int(len(samples_list)*0.7):]
35
+ for sample in samples_list:
36
+ if os.path.exists(sample):
37
+ samples["user_id"].append(None)
38
+ samples["path"].append(sample)
39
+ samples["sample_name"].append(sample.stem)
40
+ samples["attack_type"].append("-")
41
+ samples["label"].append("spoof")
42
+
43
+ print(f"KoAAD_{self.subset}:{len(samples['label'])}")
44
+ return pd.DataFrame(samples)
45
+
src/datasets/MAILABS_dataset.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from src.datasets.base_dataset import SimpleAudioFakeDataset
2
+ import pandas as pd
3
+ from pathlib import Path
4
+ import os
5
+
6
+ class MAILABS(SimpleAudioFakeDataset):
7
+ def __init__(self, root_path, subset=None, **kwargs):
8
+ super().__init__(root_path, subset, **kwargs)
9
+ self.root_path = Path(f'{root_path}')
10
+ self.subset = subset
11
+ self.samples = self.load_samples()
12
+
13
+ def load_samples(self):
14
+ samples = {
15
+ "user_id": [],
16
+ "sample_name": [],
17
+ "attack_type": [],
18
+ "label": [],
19
+ "path": []
20
+ }
21
+ split = [0.7, 0.3]
22
+
23
+ folders_1 = list(self.root_path.glob("en_US/by_book/*"))
24
+ for f1 in folders_1:
25
+ if not os.path.isdir(f1):
26
+ continue
27
+ folders_2 = list(f1.glob("*"))
28
+ for f2 in folders_2:
29
+ path = f1 / f2.name
30
+
31
+ if not path.exists():
32
+ print(f"{path} 경로를 찾을 수 없습니다.")
33
+
34
+ samples_list = list(path.rglob("*.wav"))
35
+ if self.subset == 'train':
36
+ samples_list = samples_list[:int(len(samples_list)*split[0])]
37
+ elif self.subset == 'test':
38
+ samples_list = samples_list[int(len(samples_list)*(split[0])):]
39
+ for sample in samples_list:
40
+ if sample.stem[0]==".":
41
+ continue
42
+ if os.path.exists(sample):
43
+ samples["user_id"].append(None)
44
+ samples["path"].append(sample)
45
+ samples["sample_name"].append(sample.stem)
46
+ samples["attack_type"].append("-")
47
+ samples["label"].append("bonafide")
48
+ print(f"MAILABS_{self.subset}:{len(samples['label'])}")
49
+ return pd.DataFrame(samples)
50
+
src/datasets/MLAADv3_dataset.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from src.datasets.base_dataset import SimpleAudioFakeDataset
2
+ import pandas as pd
3
+ from pathlib import Path
4
+
5
+ class MLAADv3(SimpleAudioFakeDataset):
6
+ languages=['fr', 'et', 'ar', 'hu', 'bg', 'es', 'el', 'da', 'ga', 'ru', 'fi',
7
+ 'uk', 'pl', 'en', 'sw', 'mt', 'sk', 'ro', 'hi', 'cs', 'nl', 'it', 'de']
8
+
9
+ def __init__(self, root_path, subset=None, **kwargs):
10
+ super().__init__(root_path, subset, **kwargs)
11
+ self.root_path = Path(f'{root_path}')
12
+ self.subset = subset
13
+ self.samples = self.load_samples()
14
+
15
+ def load_samples(self):
16
+ samples = {
17
+ "user_id": [],
18
+ "language" : [],
19
+ "sample_name": [],
20
+ "attack_type": [],
21
+ "label": [],
22
+ "path": []
23
+ }
24
+
25
+ for lang in self.languages:
26
+ r_path = self.root_path / f"fake/{lang}"
27
+ folders = list(r_path.glob("*"))
28
+ for folder in folders:
29
+ path = r_path / folder.name
30
+
31
+ if not path.exists():
32
+ print(f"{path} 경로를 찾을 수 없습니다.")
33
+ continue
34
+ samples_list = list(path.rglob("*.wav"))
35
+ if self.subset == 'train':
36
+ samples_list = samples_list[:int(len(samples_list)*0.7)]
37
+ else:
38
+ samples_list = samples_list[int(len(samples_list)*0.7):]
39
+ for sample in samples_list:
40
+ samples["user_id"].append(None)
41
+ samples["language"].append(lang)
42
+ samples["path"].append(sample)
43
+ samples["sample_name"].append(sample.stem)
44
+ samples["attack_type"].append("-")
45
+ samples["label"].append("spoof")
46
+ print(f"__MLAADv3_{self.subset}:{len(samples['label'])}")
47
+ return pd.DataFrame(samples)
src/datasets/__init__.py ADDED
File without changes
src/datasets/aihub_dataset.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from src.datasets.base_dataset import SimpleAudioFakeDataset
2
+ import pandas as pd
3
+ from pathlib import Path
4
+
5
+ class AIHUB(SimpleAudioFakeDataset):
6
+ def __init__(self, root_path, subset=None, **kwargs):
7
+ super().__init__(root_path, subset, **kwargs)
8
+ self.root_path = Path(f'{root_path}')
9
+ self.subset = subset
10
+ self.samples = self.load_samples()
11
+
12
+ def load_samples(self):
13
+ samples = {
14
+ "user_id": [],
15
+ "sample_name": [],
16
+ "attack_type": [],
17
+ "label": [],
18
+ "path": []
19
+ }
20
+
21
+ path = self.root_path / ""
22
+
23
+ # 해당 언어의 디렉토리가 존재하는지 확인
24
+ if not path.exists():
25
+ print(f"{path} 경로를 찾을 수 없습니다.")
26
+
27
+ samples_list = list(path.rglob("*.wav"))
28
+ if self.subset == 'train':
29
+ samples_list = samples_list[:int(len(samples_list)*0.7)]
30
+ else:
31
+ samples_list = samples_list[int(len(samples_list)*0.7):]
32
+ for sample in samples_list:
33
+ samples["user_id"].append(None)
34
+ samples["path"].append(sample)
35
+ samples["sample_name"].append(sample.stem)
36
+ samples["attack_type"].append("-")
37
+ samples["label"].append("bonafide")
38
+ print(f"__AIHUB_{self.subset}:{len(samples['label'])}")
39
+ return pd.DataFrame(samples)
40
+
src/datasets/asvspoof_dataset.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ import pandas as pd
4
+ if __name__ == "__main__":
5
+ import sys
6
+ sys.path.append(str(Path(__file__).parent.parent.parent.absolute()))
7
+
8
+ from src.datasets.base_dataset import SimpleAudioFakeDataset
9
+
10
+ ASVSPOOF_SPLIT = {
11
+ "train": ['A01', 'A07', 'A08', 'A02', 'A09', 'A10', 'A03', 'A04', 'A05', 'A06', 'A11', 'A12', 'A13', 'A14', 'A15', 'A16', 'A17', 'A18', 'A19'],
12
+ "test": ['A01', 'A07', 'A08', 'A02', 'A09', 'A10', 'A03', 'A04', 'A05', 'A06', 'A11', 'A12', 'A13', 'A14', 'A15', 'A16', 'A17', 'A18', 'A19'],
13
+ "val": ['A01', 'A07', 'A08', 'A02', 'A09', 'A10', 'A03', 'A04', 'A05', 'A06', 'A11', 'A12', 'A13', 'A14', 'A15', 'A16', 'A17', 'A18', 'A19'],
14
+ "partition_ratio": [0.7, 0.15],
15
+ "seed": 45,
16
+ }
17
+
18
+
19
+ class ASVSpoofDataset(SimpleAudioFakeDataset):
20
+
21
+ protocol_folder_name = "ASVspoof2019_LA_cm_protocols"
22
+ subset_dir_prefix = "ASVspoof2019_LA_"
23
+ subsets = ("train", "dev", "eval")
24
+
25
+ def __init__(self, path, subset="train", transform=None):
26
+ super().__init__(subset, transform)
27
+ self.path = path
28
+
29
+ self.allowed_attacks = ASVSPOOF_SPLIT[subset]
30
+ self.partition_ratio = ASVSPOOF_SPLIT["partition_ratio"]
31
+ self.seed = ASVSPOOF_SPLIT["seed"]
32
+
33
+ self.samples = pd.DataFrame()
34
+
35
+ for subset in self.subsets:
36
+ subset_dir = Path(self.path) / f"{self.subset_dir_prefix}{subset}"
37
+ subset_protocol_path = self.get_protocol_path(subset)
38
+ subset_samples = self.read_protocol(subset_dir, subset_protocol_path)
39
+
40
+ self.samples = pd.concat([self.samples, subset_samples])
41
+
42
+ self.transform = transform
43
+
44
+ def get_protocol_path(self, subset):
45
+ paths = list((Path(self.path) / self.protocol_folder_name).glob("*.txt"))
46
+ for path in paths:
47
+ if subset in Path(path).stem:
48
+ return path
49
+
50
+ def read_protocol(self, subset_dir, protocol_path):
51
+ samples = {
52
+ "user_id": [],
53
+ "sample_name": [],
54
+ "attack_type": [],
55
+ "label": [],
56
+ "path": []
57
+ }
58
+
59
+ real_samples = []
60
+ fake_samples = []
61
+ with open(protocol_path, "r") as file:
62
+ for line in file:
63
+ attack_type = line.strip().split(" ")[3]
64
+
65
+ if attack_type == "-":
66
+ real_samples.append(line)
67
+ elif attack_type in self.allowed_attacks:
68
+ fake_samples.append(line)
69
+
70
+ if attack_type not in self.allowed_attacks:
71
+ continue
72
+
73
+ fake_samples = self.split_samples(fake_samples)
74
+ for line in fake_samples:
75
+ samples = self.add_line_to_samples(samples, line, subset_dir)
76
+
77
+ real_samples = self.split_samples(real_samples)
78
+ for line in real_samples:
79
+ samples = self.add_line_to_samples(samples, line, subset_dir)
80
+
81
+ return pd.DataFrame(samples)
82
+
83
+ @staticmethod
84
+ def add_line_to_samples(samples, line, subset_dir):
85
+ user_id, sample_name, _, attack_type, label = line.strip().split(" ")
86
+ samples["user_id"].append(user_id)
87
+ samples["sample_name"].append(sample_name)
88
+ samples["attack_type"].append(attack_type)
89
+ samples["label"].append(label)
90
+
91
+ assert (subset_dir / "flac" / f"{sample_name}.flac").exists()
92
+ samples["path"].append(subset_dir / "flac" / f"{sample_name}.flac")
93
+
94
+ return samples
95
+
96
+ class ASVSpoof2019DatasetOriginal(ASVSpoofDataset):
97
+
98
+ subsets = {"train": "train", "test": "dev", "val": "eval"}
99
+
100
+ protocol_folder_name = "ASVspoof2019_LA_cm_protocols"
101
+ subset_dir_prefix = "ASVspoof2019_LA_"
102
+ subset_dirs_attacks = {
103
+ "train": ["A01", "A02", "A03", "A04", "A05", "A06"],
104
+ "dev": ["A01", "A02", "A03", "A04", "A05", "A06"],
105
+ "eval": [
106
+ "A07", "A08", "A09", "A10", "A11", "A12", "A13", "A14", "A15",
107
+ "A16", "A17", "A18", "A19"
108
+ ]
109
+ }
110
+
111
+
112
+ def __init__(self, path, fold_subset="train"):
113
+ """
114
+ Initialise object. Skip __init__ of ASVSpoofDataset doe to different
115
+ logic, but follow SimpleAudioFakeDataset constructor.
116
+ """
117
+ super(ASVSpoofDataset, self).__init__(float('inf'), fold_subset)
118
+ self.path = path
119
+ subset = self.subsets[fold_subset]
120
+ self.allowed_attacks = self.subset_dirs_attacks[subset]
121
+ subset_dir = Path(self.path) / f"{self.subset_dir_prefix}{subset}"
122
+ subset_protocol_path = self.get_protocol_path(subset)
123
+ self.samples = self.read_protocol(subset_dir, subset_protocol_path)
124
+
125
+ def read_protocol(self, subset_dir, protocol_path):
126
+ samples = {
127
+ "user_id": [],
128
+ "sample_name": [],
129
+ "attack_type": [],
130
+ "label": [],
131
+ "path": []
132
+ }
133
+
134
+ real_samples = []
135
+ fake_samples = []
136
+
137
+ with open(protocol_path, "r") as file:
138
+ for line in file:
139
+ attack_type = line.strip().split(" ")[3]
140
+ if attack_type == "-":
141
+ real_samples.append(line)
142
+ elif attack_type in self.allowed_attacks:
143
+ fake_samples.append(line)
144
+ else:
145
+ raise ValueError(
146
+ "Tried to load attack that shouldn't be here!"
147
+ )
148
+
149
+ for line in fake_samples:
150
+ samples = self.add_line_to_samples(samples, line, subset_dir)
151
+ for line in real_samples:
152
+ samples = self.add_line_to_samples(samples, line, subset_dir)
153
+
154
+ return pd.DataFrame(samples)
155
+
src/datasets/base_dataset.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Base dataset classes."""
2
+ import logging
3
+ import math
4
+ import random
5
+
6
+ import numpy as np
7
+ import pandas as pd
8
+ import torch
9
+ import torchaudio
10
+ from torch.utils.data import Dataset
11
+ from torch.utils.data.dataset import T_co
12
+
13
+
14
+ LOGGER = logging.getLogger(__name__)
15
+
16
+ SAMPLING_RATE = 16_000
17
+ APPLY_NORMALIZATION = True
18
+ APPLY_TRIMMING = True
19
+ APPLY_PADDING = True
20
+ FRAMES_NUMBER = 480_000 # <- originally 64_600
21
+
22
+
23
+ SOX_SILENCE = [
24
+ # trim all silence that is longer than 0.2s and louder than 1% volume (relative to the file)
25
+ # from beginning and middle/end
26
+ ["silence", "1", "0.2", "1%", "-1", "0.2", "1%"],
27
+ ]
28
+
29
+
30
+ class SimpleAudioFakeDataset(Dataset):
31
+ def __init__(
32
+ self,
33
+ subset,
34
+ transform=None,
35
+ return_label: bool = True,
36
+ return_meta: bool = False,
37
+ ):
38
+ self.transform = transform
39
+ self.samples = pd.DataFrame()
40
+
41
+ self.subset = subset
42
+ self.allowed_attacks = None
43
+ self.partition_ratio = None
44
+ self.seed = None
45
+ self.return_label = return_label
46
+ self.return_meta = return_meta
47
+
48
+ def split_samples(self, samples_list):
49
+ if isinstance(samples_list, pd.DataFrame):
50
+ samples_list = samples_list.sort_values(by=list(samples_list.columns))
51
+ samples_list = samples_list.sample(frac=1, random_state=self.seed)
52
+ else:
53
+ samples_list = sorted(samples_list)
54
+ random.seed(self.seed)
55
+ random.shuffle(samples_list)
56
+
57
+ p, s = self.partition_ratio
58
+ subsets = np.split(
59
+ samples_list, [int(p * len(samples_list)), int((p + s) * len(samples_list))]
60
+ )
61
+ return dict(zip(["train", "test", "val"], subsets))[self.subset]
62
+
63
+ def df2tuples(self):
64
+ tuple_samples = []
65
+ for i, elem in self.samples.iterrows():
66
+ tuple_samples.append(
67
+ (str(elem["path"]), elem["label"], elem["attack_type"])
68
+ )
69
+
70
+ self.samples = tuple_samples
71
+ return self.samples
72
+
73
+ def __getitem__(self, index) -> T_co:
74
+ if isinstance(self.samples, pd.DataFrame):
75
+ sample = self.samples.iloc[index]
76
+
77
+ path = str(sample["path"])
78
+ label = sample["label"]
79
+ attack_type = sample["attack_type"]
80
+ if type(attack_type) != str and math.isnan(attack_type):
81
+ attack_type = "N/A"
82
+ else:
83
+ path, label, attack_type = self.samples[index]
84
+
85
+ waveform, sample_rate = torchaudio.load(path, normalize=APPLY_NORMALIZATION)
86
+ real_sec_length = len(waveform[0]) / sample_rate
87
+
88
+ waveform, sample_rate = apply_preprocessing(waveform, sample_rate)
89
+
90
+ return_data = [waveform, sample_rate]
91
+ if self.return_label:
92
+ label = 1 if label == "bonafide" else 0
93
+ return_data.append(label)
94
+
95
+ if self.return_meta:
96
+ return_data.append(
97
+ (
98
+ attack_type,
99
+ path,
100
+ self.subset,
101
+ real_sec_length,
102
+ )
103
+ )
104
+ return return_data
105
+
106
+ def __len__(self):
107
+ return len(self.samples)
108
+
109
+
110
+ def apply_preprocessing(
111
+ waveform,
112
+ sample_rate,
113
+ ):
114
+ if sample_rate != SAMPLING_RATE and SAMPLING_RATE != -1:
115
+ waveform, sample_rate = resample_wave(waveform, sample_rate, SAMPLING_RATE)
116
+
117
+ # Stereo to mono
118
+ if waveform.dim() > 1 and waveform.shape[0] > 1:
119
+ waveform = waveform[:1, ...]
120
+
121
+ # Trim too long utterances...
122
+ if APPLY_TRIMMING:
123
+ waveform, sample_rate = apply_trim(waveform, sample_rate)
124
+
125
+ # ... or pad too short ones.
126
+ if APPLY_PADDING:
127
+ waveform = apply_pad(waveform, FRAMES_NUMBER)
128
+
129
+ return waveform, sample_rate
130
+
131
+
132
+ def resample_wave(waveform, sample_rate, target_sample_rate):
133
+ waveform, sample_rate = torchaudio.sox_effects.apply_effects_tensor(
134
+ waveform, sample_rate, [["rate", f"{target_sample_rate}"]]
135
+ )
136
+ return waveform, sample_rate
137
+
138
+
139
+ def resample_file(path, target_sample_rate, normalize=True):
140
+ waveform, sample_rate = torchaudio.sox_effects.apply_effects_file(
141
+ path, [["rate", f"{target_sample_rate}"]], normalize=normalize
142
+ )
143
+
144
+ return waveform, sample_rate
145
+
146
+
147
+ def apply_trim(waveform, sample_rate):
148
+ (
149
+ waveform_trimmed,
150
+ sample_rate_trimmed,
151
+ ) = torchaudio.sox_effects.apply_effects_tensor(waveform, sample_rate, SOX_SILENCE)
152
+
153
+ if waveform_trimmed.size()[1] > 0:
154
+ waveform = waveform_trimmed
155
+ sample_rate = sample_rate_trimmed
156
+
157
+ return waveform, sample_rate
158
+
159
+
160
+ def apply_pad(waveform, cut):
161
+ """Pad wave by repeating signal until `cut` length is achieved."""
162
+ waveform = waveform.squeeze(0)
163
+ waveform_len = waveform.shape[0]
164
+
165
+ if waveform_len >= cut:
166
+ return waveform[:cut]
167
+
168
+ # need to pad
169
+ num_repeats = int(cut / waveform_len) + 1
170
+ padded_waveform = torch.tile(waveform, (1, num_repeats))[:, :cut][0]
171
+
172
+ return padded_waveform
src/datasets/deepfake_asvspoof_dataset.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from pathlib import Path
3
+
4
+ import pandas as pd
5
+
6
+ from src.datasets.base_dataset import SimpleAudioFakeDataset
7
+
8
+ DF_ASVSPOOF_SPLIT = {
9
+ "partition_ratio": [0.7, 0.15],
10
+ "seed": 45
11
+ }
12
+
13
+ LOGGER = logging.getLogger()
14
+
15
+ class DeepFakeASVSpoofDataset(SimpleAudioFakeDataset):
16
+
17
+ protocol_file_name = "keys/CM/trial_metadata.txt"
18
+ subset_dir_prefix = "ASVspoof2021_DF_eval"
19
+ subset_parts = ("part00", "part01", "part02", "part03")
20
+
21
+ def __init__(self, path, subset="train", transform=None):
22
+ super().__init__(subset, transform)
23
+ self.path = path
24
+
25
+ self.partition_ratio = DF_ASVSPOOF_SPLIT["partition_ratio"]
26
+ self.seed = DF_ASVSPOOF_SPLIT["seed"]
27
+
28
+ self.flac_paths = self.get_file_references()
29
+ self.samples = self.read_protocol()
30
+
31
+ self.transform = transform
32
+ LOGGER.info(f"Spoof: {len(self.samples[self.samples['label'] == 'spoof'])}")
33
+ LOGGER.info(f"Original: {len(self.samples[self.samples['label'] == 'bonafide'])}")
34
+
35
+ def get_file_references(self):
36
+ flac_paths = {}
37
+ for part in self.subset_parts:
38
+ path = Path(self.path) / f"{self.subset_dir_prefix}_{part}" / self.subset_dir_prefix / "flac"
39
+ flac_list = list(path.glob("*.flac"))
40
+
41
+ for path in flac_list:
42
+ flac_paths[path.stem] = path
43
+
44
+ return flac_paths
45
+
46
+ def read_protocol(self):
47
+ samples = {
48
+ "sample_name": [],
49
+ "label": [],
50
+ "path": [],
51
+ "attack_type": [],
52
+ }
53
+
54
+ real_samples = []
55
+ fake_samples = []
56
+ with open(Path(self.path) / self.protocol_file_name, "r") as file:
57
+ for line in file:
58
+ label = line.strip().split(" ")[5]
59
+
60
+ if label == "bonafide":
61
+ real_samples.append(line)
62
+ elif label == "spoof":
63
+ fake_samples.append(line)
64
+
65
+ fake_samples = self.split_samples(fake_samples)
66
+ for line in fake_samples:
67
+ samples = self.add_line_to_samples(samples, line)
68
+
69
+ real_samples = self.split_samples(real_samples)
70
+ for line in real_samples:
71
+ samples = self.add_line_to_samples(samples, line)
72
+
73
+ return pd.DataFrame(samples)
74
+
75
+ def add_line_to_samples(self, samples, line):
76
+ _, sample_name, _, _, _, label, _, _ = line.strip().split(" ")
77
+ samples["sample_name"].append(sample_name)
78
+ samples["label"].append(label)
79
+ samples["attack_type"].append(label)
80
+
81
+ sample_path = self.flac_paths[sample_name]
82
+ assert sample_path.exists()
83
+ samples["path"].append(sample_path)
84
+
85
+ return samples
86
+
src/datasets/detection_dataset.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import List, Optional
3
+
4
+
5
+ import pandas as pd
6
+
7
+ from src.datasets.base_dataset import SimpleAudioFakeDataset
8
+ from src.datasets.deepfake_asvspoof_dataset import DeepFakeASVSpoofDataset
9
+ from src.datasets.fakeavceleb_dataset import FakeAVCelebDataset
10
+ from src.datasets.wavefake_dataset import WaveFakeDataset
11
+ from src.datasets.asvspoof_dataset import ASVSpoof2019DatasetOriginal
12
+ from src.datasets.MLAADv3_dataset import MLAADv3
13
+ from src.datasets.MAILABS_dataset import MAILABS
14
+ from src.datasets.aihub_dataset import AIHUB
15
+ from src.datasets.KoAAD_dataset import KoAAD
16
+
17
+
18
+ LOGGER = logging.getLogger()
19
+
20
+
21
+ class DetectionDataset(SimpleAudioFakeDataset):
22
+ def __init__(
23
+ self,
24
+ asvspoof_path=None,
25
+ wavefake_path=None,
26
+ fakeavceleb_path=None,
27
+ asvspoof2019_path=None,
28
+ MLAADv3_path=None,
29
+ MAILABS_path=None,
30
+ AIHUB_path=None,
31
+ KoAAD_path=None,
32
+ subset: str = "val",
33
+ transform=None,
34
+ oversample: bool = True,
35
+ undersample: bool = False,
36
+ return_label: bool = True,
37
+ reduced_number: Optional[int] = None,
38
+ return_meta: bool = False,
39
+ ):
40
+ super().__init__(
41
+ subset=subset,
42
+ transform=transform,
43
+ return_label=return_label,
44
+ return_meta=return_meta,
45
+ )
46
+ datasets = self._init_datasets(
47
+ asvspoof_path=asvspoof_path,
48
+ wavefake_path=wavefake_path,
49
+ fakeavceleb_path=fakeavceleb_path,
50
+ asvspoof2019_path=asvspoof2019_path,
51
+ MLAADv3_path=MLAADv3_path,
52
+ MAILABS_path=MAILABS_path,
53
+ AIHUB_path=AIHUB_path,
54
+ KoAAD_path=KoAAD_path,
55
+ subset=subset,
56
+ )
57
+ self.samples = pd.concat([ds.samples for ds in datasets], ignore_index=True)
58
+
59
+ if oversample:
60
+ self.oversample_dataset()
61
+ elif undersample:
62
+ self.undersample_dataset()
63
+
64
+ if reduced_number:
65
+ LOGGER.info(f"Using reduced number of samples - {reduced_number}!")
66
+ self.samples = self.samples.sample(
67
+ min(len(self.samples), reduced_number),
68
+ random_state=42,
69
+ )
70
+
71
+ def _init_datasets(
72
+ self,
73
+ subset: str,
74
+ asvspoof_path: Optional[str],
75
+ wavefake_path: Optional[str],
76
+ fakeavceleb_path: Optional[str],
77
+ asvspoof2019_path: Optional[str],
78
+ MLAADv3_path=Optional[str],
79
+ MAILABS_path=Optional[str],
80
+ AIHUB_path=Optional[str],
81
+ KoAAD_path=Optional[str],
82
+ ) -> List[SimpleAudioFakeDataset]:
83
+ datasets = []
84
+
85
+ if asvspoof_path is not None:
86
+ asvspoof_dataset = DeepFakeASVSpoofDataset(asvspoof_path, subset=subset)
87
+ datasets.append(asvspoof_dataset)
88
+
89
+ if wavefake_path is not None:
90
+ wavefake_dataset = WaveFakeDataset(wavefake_path, subset=subset)
91
+ datasets.append(wavefake_dataset)
92
+
93
+ if fakeavceleb_path is not None:
94
+ fakeavceleb_dataset = FakeAVCelebDataset(fakeavceleb_path, subset=subset)
95
+ datasets.append(fakeavceleb_dataset)
96
+
97
+ if asvspoof2019_path is not None:
98
+ la_dataset = ASVSpoof2019DatasetOriginal(
99
+ asvspoof2019_path, fold_subset=subset
100
+ )
101
+ datasets.append(la_dataset)
102
+
103
+ if MLAADv3_path is not None:
104
+ MLAADv3_dataset = MLAADv3(MLAADv3_path, subset=subset)
105
+ datasets.append(MLAADv3_dataset)
106
+
107
+ if MAILABS_path is not None:
108
+ MAILABS_dataset = MAILABS(MAILABS_path, subset=subset)
109
+ datasets.append(MAILABS_dataset)
110
+
111
+ if AIHUB_path is not None:
112
+ aihub_dataset = AIHUB(AIHUB_path, subset=subset)
113
+ datasets.append(aihub_dataset)
114
+
115
+ if KoAAD_path is not None:
116
+ KoAAD_dataset = KoAAD(KoAAD_path, subset=subset)
117
+ datasets.append(KoAAD_dataset)
118
+ return datasets
119
+
120
+ def oversample_dataset(self):
121
+ samples = self.samples.groupby(by=["label"])
122
+ bona_length = len(samples.groups["bonafide"])
123
+ spoof_length = len(samples.groups["spoof"])
124
+
125
+ diff_length = spoof_length - bona_length
126
+
127
+ if diff_length < 0:
128
+ raise NotImplementedError
129
+
130
+ if diff_length > 0:
131
+ bonafide = samples.get_group("bonafide").sample(diff_length, replace=True)
132
+ self.samples = pd.concat([self.samples, bonafide], ignore_index=True)
133
+
134
+ def undersample_dataset(self):
135
+ samples = self.samples.groupby(by=["label"])
136
+ bona_length = len(samples.groups["bonafide"])
137
+ spoof_length = len(samples.groups["spoof"])
138
+
139
+ if spoof_length < bona_length:
140
+ raise NotImplementedError
141
+
142
+ if spoof_length > bona_length:
143
+ spoofs = samples.get_group("spoof").sample(bona_length, replace=True)
144
+ self.samples = pd.concat(
145
+ [samples.get_group("bonafide"), spoofs], ignore_index=True
146
+ )
147
+
148
+ def get_bonafide_only(self):
149
+ samples = self.samples.groupby(by=["label"])
150
+ self.samples = samples.get_group("bonafide")
151
+ return self.samples
152
+
153
+ def get_spoof_only(self):
154
+ samples = self.samples.groupby(by=["label"])
155
+ self.samples = samples.get_group("spoof")
156
+ return self.samples
src/datasets/fakeavceleb_dataset.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ import pandas as pd
4
+
5
+ from src.datasets.base_dataset import SimpleAudioFakeDataset
6
+
7
+ FAKEAVCELEB_SPLIT = {
8
+ "train": ['faceswap-wav2lip', 'fsgan-wav2lip', 'wav2lip', 'rtvc'],
9
+ "test": ['faceswap-wav2lip', 'fsgan-wav2lip', 'wav2lip', 'rtvc'],
10
+ "val": ['faceswap-wav2lip', 'fsgan-wav2lip', 'wav2lip', 'rtvc'],
11
+ "partition_ratio": [0.7, 0.15],
12
+ "seed": 45
13
+ }
14
+
15
+
16
+ class FakeAVCelebDataset(SimpleAudioFakeDataset):
17
+
18
+ audio_folder = "FakeAVCeleb-audio"
19
+ audio_extension = ".mp3"
20
+ metadata_file = Path(audio_folder) / "meta_data.csv"
21
+ subsets = ("train", "dev", "eval")
22
+
23
+ def __init__(self, path, subset="train", transform=None):
24
+ super().__init__(subset, transform)
25
+ self.path = path
26
+
27
+ self.subset = subset
28
+ self.allowed_attacks = FAKEAVCELEB_SPLIT[subset]
29
+ self.partition_ratio = FAKEAVCELEB_SPLIT["partition_ratio"]
30
+ self.seed = FAKEAVCELEB_SPLIT["seed"]
31
+
32
+ self.metadata = self.get_metadata()
33
+
34
+ self.samples = pd.concat([self.get_fake_samples(), self.get_real_samples()], ignore_index=True)
35
+
36
+ def get_metadata(self):
37
+ md = pd.read_csv(Path(self.path) / self.metadata_file)
38
+ md["audio_type"] = md["type"].apply(lambda x: x.split("-")[-1])
39
+ return md
40
+
41
+ def get_fake_samples(self):
42
+ samples = {
43
+ "user_id": [],
44
+ "sample_name": [],
45
+ "attack_type": [],
46
+ "label": [],
47
+ "path": []
48
+ }
49
+
50
+ for attack_name in self.allowed_attacks:
51
+ fake_samples = self.metadata[
52
+ (self.metadata["method"] == attack_name) & (self.metadata["audio_type"] == "FakeAudio")
53
+ ]
54
+
55
+ samples_list = fake_samples.iterrows()
56
+ samples_list = self.split_samples(samples_list)
57
+
58
+ for _, sample in samples_list:
59
+ samples["user_id"].append(sample["source"])
60
+ samples["sample_name"].append(Path(sample["filename"]).stem)
61
+ samples["attack_type"].append(sample["method"])
62
+ samples["label"].append("spoof")
63
+ samples["path"].append(self.get_file_path(sample))
64
+
65
+ return pd.DataFrame(samples)
66
+
67
+ def get_real_samples(self):
68
+ samples = {
69
+ "user_id": [],
70
+ "sample_name": [],
71
+ "attack_type": [],
72
+ "label": [],
73
+ "path": []
74
+ }
75
+
76
+ samples_list = self.metadata[
77
+ (self.metadata["method"] == "real") & (self.metadata["audio_type"] == "RealAudio")
78
+ ]
79
+
80
+ samples_list = self.split_samples(samples_list)
81
+
82
+ for index, sample in samples_list.iterrows():
83
+ samples["user_id"].append(sample["source"])
84
+ samples["sample_name"].append(Path(sample["filename"]).stem)
85
+ samples["attack_type"].append("-")
86
+ samples["label"].append("bonafide")
87
+ samples["path"].append(self.get_file_path(sample))
88
+
89
+ return pd.DataFrame(samples)
90
+
91
+ def get_file_path(self, sample):
92
+ path = "/".join([self.audio_folder, *sample["path"].split("/")[1:]])
93
+ return Path(self.path) / path / Path(sample["filename"]).with_suffix(self.audio_extension)
94
+
src/datasets/in_the_wild_dataset.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pandas as pd
3
+ from pathlib import Path
4
+
5
+ from src.datasets.base_dataset import SimpleAudioFakeDataset
6
+
7
+
8
+ class InTheWildDataset(SimpleAudioFakeDataset):
9
+
10
+ def __init__(
11
+ self,
12
+ path,
13
+ subset="train",
14
+ transform=None,
15
+ seed=None,
16
+ partition_ratio=(0.7, 0.15),
17
+ split_strategy="random"
18
+ ):
19
+ super().__init__(subset=subset, transform=transform)
20
+ self.path = path
21
+ self.read_samples()
22
+ self.partition_ratio = partition_ratio
23
+ self.seed = seed
24
+
25
+
26
+ def read_samples(self):
27
+ path = Path(self.path)
28
+ meta_path = path / "meta.csv"
29
+
30
+ self.samples = pd.read_csv(meta_path)
31
+ self.samples["path"] = self.samples["file"].apply(lambda n: str(path / n))
32
+ self.samples["file"] = self.samples["file"].apply(lambda n: Path(n).stem)
33
+ self.samples["label"] = self.samples["label"].map({"bona-fide": "bonafide", "spoof": "spoof"})
34
+ self.samples["attack_type"] = self.samples["label"].map({"bonafide": "-", "spoof": "X"})
35
+ self.samples.rename(columns={'file': 'sample_name', 'speaker': 'user_id'}, inplace=True)
36
+
37
+
38
+ def split_samples_per_speaker(self, samples):
39
+ speaker_list = pd.Series(samples["user_id"].unique())
40
+ speaker_list = speaker_list.sort_values()
41
+ speaker_list = speaker_list.sample(frac=1, random_state=self.seed)
42
+ speaker_list = list(speaker_list)
43
+
44
+ p, s = self.partition_ratio
45
+ subsets = np.split(speaker_list, [int(p * len(speaker_list)), int((p + s) * len(speaker_list))])
46
+ speaker_subset = dict(zip(['train', 'test', 'val'], subsets))[self.subset]
47
+ return self.samples[self.samples["user_id"].isin(speaker_subset)]
48
+
49
+
50
+ if __name__ == "__main__":
51
+ dataset = InTheWildDataset(
52
+ path="../datasets/release_in_the_wild",
53
+ subset="val",
54
+ seed=242,
55
+ split_strategy="per_speaker"
56
+ )
57
+
58
+ print(len(dataset))
59
+ print(len(dataset.samples["user_id"].unique()))
60
+ print(dataset.samples["user_id"].unique())
61
+
62
+ print(dataset[0])
src/datasets/wavefake_dataset.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ import pandas as pd
4
+
5
+ from src.datasets.base_dataset import SimpleAudioFakeDataset
6
+
7
+ WAVEFAKE_SPLIT = {
8
+ "train": ['multi_band_melgan', 'melgan_large', 'parallel_wavegan', 'waveglow', 'full_band_melgan', 'melgan', 'hifiGAN'],
9
+ "test": ['multi_band_melgan', 'melgan_large', 'parallel_wavegan', 'waveglow', 'full_band_melgan', 'melgan', 'hifiGAN'],
10
+ "val": ['multi_band_melgan', 'melgan_large', 'parallel_wavegan', 'waveglow', 'full_band_melgan', 'melgan', 'hifiGAN'],
11
+ "partition_ratio": [0.7, 0.15],
12
+ "seed": 45
13
+ }
14
+
15
+
16
+ class WaveFakeDataset(SimpleAudioFakeDataset):
17
+
18
+ fake_data_path = "generated_audio"
19
+ jsut_real_data_path = "real_audio/jsut_ver1.1/basic5000/wav"
20
+ ljspeech_real_data_path = "real_audio/LJSpeech-1.1/wavs"
21
+
22
+ def __init__(self, path, subset="train", transform=None):
23
+ super().__init__(subset, transform)
24
+ self.path = Path(path)
25
+
26
+ self.fold_subset = subset
27
+ self.allowed_attacks = WAVEFAKE_SPLIT[subset]
28
+ self.partition_ratio = WAVEFAKE_SPLIT["partition_ratio"]
29
+ self.seed = WAVEFAKE_SPLIT["seed"]
30
+
31
+ self.samples = pd.concat([self.get_fake_samples(), self.get_real_samples()], ignore_index=True)
32
+
33
+ def get_fake_samples(self):
34
+ samples = {
35
+ "user_id": [],
36
+ "sample_name": [],
37
+ "attack_type": [],
38
+ "label": [],
39
+ "path": []
40
+ }
41
+
42
+ samples_list = list((self.path / self.fake_data_path).glob("*/*.wav"))
43
+ samples_list = self.filter_samples_by_attack(samples_list)
44
+ samples_list = self.split_samples(samples_list)
45
+
46
+ for sample in samples_list:
47
+ samples["user_id"].append(None)
48
+ samples["sample_name"].append("_".join(sample.stem.split("_")[:-1]))
49
+ samples["attack_type"].append(self.get_attack_from_path(sample))
50
+ samples["label"].append("spoof")
51
+ samples["path"].append(sample)
52
+
53
+ return pd.DataFrame(samples)
54
+
55
+ def filter_samples_by_attack(self, samples_list):
56
+ return [s for s in samples_list if self.get_attack_from_path(s) in self.allowed_attacks]
57
+
58
+ def get_real_samples(self):
59
+ samples = {
60
+ "user_id": [],
61
+ "sample_name": [],
62
+ "attack_type": [],
63
+ "label": [],
64
+ "path": []
65
+ }
66
+
67
+ samples_list = list((self.path / self.jsut_real_data_path).glob("*.wav"))
68
+ samples_list += list((self.path / self.ljspeech_real_data_path).glob("*.wav"))
69
+ samples_list = self.split_samples(samples_list)
70
+
71
+ for sample in samples_list:
72
+ samples["user_id"].append(None)
73
+ samples["sample_name"].append(sample.stem)
74
+ samples["attack_type"].append("-")
75
+ samples["label"].append("bonafide")
76
+ samples["path"].append(sample)
77
+
78
+ return pd.DataFrame(samples)
79
+
80
+ @staticmethod
81
+ def get_attack_from_path(path):
82
+ folder_name = path.parents[0].relative_to(path.parents[1])
83
+ return str(folder_name).split("_", maxsplit=1)[-1]
84
+
85
+
src/frontends.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Union, Callable
2
+
3
+ import torch
4
+ import torchaudio
5
+
6
+ SAMPLING_RATE = 16_000
7
+ win_length = 400 # int((25 / 1_000) * SAMPLING_RATE)
8
+ hop_length = 160 # int((10 / 1_000) * SAMPLING_RATE)
9
+
10
+ device = "cuda" if torch.cuda.is_available() else "cpu"
11
+
12
+ MFCC_FN = torchaudio.transforms.MFCC(
13
+ sample_rate=SAMPLING_RATE,
14
+ n_mfcc=128,
15
+ melkwargs={
16
+ "n_fft": 512,
17
+ "win_length": win_length,
18
+ "hop_length": hop_length,
19
+ },
20
+ ).to(device)
21
+
22
+
23
+ LFCC_FN = torchaudio.transforms.LFCC(
24
+ sample_rate=SAMPLING_RATE,
25
+ n_lfcc=128,
26
+ speckwargs={
27
+ "n_fft": 512,
28
+ "win_length": win_length,
29
+ "hop_length": hop_length,
30
+ },
31
+ ).to(device)
32
+
33
+ MEL_SCALE_FN = torchaudio.transforms.MelScale(
34
+ n_mels=80,
35
+ n_stft=257,
36
+ sample_rate=SAMPLING_RATE,
37
+ ).to(device)
38
+
39
+ delta_fn = torchaudio.transforms.ComputeDeltas(
40
+ win_length=400,
41
+ mode="replicate",
42
+ )
43
+
44
+
45
+ def get_frontend(
46
+ frontends: List[str],
47
+ ) -> Union[torchaudio.transforms.MFCC, torchaudio.transforms.LFCC, Callable,]:
48
+ if "mfcc" in frontends:
49
+ return prepare_mfcc_double_delta
50
+ elif "lfcc" in frontends:
51
+ return prepare_lfcc_double_delta
52
+ raise ValueError(f"{frontends} frontend is not supported!")
53
+
54
+
55
+ def prepare_lfcc_double_delta(input):
56
+ if input.ndim < 4:
57
+ input = input.unsqueeze(1) # (bs, 1, n_lfcc, frames)
58
+ x = LFCC_FN(input)
59
+ delta = delta_fn(x)
60
+ double_delta = delta_fn(delta)
61
+ x = torch.cat((x, delta, double_delta), 2) # -> [bs, 1, 128 * 3, 1500]
62
+ return x[:, :, :, :512] # (bs, n, n_lfcc * 3, frames)
63
+
64
+
65
+ def prepare_mfcc_double_delta(input):
66
+ if input.ndim < 4:
67
+ input = input.unsqueeze(1) # (bs, 1, n_lfcc, frames)
68
+ x = MFCC_FN(input)
69
+ delta = delta_fn(x)
70
+ double_delta = delta_fn(delta)
71
+ x = torch.cat((x, delta, double_delta), 2) # -> [bs, 1, 128 * 3, 1500]
72
+ return x[:, :, :, :512] # (bs, n, n_lfcc * 3, frames)
src/metrics.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple
2
+
3
+ import numpy as np
4
+ from scipy.interpolate import interp1d
5
+ from scipy.optimize import brentq
6
+ from sklearn.metrics import roc_curve
7
+ from sklearn.metrics import roc_curve
8
+
9
+
10
+ def calculate_eer(y, y_score) -> Tuple[float, float, np.ndarray, np.ndarray]:
11
+ fpr, tpr, thresholds = roc_curve(y, -y_score)
12
+
13
+ eer = brentq(lambda x: 1.0 - x - interp1d(fpr, tpr)(x), 0.0, 1.0)
14
+ thresh = interp1d(fpr, thresholds)(eer)
15
+ return thresh, eer, fpr, tpr
src/trainer.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """A generic training wrapper."""
2
+ from copy import deepcopy
3
+ import logging
4
+ from typing import Callable, List, Optional
5
+
6
+ import torch
7
+ from torch.utils.data import DataLoader
8
+
9
+
10
+ LOGGER = logging.getLogger(__name__)
11
+
12
+
13
+ class Trainer:
14
+ def __init__(
15
+ self,
16
+ epochs: int = 20,
17
+ batch_size: int = 32,
18
+ device: str = "cpu",
19
+ optimizer_fn: Callable = torch.optim.Adam,
20
+ optimizer_kwargs: dict = {"lr": 1e-3},
21
+ use_scheduler: bool = False,
22
+ ) -> None:
23
+ self.epochs = epochs
24
+ self.batch_size = batch_size
25
+ self.device = device
26
+ self.optimizer_fn = optimizer_fn
27
+ self.optimizer_kwargs = optimizer_kwargs
28
+ self.epoch_test_losses: List[float] = []
29
+ self.use_scheduler = use_scheduler
30
+
31
+
32
+ def forward_and_loss(model, criterion, batch_x, batch_y, **kwargs):
33
+ batch_out = model(batch_x)
34
+ batch_loss = criterion(batch_out, batch_y)
35
+ return batch_out, batch_loss
36
+
37
+
38
+ class GDTrainer(Trainer):
39
+ def train(
40
+ self,
41
+ dataset: torch.utils.data.Dataset,
42
+ model: torch.nn.Module,
43
+ test_len: Optional[float] = None,
44
+ test_dataset: Optional[torch.utils.data.Dataset] = None,
45
+ ):
46
+ if test_dataset is not None:
47
+ train = dataset
48
+ test = test_dataset
49
+ else:
50
+ test_len = int(len(dataset) * test_len)
51
+ train_len = len(dataset) - test_len
52
+ lengths = [train_len, test_len]
53
+ train, test = torch.utils.data.random_split(dataset, lengths)
54
+
55
+ train_loader = DataLoader(
56
+ train,
57
+ batch_size=self.batch_size,
58
+ shuffle=True,
59
+ drop_last=True,
60
+ num_workers=6,
61
+ )
62
+ test_loader = DataLoader(
63
+ test,
64
+ batch_size=self.batch_size,
65
+ shuffle=True,
66
+ drop_last=True,
67
+ num_workers=6,
68
+ )
69
+
70
+ criterion = torch.nn.BCEWithLogitsLoss()
71
+ optim = self.optimizer_fn(model.parameters(), **self.optimizer_kwargs)
72
+
73
+ best_model = None
74
+ best_acc = 0
75
+
76
+ LOGGER.info(f"Starting training for {self.epochs} epochs!")
77
+
78
+ forward_and_loss_fn = forward_and_loss
79
+
80
+ if self.use_scheduler:
81
+ batches_per_epoch = len(train_loader) * 2 # every 2nd epoch
82
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
83
+ optimizer=optim,
84
+ T_0=batches_per_epoch,
85
+ T_mult=1,
86
+ eta_min=5e-6,
87
+ # verbose=True,
88
+ )
89
+ use_cuda = self.device != "cpu"
90
+
91
+ for epoch in range(self.epochs):
92
+ LOGGER.info(f"Epoch num: {epoch}")
93
+
94
+ running_loss = 0
95
+ num_correct = 0.0
96
+ num_total = 0.0
97
+ model.train()
98
+
99
+ for i, (batch_x, _, batch_y) in enumerate(train_loader):
100
+ batch_size = batch_x.size(0)
101
+ num_total += batch_size
102
+ batch_x = batch_x.to(self.device)
103
+
104
+ batch_y = batch_y.unsqueeze(1).type(torch.float32).to(self.device)
105
+
106
+ batch_out, batch_loss = forward_and_loss_fn(
107
+ model, criterion, batch_x, batch_y, use_cuda=use_cuda
108
+ )
109
+ batch_pred = (torch.sigmoid(batch_out) + 0.5).int()
110
+ num_correct += (batch_pred == batch_y.int()).sum(dim=0).item()
111
+
112
+ running_loss += batch_loss.item() * batch_size
113
+
114
+ if i % 100 == 0:
115
+ LOGGER.info(
116
+ f"[{epoch:04d}][{i:05d}]: {running_loss / num_total} {num_correct/num_total*100}"
117
+ )
118
+
119
+ optim.zero_grad()
120
+ batch_loss.backward()
121
+ optim.step()
122
+ if self.use_scheduler:
123
+ scheduler.step()
124
+
125
+ running_loss /= num_total
126
+ train_accuracy = (num_correct / num_total) * 100
127
+
128
+ LOGGER.info(
129
+ f"Epoch [{epoch+1}/{self.epochs}]: train/loss: {running_loss}, train/accuracy: {train_accuracy}"
130
+ )
131
+
132
+ test_running_loss = 0.0
133
+ num_correct = 0.0
134
+ num_total = 0.0
135
+ model.eval()
136
+ eer_val = 0
137
+
138
+ for batch_x, _, batch_y in test_loader:
139
+ batch_size = batch_x.size(0)
140
+ num_total += batch_size
141
+ batch_x = batch_x.to(self.device)
142
+
143
+ with torch.no_grad():
144
+ batch_pred = model(batch_x)
145
+
146
+ batch_y = batch_y.unsqueeze(1).type(torch.float32).to(self.device)
147
+ batch_loss = criterion(batch_pred, batch_y)
148
+
149
+ test_running_loss += batch_loss.item() * batch_size
150
+
151
+ batch_pred = torch.sigmoid(batch_pred)
152
+ batch_pred_label = (batch_pred + 0.5).int()
153
+ num_correct += (batch_pred_label == batch_y.int()).sum(dim=0).item()
154
+
155
+ if num_total == 0:
156
+ num_total = 1
157
+
158
+ test_running_loss /= num_total
159
+ test_acc = 100 * (num_correct / num_total)
160
+ LOGGER.info(
161
+ f"Epoch [{epoch+1}/{self.epochs}]: test/loss: {test_running_loss}, test/accuracy: {test_acc}, test/eer: {eer_val}"
162
+ )
163
+
164
+ if best_model is None or test_acc > best_acc:
165
+ best_acc = test_acc
166
+ best_model = deepcopy(model.state_dict())
167
+
168
+ LOGGER.info(
169
+ f"[{epoch:04d}]: {running_loss} - train acc: {train_accuracy} - test_acc: {test_acc}"
170
+ )
171
+
172
+ model.load_state_dict(best_model)
173
+ return model