streamlit_demo / src /datasets /detection_dataset.py
ldhldh's picture
Upload 28 files
2c0f55c verified
import logging
from typing import List, Optional
import pandas as pd
from src.datasets.base_dataset import SimpleAudioFakeDataset
from src.datasets.deepfake_asvspoof_dataset import DeepFakeASVSpoofDataset
from src.datasets.fakeavceleb_dataset import FakeAVCelebDataset
from src.datasets.wavefake_dataset import WaveFakeDataset
from src.datasets.asvspoof_dataset import ASVSpoof2019DatasetOriginal
from src.datasets.MLAADv3_dataset import MLAADv3
from src.datasets.MAILABS_dataset import MAILABS
from src.datasets.aihub_dataset import AIHUB
from src.datasets.KoAAD_dataset import KoAAD
LOGGER = logging.getLogger()
class DetectionDataset(SimpleAudioFakeDataset):
def __init__(
self,
asvspoof_path=None,
wavefake_path=None,
fakeavceleb_path=None,
asvspoof2019_path=None,
MLAADv3_path=None,
MAILABS_path=None,
AIHUB_path=None,
KoAAD_path=None,
subset: str = "val",
transform=None,
oversample: bool = True,
undersample: bool = False,
return_label: bool = True,
reduced_number: Optional[int] = None,
return_meta: bool = False,
):
super().__init__(
subset=subset,
transform=transform,
return_label=return_label,
return_meta=return_meta,
)
datasets = self._init_datasets(
asvspoof_path=asvspoof_path,
wavefake_path=wavefake_path,
fakeavceleb_path=fakeavceleb_path,
asvspoof2019_path=asvspoof2019_path,
MLAADv3_path=MLAADv3_path,
MAILABS_path=MAILABS_path,
AIHUB_path=AIHUB_path,
KoAAD_path=KoAAD_path,
subset=subset,
)
self.samples = pd.concat([ds.samples for ds in datasets], ignore_index=True)
if oversample:
self.oversample_dataset()
elif undersample:
self.undersample_dataset()
if reduced_number:
LOGGER.info(f"Using reduced number of samples - {reduced_number}!")
self.samples = self.samples.sample(
min(len(self.samples), reduced_number),
random_state=42,
)
def _init_datasets(
self,
subset: str,
asvspoof_path: Optional[str],
wavefake_path: Optional[str],
fakeavceleb_path: Optional[str],
asvspoof2019_path: Optional[str],
MLAADv3_path=Optional[str],
MAILABS_path=Optional[str],
AIHUB_path=Optional[str],
KoAAD_path=Optional[str],
) -> List[SimpleAudioFakeDataset]:
datasets = []
if asvspoof_path is not None:
asvspoof_dataset = DeepFakeASVSpoofDataset(asvspoof_path, subset=subset)
datasets.append(asvspoof_dataset)
if wavefake_path is not None:
wavefake_dataset = WaveFakeDataset(wavefake_path, subset=subset)
datasets.append(wavefake_dataset)
if fakeavceleb_path is not None:
fakeavceleb_dataset = FakeAVCelebDataset(fakeavceleb_path, subset=subset)
datasets.append(fakeavceleb_dataset)
if asvspoof2019_path is not None:
la_dataset = ASVSpoof2019DatasetOriginal(
asvspoof2019_path, fold_subset=subset
)
datasets.append(la_dataset)
if MLAADv3_path is not None:
MLAADv3_dataset = MLAADv3(MLAADv3_path, subset=subset)
datasets.append(MLAADv3_dataset)
if MAILABS_path is not None:
MAILABS_dataset = MAILABS(MAILABS_path, subset=subset)
datasets.append(MAILABS_dataset)
if AIHUB_path is not None:
aihub_dataset = AIHUB(AIHUB_path, subset=subset)
datasets.append(aihub_dataset)
if KoAAD_path is not None:
KoAAD_dataset = KoAAD(KoAAD_path, subset=subset)
datasets.append(KoAAD_dataset)
return datasets
def oversample_dataset(self):
samples = self.samples.groupby(by=["label"])
bona_length = len(samples.groups["bonafide"])
spoof_length = len(samples.groups["spoof"])
diff_length = spoof_length - bona_length
if diff_length < 0:
raise NotImplementedError
if diff_length > 0:
bonafide = samples.get_group("bonafide").sample(diff_length, replace=True)
self.samples = pd.concat([self.samples, bonafide], ignore_index=True)
def undersample_dataset(self):
samples = self.samples.groupby(by=["label"])
bona_length = len(samples.groups["bonafide"])
spoof_length = len(samples.groups["spoof"])
if spoof_length < bona_length:
raise NotImplementedError
if spoof_length > bona_length:
spoofs = samples.get_group("spoof").sample(bona_length, replace=True)
self.samples = pd.concat(
[samples.get_group("bonafide"), spoofs], ignore_index=True
)
def get_bonafide_only(self):
samples = self.samples.groupby(by=["label"])
self.samples = samples.get_group("bonafide")
return self.samples
def get_spoof_only(self):
samples = self.samples.groupby(by=["label"])
self.samples = samples.get_group("spoof")
return self.samples