Spaces:
Sleeping
Sleeping
import random | |
from typing import List, Dict, Any, Union, Set, Callable | |
import copy | |
import pandas as pd | |
import numpy as np | |
import nltk | |
from nltk.corpus import words | |
nltk.download("words") | |
from src.dataset.GoodDataset import * | |
def copy_column_value( | |
df1: pd.DataFrame, | |
df2: pd.DataFrame, | |
source_col: str, | |
target_col:str, | |
source_transform: Callable[[Any], Any] = lambda x: x | |
) -> List[pd.DataFrame]: | |
""" | |
Copies the value from `source_col` in `df1` to `target_col` in `df2`, | |
while ensuring that the original DataFrames remain unaltered by | |
working on deep copies. | |
Args: | |
df1 (pd.DataFrame): The source DataFrame containing the value to copy. | |
df2 (pd.DataFrame): The target DataFrame where the value will be copied. | |
source_col (str): The column name in `df1` from which the value will be sourced. | |
target_col (str): The column name in `df2` where the value will be written. | |
Returns: | |
List[pd.DataFrame]: A list containing the original `df1` and the modified copy of `df2`. | |
""" | |
# Create a deepcopy of `df2` to ensure the original DataFrame remains unchanged. | |
df2_copy = copy.deepcopy(df2) | |
# Extract the value from the first row of the specified source column in `df1`. | |
value_to_copy = df1.iloc[0][source_col] | |
# Write the extracted value to the first row of the specified target column in the copied `df2`. | |
df2_copy.at[0, target_col] = source_transform(value_to_copy) | |
return [df1, df2_copy] | |
def keep_on_condition( | |
dataset: List[List[pd.DataFrame]], | |
column_to_check: str, | |
indices_to_ignore: Union[List[int], Set[int], int], | |
function_to_compare: Callable[[Any], bool] | |
) -> List[List[pd.DataFrame]]: | |
""" | |
Filters a dataset based on a column value and ignores specified indices. | |
Args: | |
dataset (List[List[pd.DataFrame]]): The dataset to filter, organized as a list of pairs of DataFrames. | |
column_to_check (str): The column in the article DataFrame to check for values. | |
values_to_keep (Union[List[Any], Set[Any], Any]): Values to keep in the filtering process. | |
indices_to_ignore (Union[List[int], Set[int], int]): Indices to ignore during filtering. | |
article_transform (Callable[[Any], Any], optional): Transformation function for column values. Defaults to identity. | |
Returns: | |
List[List[pd.DataFrame]]: Filtered dataset. | |
""" | |
# Normalize `indices_to_ignore` to a set | |
if isinstance(indices_to_ignore, int): | |
indices_to_ignore = {indices_to_ignore} | |
elif isinstance(indices_to_ignore, list): | |
indices_to_ignore = set(indices_to_ignore) | |
# Filter dataset | |
return [ | |
[preprint, article] | |
for i, (preprint, article) in enumerate(dataset) | |
if ( | |
f"article_{column_to_check}" in article.columns and | |
function_to_compare(article[f"article_{column_to_check}"].iloc[0]) and | |
i not in indices_to_ignore | |
) | |
] | |
class NegativeSampler: | |
# def __init__(self, positive_samples: List[List[pd.DataFrame]]): | |
def __init__(self, dataset: AugmentedDataset): | |
""" | |
Initializes the NegativeSampler with a dataset of preprint-article pairs. | |
:param positive_samples: List of dictionaries, each containing information about preprints and articles. | |
""" | |
self.dataset = dataset | |
self.positive_samples = dataset.positive_samples | |
### ARGUMENTS for negative sampling here? | |
def sample_random( | |
self, | |
preprint_index: int, | |
factor_max: int = 4, | |
random_state: int = -1, | |
custom_samples: List[List[pd.DataFrame]] = None | |
) -> List[List[pd.DataFrame]]: | |
""" | |
Randomly samples a non-matching article to create the negative sample. | |
:param preprint: The preprint for which to create a negative sample. | |
:return: A randomly selected negative sample. | |
""" | |
if random_state >= 0: | |
np.random.seed(random_state) | |
positive_samples = custom_samples if custom_samples is not None else self.positive_samples | |
factor = min(len(positive_samples), factor_max) | |
assert factor >= 1, "Dataset doesn't contain enough samples" | |
# Sample `factor` non-matching articles from the dataset to create the negative samples | |
other_indices = np.array([j for j in range(len(positive_samples)) if j != preprint_index]) | |
sampled_indices = np.random.choice(other_indices, size=factor, replace=False) | |
sampled_rows = [positive_samples[j] for j in sampled_indices] | |
if preprint_index < 0: | |
return sampled_rows | |
# Create and return the negative samples using the original preprint and the sampled article | |
preprint, _ = positive_samples[preprint_index] | |
return [ | |
[preprint, non_matching_article] | |
for _, non_matching_article in sampled_rows | |
] | |
def fuzz_title( | |
self, | |
fuzz_count: int = -1, | |
custom_samples: List[List[pd.DataFrame]] = None | |
) -> List[List[pd.DataFrame]]: | |
""" | |
Fuzzes out the title to create the negative sample. Likely changes the abstract and/or authors. | |
:param preprint: The preprint for which to create a negative sample. | |
:param fuzz_factor: A threshold for title similarity (0.0 to 1.0). | |
:return: A negative sample with a fuzzed title. | |
""" | |
def replace_with_random_words(text: str, fuzz_count: int = fuzz_count) -> str: | |
""" | |
Replaces a specified number of words in the input string with random words | |
from the NLTK `words` corpus. | |
Args: | |
text (str): The input string to fuzz. | |
fuzz_count (int): The number of words to replace in the string. | |
Returns: | |
str: The string with random word replacements. | |
""" | |
if fuzz_count == -1: | |
fuzz_count = text.count(" ") // 2 | |
# Load the list of English words from the NLTK corpus | |
word_list = words.words() | |
# Split the input text into a list of words | |
text_words = text.split() | |
for _ in range(fuzz_count): | |
# Randomly pick a word in the text to replace | |
index_to_replace = random.randint(0, len(text_words) - 1) | |
# Replace it with a random word from the NLTK corpus | |
random_word = random.choice(word_list) | |
text_words[index_to_replace] = random_word | |
# Join the list back into a string and return | |
return " ".join(text_words) | |
return [ | |
copy_column_value(preprint, article, "prpnt_basic_title", "article_basic_title", replace_with_random_words) | |
for preprint, article in (custom_samples or self.positive_samples) | |
] | |
def sample_authors_overlap_random( | |
self, | |
custom_samples: List[List[pd.DataFrame]] = None | |
) -> List[List[pd.DataFrame]]: | |
""" | |
Samples a random non-matching article and replaces its authors with the preprint authors to create the negative sample. | |
:param preprint: The preprint for which to create a negative sample. | |
:return: A negative sample with authors replaced. | |
""" | |
return [ | |
copy_column_value(preprint, article, "prpnt_authors_id", "article_authors_id") | |
for preprint, article in (custom_samples or self.positive_samples) | |
] | |
def sample_authors_overlap( | |
self, | |
preprint_index: int, | |
factor_max: int = 4, | |
random_state: int = -1, | |
authors_to_consider: int = 1, | |
overlapping_authors: int = 1 | |
) -> List[List[pd.DataFrame]]: | |
""" | |
Samples a published article with some author overlap to create the negative sample. | |
:param preprint: The preprint for which to create a negative sample. | |
:return: A negative sample with some author overlap. | |
""" | |
def extract_authors(authors_str: str, authors_to_keep: int = -1) -> list: | |
""" | |
Extracts a list of authors from a string, with an optional limit on the number of authors to return. | |
Args: | |
authors_str (str): A string containing authors, expected to include `openalex` and be separated by `$@$`. | |
authors_to_keep (int, optional): The number of authors to keep. If -1, all authors except the last one are kept. Defaults to -1. | |
Returns: | |
list: A list of authors, truncated to the specified number if `authors_to_keep` is provided. | |
Raises: | |
ValueError: If `authors_str` does not contain the substring `openalex`. | |
""" | |
# Split the authors string into a list using the custom delimiter `$@$` | |
authors_list = authors_str.split("$@$") | |
if not authors_list: | |
raise ValueError(f"Invalid input: {authors_str}. The string must contain 'openalex'.") | |
# Determine how many authors to keep | |
if authors_to_keep == -1: | |
authors_to_keep = len(authors_list) # Exclude the last item | |
# Return the truncated list of authors | |
return authors_list[:authors_to_keep] | |
suffix = "authors_id" | |
positive_preprint, _ = self.positive_samples[preprint_index] | |
preprint_authors = set(extract_authors(positive_preprint[f"prpnt_{suffix}"].iloc[0])) | |
def confirm_overlap(article_authors): | |
article_authors = set(extract_authors(article_authors, authors_to_consider)) | |
if len(preprint_authors.intersection(article_authors)) >= overlapping_authors: | |
print(f"\t{article_authors}") | |
return len(preprint_authors.intersection(article_authors)) >= overlapping_authors | |
# Collect preprint-article pairs where the article has some overlapping authors with the selected preprint. | |
# Exclude the pair matching the selected preprint to ensure proper functionality of random sampling later. | |
custom_samples = keep_on_condition( | |
self.positive_samples, | |
suffix, | |
preprint_index, | |
confirm_overlap | |
) | |
# If preprint_index == -1, no index is excluded from being sampled by sample_random. | |
# This is because the indices are derived from the following logic: | |
# np.array([j for j in range(len(positive_samples)) if j != preprint_index]). | |
# Since j >= 0 and preprint_index is -1, the condition (j != preprint_index) always evaluates to True. | |
return [ | |
(positive_preprint, article) | |
for _, article in self.sample_random(-1, factor_max, random_state, custom_samples) | |
] | |
def sample_similar_topic( | |
self, | |
preprint_index: int, | |
factor_max: int = 4, | |
random_state: int = -1 | |
) -> List[List[pd.DataFrame]]: | |
""" | |
Samples a non-matching article with the same topic to create the negative sample. | |
:param preprint: The preprint for which to create a negative sample. | |
:param topic_key: The key in the dataset that contains the topics. | |
:return: A negative sample with a similar topic. | |
""" | |
suffix = "classification_primary_topic_field" | |
positive_preprint, positive_article = self.positive_samples[preprint_index] | |
# Collect preprint-article pairs where the article shares the same topic as the selected preprint. | |
# Exclude the pair matching the selected preprint to ensure proper functionality of random sampling later. | |
custom_samples = keep_on_condition( | |
self.positive_samples, | |
suffix, | |
preprint_index, | |
lambda x: x == positive_article[f"article_{suffix}"].iloc[0] | |
) | |
# If preprint_index == -1, no index is excluded from being sampled by sample_random. | |
# This is because the indices are derived from the following logic: | |
# np.array([j for j in range(len(positive_samples)) if j != preprint_index]). | |
# Since j >= 0 and preprint_index is -1, the condition (j != preprint_index) always evaluates to True. | |
return [ | |
(positive_preprint, article) | |
for _, article in self.sample_random(-1, factor_max, random_state, custom_samples) | |
] | |
def create_negative_samples(self, config): | |
""" | |
Generate negative samples based on the configuration. | |
""" | |
negative_samples = [] | |
for preprint_index in tqdm(range(len(self.positive_samples)), desc="Negative Sampling"): | |
negatives = [] | |
if config.overlap_auth and not config.overlap_topic: | |
negatives = self.sample_authors_overlap( | |
preprint_index, factor_max=config.factor_max, | |
random_state=config.seed, | |
authors_to_consider=config.authors_to_consider, | |
overlapping_authors=config.overlapping_authors | |
) | |
elif config.overlap_topic and not config.overlap_auth: | |
negatives = self.sample_similar_topic(preprint_index, factor_max=config.factor_max, random_state=config.seed) | |
elif config.random: | |
negatives = self.sample_random(preprint_index, factor_max=config.factor_max, random_state=config.seed) | |
else: | |
continue | |
if config.fuzz_title: | |
negatives = self.fuzz_title(custom_samples=negatives) | |
if config.replace_auth: | |
negatives = self.sample_authors_overlap_random(negatives) | |
negative_samples.extend(negatives) | |
self.dataset.negative_samples = negative_samples |