MatchPrePrintArticles / src /dataset /NegativeSampler.py
tmencatt's picture
app
b5cf002
import random
from typing import List, Dict, Any, Union, Set, Callable
import copy
import pandas as pd
import numpy as np
import nltk
from nltk.corpus import words
nltk.download("words")
from src.dataset.GoodDataset import *
def copy_column_value(
df1: pd.DataFrame,
df2: pd.DataFrame,
source_col: str,
target_col:str,
source_transform: Callable[[Any], Any] = lambda x: x
) -> List[pd.DataFrame]:
"""
Copies the value from `source_col` in `df1` to `target_col` in `df2`,
while ensuring that the original DataFrames remain unaltered by
working on deep copies.
Args:
df1 (pd.DataFrame): The source DataFrame containing the value to copy.
df2 (pd.DataFrame): The target DataFrame where the value will be copied.
source_col (str): The column name in `df1` from which the value will be sourced.
target_col (str): The column name in `df2` where the value will be written.
Returns:
List[pd.DataFrame]: A list containing the original `df1` and the modified copy of `df2`.
"""
# Create a deepcopy of `df2` to ensure the original DataFrame remains unchanged.
df2_copy = copy.deepcopy(df2)
# Extract the value from the first row of the specified source column in `df1`.
value_to_copy = df1.iloc[0][source_col]
# Write the extracted value to the first row of the specified target column in the copied `df2`.
df2_copy.at[0, target_col] = source_transform(value_to_copy)
return [df1, df2_copy]
def keep_on_condition(
dataset: List[List[pd.DataFrame]],
column_to_check: str,
indices_to_ignore: Union[List[int], Set[int], int],
function_to_compare: Callable[[Any], bool]
) -> List[List[pd.DataFrame]]:
"""
Filters a dataset based on a column value and ignores specified indices.
Args:
dataset (List[List[pd.DataFrame]]): The dataset to filter, organized as a list of pairs of DataFrames.
column_to_check (str): The column in the article DataFrame to check for values.
values_to_keep (Union[List[Any], Set[Any], Any]): Values to keep in the filtering process.
indices_to_ignore (Union[List[int], Set[int], int]): Indices to ignore during filtering.
article_transform (Callable[[Any], Any], optional): Transformation function for column values. Defaults to identity.
Returns:
List[List[pd.DataFrame]]: Filtered dataset.
"""
# Normalize `indices_to_ignore` to a set
if isinstance(indices_to_ignore, int):
indices_to_ignore = {indices_to_ignore}
elif isinstance(indices_to_ignore, list):
indices_to_ignore = set(indices_to_ignore)
# Filter dataset
return [
[preprint, article]
for i, (preprint, article) in enumerate(dataset)
if (
f"article_{column_to_check}" in article.columns and
function_to_compare(article[f"article_{column_to_check}"].iloc[0]) and
i not in indices_to_ignore
)
]
class NegativeSampler:
# def __init__(self, positive_samples: List[List[pd.DataFrame]]):
def __init__(self, dataset: AugmentedDataset):
"""
Initializes the NegativeSampler with a dataset of preprint-article pairs.
:param positive_samples: List of dictionaries, each containing information about preprints and articles.
"""
self.dataset = dataset
self.positive_samples = dataset.positive_samples
### ARGUMENTS for negative sampling here?
def sample_random(
self,
preprint_index: int,
factor_max: int = 4,
random_state: int = -1,
custom_samples: List[List[pd.DataFrame]] = None
) -> List[List[pd.DataFrame]]:
"""
Randomly samples a non-matching article to create the negative sample.
:param preprint: The preprint for which to create a negative sample.
:return: A randomly selected negative sample.
"""
if random_state >= 0:
np.random.seed(random_state)
positive_samples = custom_samples if custom_samples is not None else self.positive_samples
factor = min(len(positive_samples), factor_max)
assert factor >= 1, "Dataset doesn't contain enough samples"
# Sample `factor` non-matching articles from the dataset to create the negative samples
other_indices = np.array([j for j in range(len(positive_samples)) if j != preprint_index])
sampled_indices = np.random.choice(other_indices, size=factor, replace=False)
sampled_rows = [positive_samples[j] for j in sampled_indices]
if preprint_index < 0:
return sampled_rows
# Create and return the negative samples using the original preprint and the sampled article
preprint, _ = positive_samples[preprint_index]
return [
[preprint, non_matching_article]
for _, non_matching_article in sampled_rows
]
def fuzz_title(
self,
fuzz_count: int = -1,
custom_samples: List[List[pd.DataFrame]] = None
) -> List[List[pd.DataFrame]]:
"""
Fuzzes out the title to create the negative sample. Likely changes the abstract and/or authors.
:param preprint: The preprint for which to create a negative sample.
:param fuzz_factor: A threshold for title similarity (0.0 to 1.0).
:return: A negative sample with a fuzzed title.
"""
def replace_with_random_words(text: str, fuzz_count: int = fuzz_count) -> str:
"""
Replaces a specified number of words in the input string with random words
from the NLTK `words` corpus.
Args:
text (str): The input string to fuzz.
fuzz_count (int): The number of words to replace in the string.
Returns:
str: The string with random word replacements.
"""
if fuzz_count == -1:
fuzz_count = text.count(" ") // 2
# Load the list of English words from the NLTK corpus
word_list = words.words()
# Split the input text into a list of words
text_words = text.split()
for _ in range(fuzz_count):
# Randomly pick a word in the text to replace
index_to_replace = random.randint(0, len(text_words) - 1)
# Replace it with a random word from the NLTK corpus
random_word = random.choice(word_list)
text_words[index_to_replace] = random_word
# Join the list back into a string and return
return " ".join(text_words)
return [
copy_column_value(preprint, article, "prpnt_basic_title", "article_basic_title", replace_with_random_words)
for preprint, article in (custom_samples or self.positive_samples)
]
def sample_authors_overlap_random(
self,
custom_samples: List[List[pd.DataFrame]] = None
) -> List[List[pd.DataFrame]]:
"""
Samples a random non-matching article and replaces its authors with the preprint authors to create the negative sample.
:param preprint: The preprint for which to create a negative sample.
:return: A negative sample with authors replaced.
"""
return [
copy_column_value(preprint, article, "prpnt_authors_id", "article_authors_id")
for preprint, article in (custom_samples or self.positive_samples)
]
def sample_authors_overlap(
self,
preprint_index: int,
factor_max: int = 4,
random_state: int = -1,
authors_to_consider: int = 1,
overlapping_authors: int = 1
) -> List[List[pd.DataFrame]]:
"""
Samples a published article with some author overlap to create the negative sample.
:param preprint: The preprint for which to create a negative sample.
:return: A negative sample with some author overlap.
"""
def extract_authors(authors_str: str, authors_to_keep: int = -1) -> list:
"""
Extracts a list of authors from a string, with an optional limit on the number of authors to return.
Args:
authors_str (str): A string containing authors, expected to include `openalex` and be separated by `$@$`.
authors_to_keep (int, optional): The number of authors to keep. If -1, all authors except the last one are kept. Defaults to -1.
Returns:
list: A list of authors, truncated to the specified number if `authors_to_keep` is provided.
Raises:
ValueError: If `authors_str` does not contain the substring `openalex`.
"""
# Split the authors string into a list using the custom delimiter `$@$`
authors_list = authors_str.split("$@$")
if not authors_list:
raise ValueError(f"Invalid input: {authors_str}. The string must contain 'openalex'.")
# Determine how many authors to keep
if authors_to_keep == -1:
authors_to_keep = len(authors_list) # Exclude the last item
# Return the truncated list of authors
return authors_list[:authors_to_keep]
suffix = "authors_id"
positive_preprint, _ = self.positive_samples[preprint_index]
preprint_authors = set(extract_authors(positive_preprint[f"prpnt_{suffix}"].iloc[0]))
def confirm_overlap(article_authors):
article_authors = set(extract_authors(article_authors, authors_to_consider))
if len(preprint_authors.intersection(article_authors)) >= overlapping_authors:
print(f"\t{article_authors}")
return len(preprint_authors.intersection(article_authors)) >= overlapping_authors
# Collect preprint-article pairs where the article has some overlapping authors with the selected preprint.
# Exclude the pair matching the selected preprint to ensure proper functionality of random sampling later.
custom_samples = keep_on_condition(
self.positive_samples,
suffix,
preprint_index,
confirm_overlap
)
# If preprint_index == -1, no index is excluded from being sampled by sample_random.
# This is because the indices are derived from the following logic:
# np.array([j for j in range(len(positive_samples)) if j != preprint_index]).
# Since j >= 0 and preprint_index is -1, the condition (j != preprint_index) always evaluates to True.
return [
(positive_preprint, article)
for _, article in self.sample_random(-1, factor_max, random_state, custom_samples)
]
def sample_similar_topic(
self,
preprint_index: int,
factor_max: int = 4,
random_state: int = -1
) -> List[List[pd.DataFrame]]:
"""
Samples a non-matching article with the same topic to create the negative sample.
:param preprint: The preprint for which to create a negative sample.
:param topic_key: The key in the dataset that contains the topics.
:return: A negative sample with a similar topic.
"""
suffix = "classification_primary_topic_field"
positive_preprint, positive_article = self.positive_samples[preprint_index]
# Collect preprint-article pairs where the article shares the same topic as the selected preprint.
# Exclude the pair matching the selected preprint to ensure proper functionality of random sampling later.
custom_samples = keep_on_condition(
self.positive_samples,
suffix,
preprint_index,
lambda x: x == positive_article[f"article_{suffix}"].iloc[0]
)
# If preprint_index == -1, no index is excluded from being sampled by sample_random.
# This is because the indices are derived from the following logic:
# np.array([j for j in range(len(positive_samples)) if j != preprint_index]).
# Since j >= 0 and preprint_index is -1, the condition (j != preprint_index) always evaluates to True.
return [
(positive_preprint, article)
for _, article in self.sample_random(-1, factor_max, random_state, custom_samples)
]
def create_negative_samples(self, config):
"""
Generate negative samples based on the configuration.
"""
negative_samples = []
for preprint_index in tqdm(range(len(self.positive_samples)), desc="Negative Sampling"):
negatives = []
if config.overlap_auth and not config.overlap_topic:
negatives = self.sample_authors_overlap(
preprint_index, factor_max=config.factor_max,
random_state=config.seed,
authors_to_consider=config.authors_to_consider,
overlapping_authors=config.overlapping_authors
)
elif config.overlap_topic and not config.overlap_auth:
negatives = self.sample_similar_topic(preprint_index, factor_max=config.factor_max, random_state=config.seed)
elif config.random:
negatives = self.sample_random(preprint_index, factor_max=config.factor_max, random_state=config.seed)
else:
continue
if config.fuzz_title:
negatives = self.fuzz_title(custom_samples=negatives)
if config.replace_auth:
negatives = self.sample_authors_overlap_random(negatives)
negative_samples.extend(negatives)
self.dataset.negative_samples = negative_samples