import pandas as pd
from src.utils.io_utils import PROJECT_ROOT
from src.dataset.GoodDataset import AugmentedDataset
from src.dataset.NegativeSampler import NegativeSampler
from src.utils.struct_utils import *
import os

class Config:
    input = os.path.join(PROJECT_ROOT, "data/positive_samples.pkl")
    output = os.path.join(PROJECT_ROOT, "data/negative_samples.pkl")

    seed=42
    
    random=True
    fuzz_title=True
    replace_auth=True
    overlap_auth=False
    overlap_topic=False

    factor_max=4
    authors_to_consider=1
    overlapping_authors=1
    fuzz_count=1

def negative_sampler(optional_path = None, factor = None, type_or_difficulty = None)-> pd.DataFrame:
    datapath = optional_path if optional_path else f"{PROJECT_ROOT}/data/crossref-preprint-article-relationships-Aug-2023.csv"
    # return pd.read_csv(datapath)
    dataset = AugmentedDataset()
    # datapath = '../data/pos.csv'
    dataset.load_csv(datapath)

    sampler = NegativeSampler(dataset)
    config = Config()
    sampler.create_negative_samples(config)

    return custom_struct_to_df(dataset.negative_samples)

def positive_sampler(optional_path=None, size=10, random=True, seed=42, full=False):
    datapath = optional_path if optional_path else f"{PROJECT_ROOT}/data/crossref-preprint-article-relationships-Aug-2023.csv"
    dataset = AugmentedDataset(datapath)
    dataset.fetch_positive_samples_parallel(
        num_samples=size,
        random=random,
        seed=seed,
        full=full
    )

    return custom_struct_to_df(dataset.positive_samples)