from src.dataset.GoodDataset import * from src.dataset.NegativeSampler import * import argparse import os def main(config): """ Main function to process the dataset and save it as a CSV file. Args: config: Namespace object containing the script arguments. """ dataset = AugmentedDataset() dataset.load(config.input) sampler = NegativeSampler(dataset) sampler.create_negative_samples(config) print(custom_struct_to_df(dataset.negative_samples).head()) custom_struct_to_df(dataset.positive_samples).to_csv('./data/pos.csv', index=False) custom_struct_to_df(dataset.negative_samples).to_csv('./data/neg.csv', index=False) print(len(dataset.positive_samples)) print(len(dataset.negative_samples)) if __name__ == "__main__": # Parse command-line arguments from src.utils.io_utils import PROJECT_ROOT parser = argparse.ArgumentParser(description="Generate and save a dataset based on the given configuration.") parser.add_argument("-i", "--input", type=str, default=os.path.join(PROJECT_ROOT, "data/positive_samples.pkl"), help="Input file path to load the positive samples.") parser.add_argument("-o", "--output", type=str, default=os.path.join(PROJECT_ROOT, "data/negative_samples.pkl"), help="Output file path to save the negative samples.") parser.add_argument("-s", "--seed", type=int, default=42, help="Random seed for reproducibility.") parser.add_argument("-r", "--random", action='store_true', help="Utilization of `sample_random`") parser.add_argument("-f", "--fuzz_title", action='store_true', help="Utilization of `fuzz_title`") parser.add_argument("-ra", "--replace_auth", action='store_true', help="Utilization of `sample_authors_overlap_random`") parser.add_argument("-oa", "--overlap_auth", action='store_true', help="Utilization of `sample_authors_overlap`") parser.add_argument("-ot", "--overlap_topic", action='store_true', help="Utilization of `sample_similar_topic`") parser.add_argument("--factor_max", type=int, default=4, help="Maximum number of negative samples to generate per positive sample.") parser.add_argument("--authors_to_consider", type=int, default=1, help="Number of authors to consider when overlapping authors.") parser.add_argument("--overlapping_authors", type=int, default=1, help="Minimum number of overlapping authors required.") parser.add_argument("--fuzz_count", type=int, default=-1, help="Number of words to replace when fuzzing titles.") # Parse the arguments and pass to the main function config = parser.parse_args() if config.overlap_auth and config.overlap_topic: parser.error("Only one of --overlap_auth and --overlap_topic can be used.") if not (config.overlap_auth or config.overlap_topic or config.random): parser.error("At least one of --overlap_auth, --overlap_topic, or --random must be specified.") main(config)