Spaces:
Sleeping
Sleeping
File size: 13,959 Bytes
b5cf002 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 |
import random
from typing import List, Dict, Any, Union, Set, Callable
import copy
import pandas as pd
import numpy as np
import nltk
from nltk.corpus import words
nltk.download("words")
from src.dataset.GoodDataset import *
def copy_column_value(
df1: pd.DataFrame,
df2: pd.DataFrame,
source_col: str,
target_col:str,
source_transform: Callable[[Any], Any] = lambda x: x
) -> List[pd.DataFrame]:
"""
Copies the value from `source_col` in `df1` to `target_col` in `df2`,
while ensuring that the original DataFrames remain unaltered by
working on deep copies.
Args:
df1 (pd.DataFrame): The source DataFrame containing the value to copy.
df2 (pd.DataFrame): The target DataFrame where the value will be copied.
source_col (str): The column name in `df1` from which the value will be sourced.
target_col (str): The column name in `df2` where the value will be written.
Returns:
List[pd.DataFrame]: A list containing the original `df1` and the modified copy of `df2`.
"""
# Create a deepcopy of `df2` to ensure the original DataFrame remains unchanged.
df2_copy = copy.deepcopy(df2)
# Extract the value from the first row of the specified source column in `df1`.
value_to_copy = df1.iloc[0][source_col]
# Write the extracted value to the first row of the specified target column in the copied `df2`.
df2_copy.at[0, target_col] = source_transform(value_to_copy)
return [df1, df2_copy]
def keep_on_condition(
dataset: List[List[pd.DataFrame]],
column_to_check: str,
indices_to_ignore: Union[List[int], Set[int], int],
function_to_compare: Callable[[Any], bool]
) -> List[List[pd.DataFrame]]:
"""
Filters a dataset based on a column value and ignores specified indices.
Args:
dataset (List[List[pd.DataFrame]]): The dataset to filter, organized as a list of pairs of DataFrames.
column_to_check (str): The column in the article DataFrame to check for values.
values_to_keep (Union[List[Any], Set[Any], Any]): Values to keep in the filtering process.
indices_to_ignore (Union[List[int], Set[int], int]): Indices to ignore during filtering.
article_transform (Callable[[Any], Any], optional): Transformation function for column values. Defaults to identity.
Returns:
List[List[pd.DataFrame]]: Filtered dataset.
"""
# Normalize `indices_to_ignore` to a set
if isinstance(indices_to_ignore, int):
indices_to_ignore = {indices_to_ignore}
elif isinstance(indices_to_ignore, list):
indices_to_ignore = set(indices_to_ignore)
# Filter dataset
return [
[preprint, article]
for i, (preprint, article) in enumerate(dataset)
if (
f"article_{column_to_check}" in article.columns and
function_to_compare(article[f"article_{column_to_check}"].iloc[0]) and
i not in indices_to_ignore
)
]
class NegativeSampler:
# def __init__(self, positive_samples: List[List[pd.DataFrame]]):
def __init__(self, dataset: AugmentedDataset):
"""
Initializes the NegativeSampler with a dataset of preprint-article pairs.
:param positive_samples: List of dictionaries, each containing information about preprints and articles.
"""
self.dataset = dataset
self.positive_samples = dataset.positive_samples
### ARGUMENTS for negative sampling here?
def sample_random(
self,
preprint_index: int,
factor_max: int = 4,
random_state: int = -1,
custom_samples: List[List[pd.DataFrame]] = None
) -> List[List[pd.DataFrame]]:
"""
Randomly samples a non-matching article to create the negative sample.
:param preprint: The preprint for which to create a negative sample.
:return: A randomly selected negative sample.
"""
if random_state >= 0:
np.random.seed(random_state)
positive_samples = custom_samples if custom_samples is not None else self.positive_samples
factor = min(len(positive_samples), factor_max)
assert factor >= 1, "Dataset doesn't contain enough samples"
# Sample `factor` non-matching articles from the dataset to create the negative samples
other_indices = np.array([j for j in range(len(positive_samples)) if j != preprint_index])
sampled_indices = np.random.choice(other_indices, size=factor, replace=False)
sampled_rows = [positive_samples[j] for j in sampled_indices]
if preprint_index < 0:
return sampled_rows
# Create and return the negative samples using the original preprint and the sampled article
preprint, _ = positive_samples[preprint_index]
return [
[preprint, non_matching_article]
for _, non_matching_article in sampled_rows
]
def fuzz_title(
self,
fuzz_count: int = -1,
custom_samples: List[List[pd.DataFrame]] = None
) -> List[List[pd.DataFrame]]:
"""
Fuzzes out the title to create the negative sample. Likely changes the abstract and/or authors.
:param preprint: The preprint for which to create a negative sample.
:param fuzz_factor: A threshold for title similarity (0.0 to 1.0).
:return: A negative sample with a fuzzed title.
"""
def replace_with_random_words(text: str, fuzz_count: int = fuzz_count) -> str:
"""
Replaces a specified number of words in the input string with random words
from the NLTK `words` corpus.
Args:
text (str): The input string to fuzz.
fuzz_count (int): The number of words to replace in the string.
Returns:
str: The string with random word replacements.
"""
if fuzz_count == -1:
fuzz_count = text.count(" ") // 2
# Load the list of English words from the NLTK corpus
word_list = words.words()
# Split the input text into a list of words
text_words = text.split()
for _ in range(fuzz_count):
# Randomly pick a word in the text to replace
index_to_replace = random.randint(0, len(text_words) - 1)
# Replace it with a random word from the NLTK corpus
random_word = random.choice(word_list)
text_words[index_to_replace] = random_word
# Join the list back into a string and return
return " ".join(text_words)
return [
copy_column_value(preprint, article, "prpnt_basic_title", "article_basic_title", replace_with_random_words)
for preprint, article in (custom_samples or self.positive_samples)
]
def sample_authors_overlap_random(
self,
custom_samples: List[List[pd.DataFrame]] = None
) -> List[List[pd.DataFrame]]:
"""
Samples a random non-matching article and replaces its authors with the preprint authors to create the negative sample.
:param preprint: The preprint for which to create a negative sample.
:return: A negative sample with authors replaced.
"""
return [
copy_column_value(preprint, article, "prpnt_authors_id", "article_authors_id")
for preprint, article in (custom_samples or self.positive_samples)
]
def sample_authors_overlap(
self,
preprint_index: int,
factor_max: int = 4,
random_state: int = -1,
authors_to_consider: int = 1,
overlapping_authors: int = 1
) -> List[List[pd.DataFrame]]:
"""
Samples a published article with some author overlap to create the negative sample.
:param preprint: The preprint for which to create a negative sample.
:return: A negative sample with some author overlap.
"""
def extract_authors(authors_str: str, authors_to_keep: int = -1) -> list:
"""
Extracts a list of authors from a string, with an optional limit on the number of authors to return.
Args:
authors_str (str): A string containing authors, expected to include `openalex` and be separated by `$@$`.
authors_to_keep (int, optional): The number of authors to keep. If -1, all authors except the last one are kept. Defaults to -1.
Returns:
list: A list of authors, truncated to the specified number if `authors_to_keep` is provided.
Raises:
ValueError: If `authors_str` does not contain the substring `openalex`.
"""
# Split the authors string into a list using the custom delimiter `$@$`
authors_list = authors_str.split("$@$")
if not authors_list:
raise ValueError(f"Invalid input: {authors_str}. The string must contain 'openalex'.")
# Determine how many authors to keep
if authors_to_keep == -1:
authors_to_keep = len(authors_list) # Exclude the last item
# Return the truncated list of authors
return authors_list[:authors_to_keep]
suffix = "authors_id"
positive_preprint, _ = self.positive_samples[preprint_index]
preprint_authors = set(extract_authors(positive_preprint[f"prpnt_{suffix}"].iloc[0]))
def confirm_overlap(article_authors):
article_authors = set(extract_authors(article_authors, authors_to_consider))
if len(preprint_authors.intersection(article_authors)) >= overlapping_authors:
print(f"\t{article_authors}")
return len(preprint_authors.intersection(article_authors)) >= overlapping_authors
# Collect preprint-article pairs where the article has some overlapping authors with the selected preprint.
# Exclude the pair matching the selected preprint to ensure proper functionality of random sampling later.
custom_samples = keep_on_condition(
self.positive_samples,
suffix,
preprint_index,
confirm_overlap
)
# If preprint_index == -1, no index is excluded from being sampled by sample_random.
# This is because the indices are derived from the following logic:
# np.array([j for j in range(len(positive_samples)) if j != preprint_index]).
# Since j >= 0 and preprint_index is -1, the condition (j != preprint_index) always evaluates to True.
return [
(positive_preprint, article)
for _, article in self.sample_random(-1, factor_max, random_state, custom_samples)
]
def sample_similar_topic(
self,
preprint_index: int,
factor_max: int = 4,
random_state: int = -1
) -> List[List[pd.DataFrame]]:
"""
Samples a non-matching article with the same topic to create the negative sample.
:param preprint: The preprint for which to create a negative sample.
:param topic_key: The key in the dataset that contains the topics.
:return: A negative sample with a similar topic.
"""
suffix = "classification_primary_topic_field"
positive_preprint, positive_article = self.positive_samples[preprint_index]
# Collect preprint-article pairs where the article shares the same topic as the selected preprint.
# Exclude the pair matching the selected preprint to ensure proper functionality of random sampling later.
custom_samples = keep_on_condition(
self.positive_samples,
suffix,
preprint_index,
lambda x: x == positive_article[f"article_{suffix}"].iloc[0]
)
# If preprint_index == -1, no index is excluded from being sampled by sample_random.
# This is because the indices are derived from the following logic:
# np.array([j for j in range(len(positive_samples)) if j != preprint_index]).
# Since j >= 0 and preprint_index is -1, the condition (j != preprint_index) always evaluates to True.
return [
(positive_preprint, article)
for _, article in self.sample_random(-1, factor_max, random_state, custom_samples)
]
def create_negative_samples(self, config):
"""
Generate negative samples based on the configuration.
"""
negative_samples = []
for preprint_index in tqdm(range(len(self.positive_samples)), desc="Negative Sampling"):
negatives = []
if config.overlap_auth and not config.overlap_topic:
negatives = self.sample_authors_overlap(
preprint_index, factor_max=config.factor_max,
random_state=config.seed,
authors_to_consider=config.authors_to_consider,
overlapping_authors=config.overlapping_authors
)
elif config.overlap_topic and not config.overlap_auth:
negatives = self.sample_similar_topic(preprint_index, factor_max=config.factor_max, random_state=config.seed)
elif config.random:
negatives = self.sample_random(preprint_index, factor_max=config.factor_max, random_state=config.seed)
else:
continue
if config.fuzz_title:
negatives = self.fuzz_title(custom_samples=negatives)
if config.replace_auth:
negatives = self.sample_authors_overlap_random(negatives)
negative_samples.extend(negatives)
self.dataset.negative_samples = negative_samples |