Spaces:
Runtime error
Runtime error
Upload 28 files
Browse files- src/__init__.py +3 -0
- src/commons.py +22 -0
- src/datasets/KoAAD_dataset.py +45 -0
- src/datasets/MAILABS_dataset.py +50 -0
- src/datasets/MLAADv3_dataset.py +47 -0
- src/datasets/__init__.py +0 -0
- src/datasets/aihub_dataset.py +40 -0
- src/datasets/asvspoof_dataset.py +155 -0
- src/datasets/base_dataset.py +172 -0
- src/datasets/deepfake_asvspoof_dataset.py +86 -0
- src/datasets/detection_dataset.py +156 -0
- src/datasets/fakeavceleb_dataset.py +94 -0
- src/datasets/in_the_wild_dataset.py +62 -0
- src/datasets/wavefake_dataset.py +85 -0
- src/frontends.py +72 -0
- src/metrics.py +15 -0
- src/trainer.py +173 -0
src/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
|
3 |
+
logging.getLogger(__name__).addHandler(logging.NullHandler())
|
src/commons.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Utility file for src toolkit."""
|
2 |
+
import os
|
3 |
+
import random
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
|
8 |
+
WHISPER_MODEL_WEIGHTS_PATH = "src/models/assets/tiny_enc.en.pt"
|
9 |
+
|
10 |
+
|
11 |
+
def set_seed(seed: int):
|
12 |
+
"""Fix PRNG seed for reproducable experiments.
|
13 |
+
"""
|
14 |
+
random.seed(seed)
|
15 |
+
np.random.seed(seed)
|
16 |
+
torch.manual_seed(seed)
|
17 |
+
if torch.cuda.is_available():
|
18 |
+
torch.cuda.manual_seed(seed)
|
19 |
+
torch.cuda.manual_seed_all(seed)
|
20 |
+
torch.backends.cudnn.deterministic = True
|
21 |
+
torch.backends.cudnn.benchmark = False
|
22 |
+
os.environ["PYTHONHASHSEED"] = str(seed)
|
src/datasets/KoAAD_dataset.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from src.datasets.base_dataset import SimpleAudioFakeDataset
|
2 |
+
import pandas as pd
|
3 |
+
from pathlib import Path
|
4 |
+
import os
|
5 |
+
|
6 |
+
class KoAAD(SimpleAudioFakeDataset):
|
7 |
+
def __init__(self, root_path, subset=None, **kwargs):
|
8 |
+
super().__init__(root_path, subset, **kwargs)
|
9 |
+
self.root_path = Path(f'{root_path}')
|
10 |
+
self.subset = subset
|
11 |
+
self.samples = self.load_samples()
|
12 |
+
|
13 |
+
def load_samples(self):
|
14 |
+
samples = {
|
15 |
+
"user_id": [],
|
16 |
+
"sample_name": [],
|
17 |
+
"attack_type": [],
|
18 |
+
"label": [],
|
19 |
+
"path": []
|
20 |
+
}
|
21 |
+
|
22 |
+
folders_1 = list(self.root_path.glob("*"))
|
23 |
+
for f1 in folders_1:
|
24 |
+
if not os.path.isdir(f1):
|
25 |
+
continue
|
26 |
+
|
27 |
+
if not f1.exists():
|
28 |
+
print(f"{path} 경로를 찾을 수 없습니다.")
|
29 |
+
|
30 |
+
samples_list = list(f1.rglob("*.[wm][ap][v3]"))
|
31 |
+
if self.subset == 'train':
|
32 |
+
samples_list = samples_list[:int(len(samples_list)*0.7)]
|
33 |
+
else:
|
34 |
+
samples_list = samples_list[int(len(samples_list)*0.7):]
|
35 |
+
for sample in samples_list:
|
36 |
+
if os.path.exists(sample):
|
37 |
+
samples["user_id"].append(None)
|
38 |
+
samples["path"].append(sample)
|
39 |
+
samples["sample_name"].append(sample.stem)
|
40 |
+
samples["attack_type"].append("-")
|
41 |
+
samples["label"].append("spoof")
|
42 |
+
|
43 |
+
print(f"KoAAD_{self.subset}:{len(samples['label'])}")
|
44 |
+
return pd.DataFrame(samples)
|
45 |
+
|
src/datasets/MAILABS_dataset.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from src.datasets.base_dataset import SimpleAudioFakeDataset
|
2 |
+
import pandas as pd
|
3 |
+
from pathlib import Path
|
4 |
+
import os
|
5 |
+
|
6 |
+
class MAILABS(SimpleAudioFakeDataset):
|
7 |
+
def __init__(self, root_path, subset=None, **kwargs):
|
8 |
+
super().__init__(root_path, subset, **kwargs)
|
9 |
+
self.root_path = Path(f'{root_path}')
|
10 |
+
self.subset = subset
|
11 |
+
self.samples = self.load_samples()
|
12 |
+
|
13 |
+
def load_samples(self):
|
14 |
+
samples = {
|
15 |
+
"user_id": [],
|
16 |
+
"sample_name": [],
|
17 |
+
"attack_type": [],
|
18 |
+
"label": [],
|
19 |
+
"path": []
|
20 |
+
}
|
21 |
+
split = [0.7, 0.3]
|
22 |
+
|
23 |
+
folders_1 = list(self.root_path.glob("en_US/by_book/*"))
|
24 |
+
for f1 in folders_1:
|
25 |
+
if not os.path.isdir(f1):
|
26 |
+
continue
|
27 |
+
folders_2 = list(f1.glob("*"))
|
28 |
+
for f2 in folders_2:
|
29 |
+
path = f1 / f2.name
|
30 |
+
|
31 |
+
if not path.exists():
|
32 |
+
print(f"{path} 경로를 찾을 수 없습니다.")
|
33 |
+
|
34 |
+
samples_list = list(path.rglob("*.wav"))
|
35 |
+
if self.subset == 'train':
|
36 |
+
samples_list = samples_list[:int(len(samples_list)*split[0])]
|
37 |
+
elif self.subset == 'test':
|
38 |
+
samples_list = samples_list[int(len(samples_list)*(split[0])):]
|
39 |
+
for sample in samples_list:
|
40 |
+
if sample.stem[0]==".":
|
41 |
+
continue
|
42 |
+
if os.path.exists(sample):
|
43 |
+
samples["user_id"].append(None)
|
44 |
+
samples["path"].append(sample)
|
45 |
+
samples["sample_name"].append(sample.stem)
|
46 |
+
samples["attack_type"].append("-")
|
47 |
+
samples["label"].append("bonafide")
|
48 |
+
print(f"MAILABS_{self.subset}:{len(samples['label'])}")
|
49 |
+
return pd.DataFrame(samples)
|
50 |
+
|
src/datasets/MLAADv3_dataset.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from src.datasets.base_dataset import SimpleAudioFakeDataset
|
2 |
+
import pandas as pd
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
class MLAADv3(SimpleAudioFakeDataset):
|
6 |
+
languages=['fr', 'et', 'ar', 'hu', 'bg', 'es', 'el', 'da', 'ga', 'ru', 'fi',
|
7 |
+
'uk', 'pl', 'en', 'sw', 'mt', 'sk', 'ro', 'hi', 'cs', 'nl', 'it', 'de']
|
8 |
+
|
9 |
+
def __init__(self, root_path, subset=None, **kwargs):
|
10 |
+
super().__init__(root_path, subset, **kwargs)
|
11 |
+
self.root_path = Path(f'{root_path}')
|
12 |
+
self.subset = subset
|
13 |
+
self.samples = self.load_samples()
|
14 |
+
|
15 |
+
def load_samples(self):
|
16 |
+
samples = {
|
17 |
+
"user_id": [],
|
18 |
+
"language" : [],
|
19 |
+
"sample_name": [],
|
20 |
+
"attack_type": [],
|
21 |
+
"label": [],
|
22 |
+
"path": []
|
23 |
+
}
|
24 |
+
|
25 |
+
for lang in self.languages:
|
26 |
+
r_path = self.root_path / f"fake/{lang}"
|
27 |
+
folders = list(r_path.glob("*"))
|
28 |
+
for folder in folders:
|
29 |
+
path = r_path / folder.name
|
30 |
+
|
31 |
+
if not path.exists():
|
32 |
+
print(f"{path} 경로를 찾을 수 없습니다.")
|
33 |
+
continue
|
34 |
+
samples_list = list(path.rglob("*.wav"))
|
35 |
+
if self.subset == 'train':
|
36 |
+
samples_list = samples_list[:int(len(samples_list)*0.7)]
|
37 |
+
else:
|
38 |
+
samples_list = samples_list[int(len(samples_list)*0.7):]
|
39 |
+
for sample in samples_list:
|
40 |
+
samples["user_id"].append(None)
|
41 |
+
samples["language"].append(lang)
|
42 |
+
samples["path"].append(sample)
|
43 |
+
samples["sample_name"].append(sample.stem)
|
44 |
+
samples["attack_type"].append("-")
|
45 |
+
samples["label"].append("spoof")
|
46 |
+
print(f"__MLAADv3_{self.subset}:{len(samples['label'])}")
|
47 |
+
return pd.DataFrame(samples)
|
src/datasets/__init__.py
ADDED
File without changes
|
src/datasets/aihub_dataset.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from src.datasets.base_dataset import SimpleAudioFakeDataset
|
2 |
+
import pandas as pd
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
class AIHUB(SimpleAudioFakeDataset):
|
6 |
+
def __init__(self, root_path, subset=None, **kwargs):
|
7 |
+
super().__init__(root_path, subset, **kwargs)
|
8 |
+
self.root_path = Path(f'{root_path}')
|
9 |
+
self.subset = subset
|
10 |
+
self.samples = self.load_samples()
|
11 |
+
|
12 |
+
def load_samples(self):
|
13 |
+
samples = {
|
14 |
+
"user_id": [],
|
15 |
+
"sample_name": [],
|
16 |
+
"attack_type": [],
|
17 |
+
"label": [],
|
18 |
+
"path": []
|
19 |
+
}
|
20 |
+
|
21 |
+
path = self.root_path / ""
|
22 |
+
|
23 |
+
# 해당 언어의 디렉토리가 존재하는지 확인
|
24 |
+
if not path.exists():
|
25 |
+
print(f"{path} 경로를 찾을 수 없습니다.")
|
26 |
+
|
27 |
+
samples_list = list(path.rglob("*.wav"))
|
28 |
+
if self.subset == 'train':
|
29 |
+
samples_list = samples_list[:int(len(samples_list)*0.7)]
|
30 |
+
else:
|
31 |
+
samples_list = samples_list[int(len(samples_list)*0.7):]
|
32 |
+
for sample in samples_list:
|
33 |
+
samples["user_id"].append(None)
|
34 |
+
samples["path"].append(sample)
|
35 |
+
samples["sample_name"].append(sample.stem)
|
36 |
+
samples["attack_type"].append("-")
|
37 |
+
samples["label"].append("bonafide")
|
38 |
+
print(f"__AIHUB_{self.subset}:{len(samples['label'])}")
|
39 |
+
return pd.DataFrame(samples)
|
40 |
+
|
src/datasets/asvspoof_dataset.py
ADDED
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
|
3 |
+
import pandas as pd
|
4 |
+
if __name__ == "__main__":
|
5 |
+
import sys
|
6 |
+
sys.path.append(str(Path(__file__).parent.parent.parent.absolute()))
|
7 |
+
|
8 |
+
from src.datasets.base_dataset import SimpleAudioFakeDataset
|
9 |
+
|
10 |
+
ASVSPOOF_SPLIT = {
|
11 |
+
"train": ['A01', 'A07', 'A08', 'A02', 'A09', 'A10', 'A03', 'A04', 'A05', 'A06', 'A11', 'A12', 'A13', 'A14', 'A15', 'A16', 'A17', 'A18', 'A19'],
|
12 |
+
"test": ['A01', 'A07', 'A08', 'A02', 'A09', 'A10', 'A03', 'A04', 'A05', 'A06', 'A11', 'A12', 'A13', 'A14', 'A15', 'A16', 'A17', 'A18', 'A19'],
|
13 |
+
"val": ['A01', 'A07', 'A08', 'A02', 'A09', 'A10', 'A03', 'A04', 'A05', 'A06', 'A11', 'A12', 'A13', 'A14', 'A15', 'A16', 'A17', 'A18', 'A19'],
|
14 |
+
"partition_ratio": [0.7, 0.15],
|
15 |
+
"seed": 45,
|
16 |
+
}
|
17 |
+
|
18 |
+
|
19 |
+
class ASVSpoofDataset(SimpleAudioFakeDataset):
|
20 |
+
|
21 |
+
protocol_folder_name = "ASVspoof2019_LA_cm_protocols"
|
22 |
+
subset_dir_prefix = "ASVspoof2019_LA_"
|
23 |
+
subsets = ("train", "dev", "eval")
|
24 |
+
|
25 |
+
def __init__(self, path, subset="train", transform=None):
|
26 |
+
super().__init__(subset, transform)
|
27 |
+
self.path = path
|
28 |
+
|
29 |
+
self.allowed_attacks = ASVSPOOF_SPLIT[subset]
|
30 |
+
self.partition_ratio = ASVSPOOF_SPLIT["partition_ratio"]
|
31 |
+
self.seed = ASVSPOOF_SPLIT["seed"]
|
32 |
+
|
33 |
+
self.samples = pd.DataFrame()
|
34 |
+
|
35 |
+
for subset in self.subsets:
|
36 |
+
subset_dir = Path(self.path) / f"{self.subset_dir_prefix}{subset}"
|
37 |
+
subset_protocol_path = self.get_protocol_path(subset)
|
38 |
+
subset_samples = self.read_protocol(subset_dir, subset_protocol_path)
|
39 |
+
|
40 |
+
self.samples = pd.concat([self.samples, subset_samples])
|
41 |
+
|
42 |
+
self.transform = transform
|
43 |
+
|
44 |
+
def get_protocol_path(self, subset):
|
45 |
+
paths = list((Path(self.path) / self.protocol_folder_name).glob("*.txt"))
|
46 |
+
for path in paths:
|
47 |
+
if subset in Path(path).stem:
|
48 |
+
return path
|
49 |
+
|
50 |
+
def read_protocol(self, subset_dir, protocol_path):
|
51 |
+
samples = {
|
52 |
+
"user_id": [],
|
53 |
+
"sample_name": [],
|
54 |
+
"attack_type": [],
|
55 |
+
"label": [],
|
56 |
+
"path": []
|
57 |
+
}
|
58 |
+
|
59 |
+
real_samples = []
|
60 |
+
fake_samples = []
|
61 |
+
with open(protocol_path, "r") as file:
|
62 |
+
for line in file:
|
63 |
+
attack_type = line.strip().split(" ")[3]
|
64 |
+
|
65 |
+
if attack_type == "-":
|
66 |
+
real_samples.append(line)
|
67 |
+
elif attack_type in self.allowed_attacks:
|
68 |
+
fake_samples.append(line)
|
69 |
+
|
70 |
+
if attack_type not in self.allowed_attacks:
|
71 |
+
continue
|
72 |
+
|
73 |
+
fake_samples = self.split_samples(fake_samples)
|
74 |
+
for line in fake_samples:
|
75 |
+
samples = self.add_line_to_samples(samples, line, subset_dir)
|
76 |
+
|
77 |
+
real_samples = self.split_samples(real_samples)
|
78 |
+
for line in real_samples:
|
79 |
+
samples = self.add_line_to_samples(samples, line, subset_dir)
|
80 |
+
|
81 |
+
return pd.DataFrame(samples)
|
82 |
+
|
83 |
+
@staticmethod
|
84 |
+
def add_line_to_samples(samples, line, subset_dir):
|
85 |
+
user_id, sample_name, _, attack_type, label = line.strip().split(" ")
|
86 |
+
samples["user_id"].append(user_id)
|
87 |
+
samples["sample_name"].append(sample_name)
|
88 |
+
samples["attack_type"].append(attack_type)
|
89 |
+
samples["label"].append(label)
|
90 |
+
|
91 |
+
assert (subset_dir / "flac" / f"{sample_name}.flac").exists()
|
92 |
+
samples["path"].append(subset_dir / "flac" / f"{sample_name}.flac")
|
93 |
+
|
94 |
+
return samples
|
95 |
+
|
96 |
+
class ASVSpoof2019DatasetOriginal(ASVSpoofDataset):
|
97 |
+
|
98 |
+
subsets = {"train": "train", "test": "dev", "val": "eval"}
|
99 |
+
|
100 |
+
protocol_folder_name = "ASVspoof2019_LA_cm_protocols"
|
101 |
+
subset_dir_prefix = "ASVspoof2019_LA_"
|
102 |
+
subset_dirs_attacks = {
|
103 |
+
"train": ["A01", "A02", "A03", "A04", "A05", "A06"],
|
104 |
+
"dev": ["A01", "A02", "A03", "A04", "A05", "A06"],
|
105 |
+
"eval": [
|
106 |
+
"A07", "A08", "A09", "A10", "A11", "A12", "A13", "A14", "A15",
|
107 |
+
"A16", "A17", "A18", "A19"
|
108 |
+
]
|
109 |
+
}
|
110 |
+
|
111 |
+
|
112 |
+
def __init__(self, path, fold_subset="train"):
|
113 |
+
"""
|
114 |
+
Initialise object. Skip __init__ of ASVSpoofDataset doe to different
|
115 |
+
logic, but follow SimpleAudioFakeDataset constructor.
|
116 |
+
"""
|
117 |
+
super(ASVSpoofDataset, self).__init__(float('inf'), fold_subset)
|
118 |
+
self.path = path
|
119 |
+
subset = self.subsets[fold_subset]
|
120 |
+
self.allowed_attacks = self.subset_dirs_attacks[subset]
|
121 |
+
subset_dir = Path(self.path) / f"{self.subset_dir_prefix}{subset}"
|
122 |
+
subset_protocol_path = self.get_protocol_path(subset)
|
123 |
+
self.samples = self.read_protocol(subset_dir, subset_protocol_path)
|
124 |
+
|
125 |
+
def read_protocol(self, subset_dir, protocol_path):
|
126 |
+
samples = {
|
127 |
+
"user_id": [],
|
128 |
+
"sample_name": [],
|
129 |
+
"attack_type": [],
|
130 |
+
"label": [],
|
131 |
+
"path": []
|
132 |
+
}
|
133 |
+
|
134 |
+
real_samples = []
|
135 |
+
fake_samples = []
|
136 |
+
|
137 |
+
with open(protocol_path, "r") as file:
|
138 |
+
for line in file:
|
139 |
+
attack_type = line.strip().split(" ")[3]
|
140 |
+
if attack_type == "-":
|
141 |
+
real_samples.append(line)
|
142 |
+
elif attack_type in self.allowed_attacks:
|
143 |
+
fake_samples.append(line)
|
144 |
+
else:
|
145 |
+
raise ValueError(
|
146 |
+
"Tried to load attack that shouldn't be here!"
|
147 |
+
)
|
148 |
+
|
149 |
+
for line in fake_samples:
|
150 |
+
samples = self.add_line_to_samples(samples, line, subset_dir)
|
151 |
+
for line in real_samples:
|
152 |
+
samples = self.add_line_to_samples(samples, line, subset_dir)
|
153 |
+
|
154 |
+
return pd.DataFrame(samples)
|
155 |
+
|
src/datasets/base_dataset.py
ADDED
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Base dataset classes."""
|
2 |
+
import logging
|
3 |
+
import math
|
4 |
+
import random
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import pandas as pd
|
8 |
+
import torch
|
9 |
+
import torchaudio
|
10 |
+
from torch.utils.data import Dataset
|
11 |
+
from torch.utils.data.dataset import T_co
|
12 |
+
|
13 |
+
|
14 |
+
LOGGER = logging.getLogger(__name__)
|
15 |
+
|
16 |
+
SAMPLING_RATE = 16_000
|
17 |
+
APPLY_NORMALIZATION = True
|
18 |
+
APPLY_TRIMMING = True
|
19 |
+
APPLY_PADDING = True
|
20 |
+
FRAMES_NUMBER = 480_000 # <- originally 64_600
|
21 |
+
|
22 |
+
|
23 |
+
SOX_SILENCE = [
|
24 |
+
# trim all silence that is longer than 0.2s and louder than 1% volume (relative to the file)
|
25 |
+
# from beginning and middle/end
|
26 |
+
["silence", "1", "0.2", "1%", "-1", "0.2", "1%"],
|
27 |
+
]
|
28 |
+
|
29 |
+
|
30 |
+
class SimpleAudioFakeDataset(Dataset):
|
31 |
+
def __init__(
|
32 |
+
self,
|
33 |
+
subset,
|
34 |
+
transform=None,
|
35 |
+
return_label: bool = True,
|
36 |
+
return_meta: bool = False,
|
37 |
+
):
|
38 |
+
self.transform = transform
|
39 |
+
self.samples = pd.DataFrame()
|
40 |
+
|
41 |
+
self.subset = subset
|
42 |
+
self.allowed_attacks = None
|
43 |
+
self.partition_ratio = None
|
44 |
+
self.seed = None
|
45 |
+
self.return_label = return_label
|
46 |
+
self.return_meta = return_meta
|
47 |
+
|
48 |
+
def split_samples(self, samples_list):
|
49 |
+
if isinstance(samples_list, pd.DataFrame):
|
50 |
+
samples_list = samples_list.sort_values(by=list(samples_list.columns))
|
51 |
+
samples_list = samples_list.sample(frac=1, random_state=self.seed)
|
52 |
+
else:
|
53 |
+
samples_list = sorted(samples_list)
|
54 |
+
random.seed(self.seed)
|
55 |
+
random.shuffle(samples_list)
|
56 |
+
|
57 |
+
p, s = self.partition_ratio
|
58 |
+
subsets = np.split(
|
59 |
+
samples_list, [int(p * len(samples_list)), int((p + s) * len(samples_list))]
|
60 |
+
)
|
61 |
+
return dict(zip(["train", "test", "val"], subsets))[self.subset]
|
62 |
+
|
63 |
+
def df2tuples(self):
|
64 |
+
tuple_samples = []
|
65 |
+
for i, elem in self.samples.iterrows():
|
66 |
+
tuple_samples.append(
|
67 |
+
(str(elem["path"]), elem["label"], elem["attack_type"])
|
68 |
+
)
|
69 |
+
|
70 |
+
self.samples = tuple_samples
|
71 |
+
return self.samples
|
72 |
+
|
73 |
+
def __getitem__(self, index) -> T_co:
|
74 |
+
if isinstance(self.samples, pd.DataFrame):
|
75 |
+
sample = self.samples.iloc[index]
|
76 |
+
|
77 |
+
path = str(sample["path"])
|
78 |
+
label = sample["label"]
|
79 |
+
attack_type = sample["attack_type"]
|
80 |
+
if type(attack_type) != str and math.isnan(attack_type):
|
81 |
+
attack_type = "N/A"
|
82 |
+
else:
|
83 |
+
path, label, attack_type = self.samples[index]
|
84 |
+
|
85 |
+
waveform, sample_rate = torchaudio.load(path, normalize=APPLY_NORMALIZATION)
|
86 |
+
real_sec_length = len(waveform[0]) / sample_rate
|
87 |
+
|
88 |
+
waveform, sample_rate = apply_preprocessing(waveform, sample_rate)
|
89 |
+
|
90 |
+
return_data = [waveform, sample_rate]
|
91 |
+
if self.return_label:
|
92 |
+
label = 1 if label == "bonafide" else 0
|
93 |
+
return_data.append(label)
|
94 |
+
|
95 |
+
if self.return_meta:
|
96 |
+
return_data.append(
|
97 |
+
(
|
98 |
+
attack_type,
|
99 |
+
path,
|
100 |
+
self.subset,
|
101 |
+
real_sec_length,
|
102 |
+
)
|
103 |
+
)
|
104 |
+
return return_data
|
105 |
+
|
106 |
+
def __len__(self):
|
107 |
+
return len(self.samples)
|
108 |
+
|
109 |
+
|
110 |
+
def apply_preprocessing(
|
111 |
+
waveform,
|
112 |
+
sample_rate,
|
113 |
+
):
|
114 |
+
if sample_rate != SAMPLING_RATE and SAMPLING_RATE != -1:
|
115 |
+
waveform, sample_rate = resample_wave(waveform, sample_rate, SAMPLING_RATE)
|
116 |
+
|
117 |
+
# Stereo to mono
|
118 |
+
if waveform.dim() > 1 and waveform.shape[0] > 1:
|
119 |
+
waveform = waveform[:1, ...]
|
120 |
+
|
121 |
+
# Trim too long utterances...
|
122 |
+
if APPLY_TRIMMING:
|
123 |
+
waveform, sample_rate = apply_trim(waveform, sample_rate)
|
124 |
+
|
125 |
+
# ... or pad too short ones.
|
126 |
+
if APPLY_PADDING:
|
127 |
+
waveform = apply_pad(waveform, FRAMES_NUMBER)
|
128 |
+
|
129 |
+
return waveform, sample_rate
|
130 |
+
|
131 |
+
|
132 |
+
def resample_wave(waveform, sample_rate, target_sample_rate):
|
133 |
+
waveform, sample_rate = torchaudio.sox_effects.apply_effects_tensor(
|
134 |
+
waveform, sample_rate, [["rate", f"{target_sample_rate}"]]
|
135 |
+
)
|
136 |
+
return waveform, sample_rate
|
137 |
+
|
138 |
+
|
139 |
+
def resample_file(path, target_sample_rate, normalize=True):
|
140 |
+
waveform, sample_rate = torchaudio.sox_effects.apply_effects_file(
|
141 |
+
path, [["rate", f"{target_sample_rate}"]], normalize=normalize
|
142 |
+
)
|
143 |
+
|
144 |
+
return waveform, sample_rate
|
145 |
+
|
146 |
+
|
147 |
+
def apply_trim(waveform, sample_rate):
|
148 |
+
(
|
149 |
+
waveform_trimmed,
|
150 |
+
sample_rate_trimmed,
|
151 |
+
) = torchaudio.sox_effects.apply_effects_tensor(waveform, sample_rate, SOX_SILENCE)
|
152 |
+
|
153 |
+
if waveform_trimmed.size()[1] > 0:
|
154 |
+
waveform = waveform_trimmed
|
155 |
+
sample_rate = sample_rate_trimmed
|
156 |
+
|
157 |
+
return waveform, sample_rate
|
158 |
+
|
159 |
+
|
160 |
+
def apply_pad(waveform, cut):
|
161 |
+
"""Pad wave by repeating signal until `cut` length is achieved."""
|
162 |
+
waveform = waveform.squeeze(0)
|
163 |
+
waveform_len = waveform.shape[0]
|
164 |
+
|
165 |
+
if waveform_len >= cut:
|
166 |
+
return waveform[:cut]
|
167 |
+
|
168 |
+
# need to pad
|
169 |
+
num_repeats = int(cut / waveform_len) + 1
|
170 |
+
padded_waveform = torch.tile(waveform, (1, num_repeats))[:, :cut][0]
|
171 |
+
|
172 |
+
return padded_waveform
|
src/datasets/deepfake_asvspoof_dataset.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from pathlib import Path
|
3 |
+
|
4 |
+
import pandas as pd
|
5 |
+
|
6 |
+
from src.datasets.base_dataset import SimpleAudioFakeDataset
|
7 |
+
|
8 |
+
DF_ASVSPOOF_SPLIT = {
|
9 |
+
"partition_ratio": [0.7, 0.15],
|
10 |
+
"seed": 45
|
11 |
+
}
|
12 |
+
|
13 |
+
LOGGER = logging.getLogger()
|
14 |
+
|
15 |
+
class DeepFakeASVSpoofDataset(SimpleAudioFakeDataset):
|
16 |
+
|
17 |
+
protocol_file_name = "keys/CM/trial_metadata.txt"
|
18 |
+
subset_dir_prefix = "ASVspoof2021_DF_eval"
|
19 |
+
subset_parts = ("part00", "part01", "part02", "part03")
|
20 |
+
|
21 |
+
def __init__(self, path, subset="train", transform=None):
|
22 |
+
super().__init__(subset, transform)
|
23 |
+
self.path = path
|
24 |
+
|
25 |
+
self.partition_ratio = DF_ASVSPOOF_SPLIT["partition_ratio"]
|
26 |
+
self.seed = DF_ASVSPOOF_SPLIT["seed"]
|
27 |
+
|
28 |
+
self.flac_paths = self.get_file_references()
|
29 |
+
self.samples = self.read_protocol()
|
30 |
+
|
31 |
+
self.transform = transform
|
32 |
+
LOGGER.info(f"Spoof: {len(self.samples[self.samples['label'] == 'spoof'])}")
|
33 |
+
LOGGER.info(f"Original: {len(self.samples[self.samples['label'] == 'bonafide'])}")
|
34 |
+
|
35 |
+
def get_file_references(self):
|
36 |
+
flac_paths = {}
|
37 |
+
for part in self.subset_parts:
|
38 |
+
path = Path(self.path) / f"{self.subset_dir_prefix}_{part}" / self.subset_dir_prefix / "flac"
|
39 |
+
flac_list = list(path.glob("*.flac"))
|
40 |
+
|
41 |
+
for path in flac_list:
|
42 |
+
flac_paths[path.stem] = path
|
43 |
+
|
44 |
+
return flac_paths
|
45 |
+
|
46 |
+
def read_protocol(self):
|
47 |
+
samples = {
|
48 |
+
"sample_name": [],
|
49 |
+
"label": [],
|
50 |
+
"path": [],
|
51 |
+
"attack_type": [],
|
52 |
+
}
|
53 |
+
|
54 |
+
real_samples = []
|
55 |
+
fake_samples = []
|
56 |
+
with open(Path(self.path) / self.protocol_file_name, "r") as file:
|
57 |
+
for line in file:
|
58 |
+
label = line.strip().split(" ")[5]
|
59 |
+
|
60 |
+
if label == "bonafide":
|
61 |
+
real_samples.append(line)
|
62 |
+
elif label == "spoof":
|
63 |
+
fake_samples.append(line)
|
64 |
+
|
65 |
+
fake_samples = self.split_samples(fake_samples)
|
66 |
+
for line in fake_samples:
|
67 |
+
samples = self.add_line_to_samples(samples, line)
|
68 |
+
|
69 |
+
real_samples = self.split_samples(real_samples)
|
70 |
+
for line in real_samples:
|
71 |
+
samples = self.add_line_to_samples(samples, line)
|
72 |
+
|
73 |
+
return pd.DataFrame(samples)
|
74 |
+
|
75 |
+
def add_line_to_samples(self, samples, line):
|
76 |
+
_, sample_name, _, _, _, label, _, _ = line.strip().split(" ")
|
77 |
+
samples["sample_name"].append(sample_name)
|
78 |
+
samples["label"].append(label)
|
79 |
+
samples["attack_type"].append(label)
|
80 |
+
|
81 |
+
sample_path = self.flac_paths[sample_name]
|
82 |
+
assert sample_path.exists()
|
83 |
+
samples["path"].append(sample_path)
|
84 |
+
|
85 |
+
return samples
|
86 |
+
|
src/datasets/detection_dataset.py
ADDED
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from typing import List, Optional
|
3 |
+
|
4 |
+
|
5 |
+
import pandas as pd
|
6 |
+
|
7 |
+
from src.datasets.base_dataset import SimpleAudioFakeDataset
|
8 |
+
from src.datasets.deepfake_asvspoof_dataset import DeepFakeASVSpoofDataset
|
9 |
+
from src.datasets.fakeavceleb_dataset import FakeAVCelebDataset
|
10 |
+
from src.datasets.wavefake_dataset import WaveFakeDataset
|
11 |
+
from src.datasets.asvspoof_dataset import ASVSpoof2019DatasetOriginal
|
12 |
+
from src.datasets.MLAADv3_dataset import MLAADv3
|
13 |
+
from src.datasets.MAILABS_dataset import MAILABS
|
14 |
+
from src.datasets.aihub_dataset import AIHUB
|
15 |
+
from src.datasets.KoAAD_dataset import KoAAD
|
16 |
+
|
17 |
+
|
18 |
+
LOGGER = logging.getLogger()
|
19 |
+
|
20 |
+
|
21 |
+
class DetectionDataset(SimpleAudioFakeDataset):
|
22 |
+
def __init__(
|
23 |
+
self,
|
24 |
+
asvspoof_path=None,
|
25 |
+
wavefake_path=None,
|
26 |
+
fakeavceleb_path=None,
|
27 |
+
asvspoof2019_path=None,
|
28 |
+
MLAADv3_path=None,
|
29 |
+
MAILABS_path=None,
|
30 |
+
AIHUB_path=None,
|
31 |
+
KoAAD_path=None,
|
32 |
+
subset: str = "val",
|
33 |
+
transform=None,
|
34 |
+
oversample: bool = True,
|
35 |
+
undersample: bool = False,
|
36 |
+
return_label: bool = True,
|
37 |
+
reduced_number: Optional[int] = None,
|
38 |
+
return_meta: bool = False,
|
39 |
+
):
|
40 |
+
super().__init__(
|
41 |
+
subset=subset,
|
42 |
+
transform=transform,
|
43 |
+
return_label=return_label,
|
44 |
+
return_meta=return_meta,
|
45 |
+
)
|
46 |
+
datasets = self._init_datasets(
|
47 |
+
asvspoof_path=asvspoof_path,
|
48 |
+
wavefake_path=wavefake_path,
|
49 |
+
fakeavceleb_path=fakeavceleb_path,
|
50 |
+
asvspoof2019_path=asvspoof2019_path,
|
51 |
+
MLAADv3_path=MLAADv3_path,
|
52 |
+
MAILABS_path=MAILABS_path,
|
53 |
+
AIHUB_path=AIHUB_path,
|
54 |
+
KoAAD_path=KoAAD_path,
|
55 |
+
subset=subset,
|
56 |
+
)
|
57 |
+
self.samples = pd.concat([ds.samples for ds in datasets], ignore_index=True)
|
58 |
+
|
59 |
+
if oversample:
|
60 |
+
self.oversample_dataset()
|
61 |
+
elif undersample:
|
62 |
+
self.undersample_dataset()
|
63 |
+
|
64 |
+
if reduced_number:
|
65 |
+
LOGGER.info(f"Using reduced number of samples - {reduced_number}!")
|
66 |
+
self.samples = self.samples.sample(
|
67 |
+
min(len(self.samples), reduced_number),
|
68 |
+
random_state=42,
|
69 |
+
)
|
70 |
+
|
71 |
+
def _init_datasets(
|
72 |
+
self,
|
73 |
+
subset: str,
|
74 |
+
asvspoof_path: Optional[str],
|
75 |
+
wavefake_path: Optional[str],
|
76 |
+
fakeavceleb_path: Optional[str],
|
77 |
+
asvspoof2019_path: Optional[str],
|
78 |
+
MLAADv3_path=Optional[str],
|
79 |
+
MAILABS_path=Optional[str],
|
80 |
+
AIHUB_path=Optional[str],
|
81 |
+
KoAAD_path=Optional[str],
|
82 |
+
) -> List[SimpleAudioFakeDataset]:
|
83 |
+
datasets = []
|
84 |
+
|
85 |
+
if asvspoof_path is not None:
|
86 |
+
asvspoof_dataset = DeepFakeASVSpoofDataset(asvspoof_path, subset=subset)
|
87 |
+
datasets.append(asvspoof_dataset)
|
88 |
+
|
89 |
+
if wavefake_path is not None:
|
90 |
+
wavefake_dataset = WaveFakeDataset(wavefake_path, subset=subset)
|
91 |
+
datasets.append(wavefake_dataset)
|
92 |
+
|
93 |
+
if fakeavceleb_path is not None:
|
94 |
+
fakeavceleb_dataset = FakeAVCelebDataset(fakeavceleb_path, subset=subset)
|
95 |
+
datasets.append(fakeavceleb_dataset)
|
96 |
+
|
97 |
+
if asvspoof2019_path is not None:
|
98 |
+
la_dataset = ASVSpoof2019DatasetOriginal(
|
99 |
+
asvspoof2019_path, fold_subset=subset
|
100 |
+
)
|
101 |
+
datasets.append(la_dataset)
|
102 |
+
|
103 |
+
if MLAADv3_path is not None:
|
104 |
+
MLAADv3_dataset = MLAADv3(MLAADv3_path, subset=subset)
|
105 |
+
datasets.append(MLAADv3_dataset)
|
106 |
+
|
107 |
+
if MAILABS_path is not None:
|
108 |
+
MAILABS_dataset = MAILABS(MAILABS_path, subset=subset)
|
109 |
+
datasets.append(MAILABS_dataset)
|
110 |
+
|
111 |
+
if AIHUB_path is not None:
|
112 |
+
aihub_dataset = AIHUB(AIHUB_path, subset=subset)
|
113 |
+
datasets.append(aihub_dataset)
|
114 |
+
|
115 |
+
if KoAAD_path is not None:
|
116 |
+
KoAAD_dataset = KoAAD(KoAAD_path, subset=subset)
|
117 |
+
datasets.append(KoAAD_dataset)
|
118 |
+
return datasets
|
119 |
+
|
120 |
+
def oversample_dataset(self):
|
121 |
+
samples = self.samples.groupby(by=["label"])
|
122 |
+
bona_length = len(samples.groups["bonafide"])
|
123 |
+
spoof_length = len(samples.groups["spoof"])
|
124 |
+
|
125 |
+
diff_length = spoof_length - bona_length
|
126 |
+
|
127 |
+
if diff_length < 0:
|
128 |
+
raise NotImplementedError
|
129 |
+
|
130 |
+
if diff_length > 0:
|
131 |
+
bonafide = samples.get_group("bonafide").sample(diff_length, replace=True)
|
132 |
+
self.samples = pd.concat([self.samples, bonafide], ignore_index=True)
|
133 |
+
|
134 |
+
def undersample_dataset(self):
|
135 |
+
samples = self.samples.groupby(by=["label"])
|
136 |
+
bona_length = len(samples.groups["bonafide"])
|
137 |
+
spoof_length = len(samples.groups["spoof"])
|
138 |
+
|
139 |
+
if spoof_length < bona_length:
|
140 |
+
raise NotImplementedError
|
141 |
+
|
142 |
+
if spoof_length > bona_length:
|
143 |
+
spoofs = samples.get_group("spoof").sample(bona_length, replace=True)
|
144 |
+
self.samples = pd.concat(
|
145 |
+
[samples.get_group("bonafide"), spoofs], ignore_index=True
|
146 |
+
)
|
147 |
+
|
148 |
+
def get_bonafide_only(self):
|
149 |
+
samples = self.samples.groupby(by=["label"])
|
150 |
+
self.samples = samples.get_group("bonafide")
|
151 |
+
return self.samples
|
152 |
+
|
153 |
+
def get_spoof_only(self):
|
154 |
+
samples = self.samples.groupby(by=["label"])
|
155 |
+
self.samples = samples.get_group("spoof")
|
156 |
+
return self.samples
|
src/datasets/fakeavceleb_dataset.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
|
3 |
+
import pandas as pd
|
4 |
+
|
5 |
+
from src.datasets.base_dataset import SimpleAudioFakeDataset
|
6 |
+
|
7 |
+
FAKEAVCELEB_SPLIT = {
|
8 |
+
"train": ['faceswap-wav2lip', 'fsgan-wav2lip', 'wav2lip', 'rtvc'],
|
9 |
+
"test": ['faceswap-wav2lip', 'fsgan-wav2lip', 'wav2lip', 'rtvc'],
|
10 |
+
"val": ['faceswap-wav2lip', 'fsgan-wav2lip', 'wav2lip', 'rtvc'],
|
11 |
+
"partition_ratio": [0.7, 0.15],
|
12 |
+
"seed": 45
|
13 |
+
}
|
14 |
+
|
15 |
+
|
16 |
+
class FakeAVCelebDataset(SimpleAudioFakeDataset):
|
17 |
+
|
18 |
+
audio_folder = "FakeAVCeleb-audio"
|
19 |
+
audio_extension = ".mp3"
|
20 |
+
metadata_file = Path(audio_folder) / "meta_data.csv"
|
21 |
+
subsets = ("train", "dev", "eval")
|
22 |
+
|
23 |
+
def __init__(self, path, subset="train", transform=None):
|
24 |
+
super().__init__(subset, transform)
|
25 |
+
self.path = path
|
26 |
+
|
27 |
+
self.subset = subset
|
28 |
+
self.allowed_attacks = FAKEAVCELEB_SPLIT[subset]
|
29 |
+
self.partition_ratio = FAKEAVCELEB_SPLIT["partition_ratio"]
|
30 |
+
self.seed = FAKEAVCELEB_SPLIT["seed"]
|
31 |
+
|
32 |
+
self.metadata = self.get_metadata()
|
33 |
+
|
34 |
+
self.samples = pd.concat([self.get_fake_samples(), self.get_real_samples()], ignore_index=True)
|
35 |
+
|
36 |
+
def get_metadata(self):
|
37 |
+
md = pd.read_csv(Path(self.path) / self.metadata_file)
|
38 |
+
md["audio_type"] = md["type"].apply(lambda x: x.split("-")[-1])
|
39 |
+
return md
|
40 |
+
|
41 |
+
def get_fake_samples(self):
|
42 |
+
samples = {
|
43 |
+
"user_id": [],
|
44 |
+
"sample_name": [],
|
45 |
+
"attack_type": [],
|
46 |
+
"label": [],
|
47 |
+
"path": []
|
48 |
+
}
|
49 |
+
|
50 |
+
for attack_name in self.allowed_attacks:
|
51 |
+
fake_samples = self.metadata[
|
52 |
+
(self.metadata["method"] == attack_name) & (self.metadata["audio_type"] == "FakeAudio")
|
53 |
+
]
|
54 |
+
|
55 |
+
samples_list = fake_samples.iterrows()
|
56 |
+
samples_list = self.split_samples(samples_list)
|
57 |
+
|
58 |
+
for _, sample in samples_list:
|
59 |
+
samples["user_id"].append(sample["source"])
|
60 |
+
samples["sample_name"].append(Path(sample["filename"]).stem)
|
61 |
+
samples["attack_type"].append(sample["method"])
|
62 |
+
samples["label"].append("spoof")
|
63 |
+
samples["path"].append(self.get_file_path(sample))
|
64 |
+
|
65 |
+
return pd.DataFrame(samples)
|
66 |
+
|
67 |
+
def get_real_samples(self):
|
68 |
+
samples = {
|
69 |
+
"user_id": [],
|
70 |
+
"sample_name": [],
|
71 |
+
"attack_type": [],
|
72 |
+
"label": [],
|
73 |
+
"path": []
|
74 |
+
}
|
75 |
+
|
76 |
+
samples_list = self.metadata[
|
77 |
+
(self.metadata["method"] == "real") & (self.metadata["audio_type"] == "RealAudio")
|
78 |
+
]
|
79 |
+
|
80 |
+
samples_list = self.split_samples(samples_list)
|
81 |
+
|
82 |
+
for index, sample in samples_list.iterrows():
|
83 |
+
samples["user_id"].append(sample["source"])
|
84 |
+
samples["sample_name"].append(Path(sample["filename"]).stem)
|
85 |
+
samples["attack_type"].append("-")
|
86 |
+
samples["label"].append("bonafide")
|
87 |
+
samples["path"].append(self.get_file_path(sample))
|
88 |
+
|
89 |
+
return pd.DataFrame(samples)
|
90 |
+
|
91 |
+
def get_file_path(self, sample):
|
92 |
+
path = "/".join([self.audio_folder, *sample["path"].split("/")[1:]])
|
93 |
+
return Path(self.path) / path / Path(sample["filename"]).with_suffix(self.audio_extension)
|
94 |
+
|
src/datasets/in_the_wild_dataset.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import pandas as pd
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
from src.datasets.base_dataset import SimpleAudioFakeDataset
|
6 |
+
|
7 |
+
|
8 |
+
class InTheWildDataset(SimpleAudioFakeDataset):
|
9 |
+
|
10 |
+
def __init__(
|
11 |
+
self,
|
12 |
+
path,
|
13 |
+
subset="train",
|
14 |
+
transform=None,
|
15 |
+
seed=None,
|
16 |
+
partition_ratio=(0.7, 0.15),
|
17 |
+
split_strategy="random"
|
18 |
+
):
|
19 |
+
super().__init__(subset=subset, transform=transform)
|
20 |
+
self.path = path
|
21 |
+
self.read_samples()
|
22 |
+
self.partition_ratio = partition_ratio
|
23 |
+
self.seed = seed
|
24 |
+
|
25 |
+
|
26 |
+
def read_samples(self):
|
27 |
+
path = Path(self.path)
|
28 |
+
meta_path = path / "meta.csv"
|
29 |
+
|
30 |
+
self.samples = pd.read_csv(meta_path)
|
31 |
+
self.samples["path"] = self.samples["file"].apply(lambda n: str(path / n))
|
32 |
+
self.samples["file"] = self.samples["file"].apply(lambda n: Path(n).stem)
|
33 |
+
self.samples["label"] = self.samples["label"].map({"bona-fide": "bonafide", "spoof": "spoof"})
|
34 |
+
self.samples["attack_type"] = self.samples["label"].map({"bonafide": "-", "spoof": "X"})
|
35 |
+
self.samples.rename(columns={'file': 'sample_name', 'speaker': 'user_id'}, inplace=True)
|
36 |
+
|
37 |
+
|
38 |
+
def split_samples_per_speaker(self, samples):
|
39 |
+
speaker_list = pd.Series(samples["user_id"].unique())
|
40 |
+
speaker_list = speaker_list.sort_values()
|
41 |
+
speaker_list = speaker_list.sample(frac=1, random_state=self.seed)
|
42 |
+
speaker_list = list(speaker_list)
|
43 |
+
|
44 |
+
p, s = self.partition_ratio
|
45 |
+
subsets = np.split(speaker_list, [int(p * len(speaker_list)), int((p + s) * len(speaker_list))])
|
46 |
+
speaker_subset = dict(zip(['train', 'test', 'val'], subsets))[self.subset]
|
47 |
+
return self.samples[self.samples["user_id"].isin(speaker_subset)]
|
48 |
+
|
49 |
+
|
50 |
+
if __name__ == "__main__":
|
51 |
+
dataset = InTheWildDataset(
|
52 |
+
path="../datasets/release_in_the_wild",
|
53 |
+
subset="val",
|
54 |
+
seed=242,
|
55 |
+
split_strategy="per_speaker"
|
56 |
+
)
|
57 |
+
|
58 |
+
print(len(dataset))
|
59 |
+
print(len(dataset.samples["user_id"].unique()))
|
60 |
+
print(dataset.samples["user_id"].unique())
|
61 |
+
|
62 |
+
print(dataset[0])
|
src/datasets/wavefake_dataset.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
|
3 |
+
import pandas as pd
|
4 |
+
|
5 |
+
from src.datasets.base_dataset import SimpleAudioFakeDataset
|
6 |
+
|
7 |
+
WAVEFAKE_SPLIT = {
|
8 |
+
"train": ['multi_band_melgan', 'melgan_large', 'parallel_wavegan', 'waveglow', 'full_band_melgan', 'melgan', 'hifiGAN'],
|
9 |
+
"test": ['multi_band_melgan', 'melgan_large', 'parallel_wavegan', 'waveglow', 'full_band_melgan', 'melgan', 'hifiGAN'],
|
10 |
+
"val": ['multi_band_melgan', 'melgan_large', 'parallel_wavegan', 'waveglow', 'full_band_melgan', 'melgan', 'hifiGAN'],
|
11 |
+
"partition_ratio": [0.7, 0.15],
|
12 |
+
"seed": 45
|
13 |
+
}
|
14 |
+
|
15 |
+
|
16 |
+
class WaveFakeDataset(SimpleAudioFakeDataset):
|
17 |
+
|
18 |
+
fake_data_path = "generated_audio"
|
19 |
+
jsut_real_data_path = "real_audio/jsut_ver1.1/basic5000/wav"
|
20 |
+
ljspeech_real_data_path = "real_audio/LJSpeech-1.1/wavs"
|
21 |
+
|
22 |
+
def __init__(self, path, subset="train", transform=None):
|
23 |
+
super().__init__(subset, transform)
|
24 |
+
self.path = Path(path)
|
25 |
+
|
26 |
+
self.fold_subset = subset
|
27 |
+
self.allowed_attacks = WAVEFAKE_SPLIT[subset]
|
28 |
+
self.partition_ratio = WAVEFAKE_SPLIT["partition_ratio"]
|
29 |
+
self.seed = WAVEFAKE_SPLIT["seed"]
|
30 |
+
|
31 |
+
self.samples = pd.concat([self.get_fake_samples(), self.get_real_samples()], ignore_index=True)
|
32 |
+
|
33 |
+
def get_fake_samples(self):
|
34 |
+
samples = {
|
35 |
+
"user_id": [],
|
36 |
+
"sample_name": [],
|
37 |
+
"attack_type": [],
|
38 |
+
"label": [],
|
39 |
+
"path": []
|
40 |
+
}
|
41 |
+
|
42 |
+
samples_list = list((self.path / self.fake_data_path).glob("*/*.wav"))
|
43 |
+
samples_list = self.filter_samples_by_attack(samples_list)
|
44 |
+
samples_list = self.split_samples(samples_list)
|
45 |
+
|
46 |
+
for sample in samples_list:
|
47 |
+
samples["user_id"].append(None)
|
48 |
+
samples["sample_name"].append("_".join(sample.stem.split("_")[:-1]))
|
49 |
+
samples["attack_type"].append(self.get_attack_from_path(sample))
|
50 |
+
samples["label"].append("spoof")
|
51 |
+
samples["path"].append(sample)
|
52 |
+
|
53 |
+
return pd.DataFrame(samples)
|
54 |
+
|
55 |
+
def filter_samples_by_attack(self, samples_list):
|
56 |
+
return [s for s in samples_list if self.get_attack_from_path(s) in self.allowed_attacks]
|
57 |
+
|
58 |
+
def get_real_samples(self):
|
59 |
+
samples = {
|
60 |
+
"user_id": [],
|
61 |
+
"sample_name": [],
|
62 |
+
"attack_type": [],
|
63 |
+
"label": [],
|
64 |
+
"path": []
|
65 |
+
}
|
66 |
+
|
67 |
+
samples_list = list((self.path / self.jsut_real_data_path).glob("*.wav"))
|
68 |
+
samples_list += list((self.path / self.ljspeech_real_data_path).glob("*.wav"))
|
69 |
+
samples_list = self.split_samples(samples_list)
|
70 |
+
|
71 |
+
for sample in samples_list:
|
72 |
+
samples["user_id"].append(None)
|
73 |
+
samples["sample_name"].append(sample.stem)
|
74 |
+
samples["attack_type"].append("-")
|
75 |
+
samples["label"].append("bonafide")
|
76 |
+
samples["path"].append(sample)
|
77 |
+
|
78 |
+
return pd.DataFrame(samples)
|
79 |
+
|
80 |
+
@staticmethod
|
81 |
+
def get_attack_from_path(path):
|
82 |
+
folder_name = path.parents[0].relative_to(path.parents[1])
|
83 |
+
return str(folder_name).split("_", maxsplit=1)[-1]
|
84 |
+
|
85 |
+
|
src/frontends.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Union, Callable
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torchaudio
|
5 |
+
|
6 |
+
SAMPLING_RATE = 16_000
|
7 |
+
win_length = 400 # int((25 / 1_000) * SAMPLING_RATE)
|
8 |
+
hop_length = 160 # int((10 / 1_000) * SAMPLING_RATE)
|
9 |
+
|
10 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
11 |
+
|
12 |
+
MFCC_FN = torchaudio.transforms.MFCC(
|
13 |
+
sample_rate=SAMPLING_RATE,
|
14 |
+
n_mfcc=128,
|
15 |
+
melkwargs={
|
16 |
+
"n_fft": 512,
|
17 |
+
"win_length": win_length,
|
18 |
+
"hop_length": hop_length,
|
19 |
+
},
|
20 |
+
).to(device)
|
21 |
+
|
22 |
+
|
23 |
+
LFCC_FN = torchaudio.transforms.LFCC(
|
24 |
+
sample_rate=SAMPLING_RATE,
|
25 |
+
n_lfcc=128,
|
26 |
+
speckwargs={
|
27 |
+
"n_fft": 512,
|
28 |
+
"win_length": win_length,
|
29 |
+
"hop_length": hop_length,
|
30 |
+
},
|
31 |
+
).to(device)
|
32 |
+
|
33 |
+
MEL_SCALE_FN = torchaudio.transforms.MelScale(
|
34 |
+
n_mels=80,
|
35 |
+
n_stft=257,
|
36 |
+
sample_rate=SAMPLING_RATE,
|
37 |
+
).to(device)
|
38 |
+
|
39 |
+
delta_fn = torchaudio.transforms.ComputeDeltas(
|
40 |
+
win_length=400,
|
41 |
+
mode="replicate",
|
42 |
+
)
|
43 |
+
|
44 |
+
|
45 |
+
def get_frontend(
|
46 |
+
frontends: List[str],
|
47 |
+
) -> Union[torchaudio.transforms.MFCC, torchaudio.transforms.LFCC, Callable,]:
|
48 |
+
if "mfcc" in frontends:
|
49 |
+
return prepare_mfcc_double_delta
|
50 |
+
elif "lfcc" in frontends:
|
51 |
+
return prepare_lfcc_double_delta
|
52 |
+
raise ValueError(f"{frontends} frontend is not supported!")
|
53 |
+
|
54 |
+
|
55 |
+
def prepare_lfcc_double_delta(input):
|
56 |
+
if input.ndim < 4:
|
57 |
+
input = input.unsqueeze(1) # (bs, 1, n_lfcc, frames)
|
58 |
+
x = LFCC_FN(input)
|
59 |
+
delta = delta_fn(x)
|
60 |
+
double_delta = delta_fn(delta)
|
61 |
+
x = torch.cat((x, delta, double_delta), 2) # -> [bs, 1, 128 * 3, 1500]
|
62 |
+
return x[:, :, :, :512] # (bs, n, n_lfcc * 3, frames)
|
63 |
+
|
64 |
+
|
65 |
+
def prepare_mfcc_double_delta(input):
|
66 |
+
if input.ndim < 4:
|
67 |
+
input = input.unsqueeze(1) # (bs, 1, n_lfcc, frames)
|
68 |
+
x = MFCC_FN(input)
|
69 |
+
delta = delta_fn(x)
|
70 |
+
double_delta = delta_fn(delta)
|
71 |
+
x = torch.cat((x, delta, double_delta), 2) # -> [bs, 1, 128 * 3, 1500]
|
72 |
+
return x[:, :, :, :512] # (bs, n, n_lfcc * 3, frames)
|
src/metrics.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Tuple
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
from scipy.interpolate import interp1d
|
5 |
+
from scipy.optimize import brentq
|
6 |
+
from sklearn.metrics import roc_curve
|
7 |
+
from sklearn.metrics import roc_curve
|
8 |
+
|
9 |
+
|
10 |
+
def calculate_eer(y, y_score) -> Tuple[float, float, np.ndarray, np.ndarray]:
|
11 |
+
fpr, tpr, thresholds = roc_curve(y, -y_score)
|
12 |
+
|
13 |
+
eer = brentq(lambda x: 1.0 - x - interp1d(fpr, tpr)(x), 0.0, 1.0)
|
14 |
+
thresh = interp1d(fpr, thresholds)(eer)
|
15 |
+
return thresh, eer, fpr, tpr
|
src/trainer.py
ADDED
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""A generic training wrapper."""
|
2 |
+
from copy import deepcopy
|
3 |
+
import logging
|
4 |
+
from typing import Callable, List, Optional
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from torch.utils.data import DataLoader
|
8 |
+
|
9 |
+
|
10 |
+
LOGGER = logging.getLogger(__name__)
|
11 |
+
|
12 |
+
|
13 |
+
class Trainer:
|
14 |
+
def __init__(
|
15 |
+
self,
|
16 |
+
epochs: int = 20,
|
17 |
+
batch_size: int = 32,
|
18 |
+
device: str = "cpu",
|
19 |
+
optimizer_fn: Callable = torch.optim.Adam,
|
20 |
+
optimizer_kwargs: dict = {"lr": 1e-3},
|
21 |
+
use_scheduler: bool = False,
|
22 |
+
) -> None:
|
23 |
+
self.epochs = epochs
|
24 |
+
self.batch_size = batch_size
|
25 |
+
self.device = device
|
26 |
+
self.optimizer_fn = optimizer_fn
|
27 |
+
self.optimizer_kwargs = optimizer_kwargs
|
28 |
+
self.epoch_test_losses: List[float] = []
|
29 |
+
self.use_scheduler = use_scheduler
|
30 |
+
|
31 |
+
|
32 |
+
def forward_and_loss(model, criterion, batch_x, batch_y, **kwargs):
|
33 |
+
batch_out = model(batch_x)
|
34 |
+
batch_loss = criterion(batch_out, batch_y)
|
35 |
+
return batch_out, batch_loss
|
36 |
+
|
37 |
+
|
38 |
+
class GDTrainer(Trainer):
|
39 |
+
def train(
|
40 |
+
self,
|
41 |
+
dataset: torch.utils.data.Dataset,
|
42 |
+
model: torch.nn.Module,
|
43 |
+
test_len: Optional[float] = None,
|
44 |
+
test_dataset: Optional[torch.utils.data.Dataset] = None,
|
45 |
+
):
|
46 |
+
if test_dataset is not None:
|
47 |
+
train = dataset
|
48 |
+
test = test_dataset
|
49 |
+
else:
|
50 |
+
test_len = int(len(dataset) * test_len)
|
51 |
+
train_len = len(dataset) - test_len
|
52 |
+
lengths = [train_len, test_len]
|
53 |
+
train, test = torch.utils.data.random_split(dataset, lengths)
|
54 |
+
|
55 |
+
train_loader = DataLoader(
|
56 |
+
train,
|
57 |
+
batch_size=self.batch_size,
|
58 |
+
shuffle=True,
|
59 |
+
drop_last=True,
|
60 |
+
num_workers=6,
|
61 |
+
)
|
62 |
+
test_loader = DataLoader(
|
63 |
+
test,
|
64 |
+
batch_size=self.batch_size,
|
65 |
+
shuffle=True,
|
66 |
+
drop_last=True,
|
67 |
+
num_workers=6,
|
68 |
+
)
|
69 |
+
|
70 |
+
criterion = torch.nn.BCEWithLogitsLoss()
|
71 |
+
optim = self.optimizer_fn(model.parameters(), **self.optimizer_kwargs)
|
72 |
+
|
73 |
+
best_model = None
|
74 |
+
best_acc = 0
|
75 |
+
|
76 |
+
LOGGER.info(f"Starting training for {self.epochs} epochs!")
|
77 |
+
|
78 |
+
forward_and_loss_fn = forward_and_loss
|
79 |
+
|
80 |
+
if self.use_scheduler:
|
81 |
+
batches_per_epoch = len(train_loader) * 2 # every 2nd epoch
|
82 |
+
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
|
83 |
+
optimizer=optim,
|
84 |
+
T_0=batches_per_epoch,
|
85 |
+
T_mult=1,
|
86 |
+
eta_min=5e-6,
|
87 |
+
# verbose=True,
|
88 |
+
)
|
89 |
+
use_cuda = self.device != "cpu"
|
90 |
+
|
91 |
+
for epoch in range(self.epochs):
|
92 |
+
LOGGER.info(f"Epoch num: {epoch}")
|
93 |
+
|
94 |
+
running_loss = 0
|
95 |
+
num_correct = 0.0
|
96 |
+
num_total = 0.0
|
97 |
+
model.train()
|
98 |
+
|
99 |
+
for i, (batch_x, _, batch_y) in enumerate(train_loader):
|
100 |
+
batch_size = batch_x.size(0)
|
101 |
+
num_total += batch_size
|
102 |
+
batch_x = batch_x.to(self.device)
|
103 |
+
|
104 |
+
batch_y = batch_y.unsqueeze(1).type(torch.float32).to(self.device)
|
105 |
+
|
106 |
+
batch_out, batch_loss = forward_and_loss_fn(
|
107 |
+
model, criterion, batch_x, batch_y, use_cuda=use_cuda
|
108 |
+
)
|
109 |
+
batch_pred = (torch.sigmoid(batch_out) + 0.5).int()
|
110 |
+
num_correct += (batch_pred == batch_y.int()).sum(dim=0).item()
|
111 |
+
|
112 |
+
running_loss += batch_loss.item() * batch_size
|
113 |
+
|
114 |
+
if i % 100 == 0:
|
115 |
+
LOGGER.info(
|
116 |
+
f"[{epoch:04d}][{i:05d}]: {running_loss / num_total} {num_correct/num_total*100}"
|
117 |
+
)
|
118 |
+
|
119 |
+
optim.zero_grad()
|
120 |
+
batch_loss.backward()
|
121 |
+
optim.step()
|
122 |
+
if self.use_scheduler:
|
123 |
+
scheduler.step()
|
124 |
+
|
125 |
+
running_loss /= num_total
|
126 |
+
train_accuracy = (num_correct / num_total) * 100
|
127 |
+
|
128 |
+
LOGGER.info(
|
129 |
+
f"Epoch [{epoch+1}/{self.epochs}]: train/loss: {running_loss}, train/accuracy: {train_accuracy}"
|
130 |
+
)
|
131 |
+
|
132 |
+
test_running_loss = 0.0
|
133 |
+
num_correct = 0.0
|
134 |
+
num_total = 0.0
|
135 |
+
model.eval()
|
136 |
+
eer_val = 0
|
137 |
+
|
138 |
+
for batch_x, _, batch_y in test_loader:
|
139 |
+
batch_size = batch_x.size(0)
|
140 |
+
num_total += batch_size
|
141 |
+
batch_x = batch_x.to(self.device)
|
142 |
+
|
143 |
+
with torch.no_grad():
|
144 |
+
batch_pred = model(batch_x)
|
145 |
+
|
146 |
+
batch_y = batch_y.unsqueeze(1).type(torch.float32).to(self.device)
|
147 |
+
batch_loss = criterion(batch_pred, batch_y)
|
148 |
+
|
149 |
+
test_running_loss += batch_loss.item() * batch_size
|
150 |
+
|
151 |
+
batch_pred = torch.sigmoid(batch_pred)
|
152 |
+
batch_pred_label = (batch_pred + 0.5).int()
|
153 |
+
num_correct += (batch_pred_label == batch_y.int()).sum(dim=0).item()
|
154 |
+
|
155 |
+
if num_total == 0:
|
156 |
+
num_total = 1
|
157 |
+
|
158 |
+
test_running_loss /= num_total
|
159 |
+
test_acc = 100 * (num_correct / num_total)
|
160 |
+
LOGGER.info(
|
161 |
+
f"Epoch [{epoch+1}/{self.epochs}]: test/loss: {test_running_loss}, test/accuracy: {test_acc}, test/eer: {eer_val}"
|
162 |
+
)
|
163 |
+
|
164 |
+
if best_model is None or test_acc > best_acc:
|
165 |
+
best_acc = test_acc
|
166 |
+
best_model = deepcopy(model.state_dict())
|
167 |
+
|
168 |
+
LOGGER.info(
|
169 |
+
f"[{epoch:04d}]: {running_loss} - train acc: {train_accuracy} - test_acc: {test_acc}"
|
170 |
+
)
|
171 |
+
|
172 |
+
model.load_state_dict(best_model)
|
173 |
+
return model
|