Spaces:
Runtime error
Runtime error
import logging | |
from typing import List, Optional | |
import pandas as pd | |
from src.datasets.base_dataset import SimpleAudioFakeDataset | |
from src.datasets.deepfake_asvspoof_dataset import DeepFakeASVSpoofDataset | |
from src.datasets.fakeavceleb_dataset import FakeAVCelebDataset | |
from src.datasets.wavefake_dataset import WaveFakeDataset | |
from src.datasets.asvspoof_dataset import ASVSpoof2019DatasetOriginal | |
from src.datasets.MLAADv3_dataset import MLAADv3 | |
from src.datasets.MAILABS_dataset import MAILABS | |
from src.datasets.aihub_dataset import AIHUB | |
from src.datasets.KoAAD_dataset import KoAAD | |
LOGGER = logging.getLogger() | |
class DetectionDataset(SimpleAudioFakeDataset): | |
def __init__( | |
self, | |
asvspoof_path=None, | |
wavefake_path=None, | |
fakeavceleb_path=None, | |
asvspoof2019_path=None, | |
MLAADv3_path=None, | |
MAILABS_path=None, | |
AIHUB_path=None, | |
KoAAD_path=None, | |
subset: str = "val", | |
transform=None, | |
oversample: bool = True, | |
undersample: bool = False, | |
return_label: bool = True, | |
reduced_number: Optional[int] = None, | |
return_meta: bool = False, | |
): | |
super().__init__( | |
subset=subset, | |
transform=transform, | |
return_label=return_label, | |
return_meta=return_meta, | |
) | |
datasets = self._init_datasets( | |
asvspoof_path=asvspoof_path, | |
wavefake_path=wavefake_path, | |
fakeavceleb_path=fakeavceleb_path, | |
asvspoof2019_path=asvspoof2019_path, | |
MLAADv3_path=MLAADv3_path, | |
MAILABS_path=MAILABS_path, | |
AIHUB_path=AIHUB_path, | |
KoAAD_path=KoAAD_path, | |
subset=subset, | |
) | |
self.samples = pd.concat([ds.samples for ds in datasets], ignore_index=True) | |
if oversample: | |
self.oversample_dataset() | |
elif undersample: | |
self.undersample_dataset() | |
if reduced_number: | |
LOGGER.info(f"Using reduced number of samples - {reduced_number}!") | |
self.samples = self.samples.sample( | |
min(len(self.samples), reduced_number), | |
random_state=42, | |
) | |
def _init_datasets( | |
self, | |
subset: str, | |
asvspoof_path: Optional[str], | |
wavefake_path: Optional[str], | |
fakeavceleb_path: Optional[str], | |
asvspoof2019_path: Optional[str], | |
MLAADv3_path=Optional[str], | |
MAILABS_path=Optional[str], | |
AIHUB_path=Optional[str], | |
KoAAD_path=Optional[str], | |
) -> List[SimpleAudioFakeDataset]: | |
datasets = [] | |
if asvspoof_path is not None: | |
asvspoof_dataset = DeepFakeASVSpoofDataset(asvspoof_path, subset=subset) | |
datasets.append(asvspoof_dataset) | |
if wavefake_path is not None: | |
wavefake_dataset = WaveFakeDataset(wavefake_path, subset=subset) | |
datasets.append(wavefake_dataset) | |
if fakeavceleb_path is not None: | |
fakeavceleb_dataset = FakeAVCelebDataset(fakeavceleb_path, subset=subset) | |
datasets.append(fakeavceleb_dataset) | |
if asvspoof2019_path is not None: | |
la_dataset = ASVSpoof2019DatasetOriginal( | |
asvspoof2019_path, fold_subset=subset | |
) | |
datasets.append(la_dataset) | |
if MLAADv3_path is not None: | |
MLAADv3_dataset = MLAADv3(MLAADv3_path, subset=subset) | |
datasets.append(MLAADv3_dataset) | |
if MAILABS_path is not None: | |
MAILABS_dataset = MAILABS(MAILABS_path, subset=subset) | |
datasets.append(MAILABS_dataset) | |
if AIHUB_path is not None: | |
aihub_dataset = AIHUB(AIHUB_path, subset=subset) | |
datasets.append(aihub_dataset) | |
if KoAAD_path is not None: | |
KoAAD_dataset = KoAAD(KoAAD_path, subset=subset) | |
datasets.append(KoAAD_dataset) | |
return datasets | |
def oversample_dataset(self): | |
samples = self.samples.groupby(by=["label"]) | |
bona_length = len(samples.groups["bonafide"]) | |
spoof_length = len(samples.groups["spoof"]) | |
diff_length = spoof_length - bona_length | |
if diff_length < 0: | |
raise NotImplementedError | |
if diff_length > 0: | |
bonafide = samples.get_group("bonafide").sample(diff_length, replace=True) | |
self.samples = pd.concat([self.samples, bonafide], ignore_index=True) | |
def undersample_dataset(self): | |
samples = self.samples.groupby(by=["label"]) | |
bona_length = len(samples.groups["bonafide"]) | |
spoof_length = len(samples.groups["spoof"]) | |
if spoof_length < bona_length: | |
raise NotImplementedError | |
if spoof_length > bona_length: | |
spoofs = samples.get_group("spoof").sample(bona_length, replace=True) | |
self.samples = pd.concat( | |
[samples.get_group("bonafide"), spoofs], ignore_index=True | |
) | |
def get_bonafide_only(self): | |
samples = self.samples.groupby(by=["label"]) | |
self.samples = samples.get_group("bonafide") | |
return self.samples | |
def get_spoof_only(self): | |
samples = self.samples.groupby(by=["label"]) | |
self.samples = samples.get_group("spoof") | |
return self.samples | |