File size: 13,959 Bytes
b5cf002
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
import random
from typing import List, Dict, Any, Union, Set, Callable
import copy

import pandas as pd
import numpy as np

import nltk
from nltk.corpus import words
nltk.download("words")

from src.dataset.GoodDataset import *


def copy_column_value(
    df1: pd.DataFrame, 
    df2: pd.DataFrame, 
    source_col: str, 
    target_col:str,
    source_transform: Callable[[Any], Any] = lambda x: x
) -> List[pd.DataFrame]:
    """
    Copies the value from `source_col` in `df1` to `target_col` in `df2`,
    while ensuring that the original DataFrames remain unaltered by
    working on deep copies.

    Args:
        df1 (pd.DataFrame): The source DataFrame containing the value to copy.
        df2 (pd.DataFrame): The target DataFrame where the value will be copied.
        source_col (str): The column name in `df1` from which the value will be sourced.
        target_col (str): The column name in `df2` where the value will be written.

    Returns:
        List[pd.DataFrame]: A list containing the original `df1` and the modified copy of `df2`.
    """
    # Create a deepcopy of `df2` to ensure the original DataFrame remains unchanged.
    df2_copy = copy.deepcopy(df2)

    # Extract the value from the first row of the specified source column in `df1`.
    value_to_copy = df1.iloc[0][source_col]

    # Write the extracted value to the first row of the specified target column in the copied `df2`.
    df2_copy.at[0, target_col] = source_transform(value_to_copy)
    
    return [df1, df2_copy]

def keep_on_condition(
    dataset: List[List[pd.DataFrame]], 
    column_to_check: str, 
    indices_to_ignore: Union[List[int], Set[int], int],
    function_to_compare: Callable[[Any], bool]
) -> List[List[pd.DataFrame]]:
    """
    Filters a dataset based on a column value and ignores specified indices.

    Args:
        dataset (List[List[pd.DataFrame]]): The dataset to filter, organized as a list of pairs of DataFrames.
        column_to_check (str): The column in the article DataFrame to check for values.
        values_to_keep (Union[List[Any], Set[Any], Any]): Values to keep in the filtering process.
        indices_to_ignore (Union[List[int], Set[int], int]): Indices to ignore during filtering.
        article_transform (Callable[[Any], Any], optional): Transformation function for column values. Defaults to identity.

    Returns:
        List[List[pd.DataFrame]]: Filtered dataset.
    """

    # Normalize `indices_to_ignore` to a set
    if isinstance(indices_to_ignore, int):
        indices_to_ignore = {indices_to_ignore}
    elif isinstance(indices_to_ignore, list):
        indices_to_ignore = set(indices_to_ignore)

    # Filter dataset
    return [
        [preprint, article]
        for i, (preprint, article) in enumerate(dataset)
        if (
            f"article_{column_to_check}" in article.columns and
            function_to_compare(article[f"article_{column_to_check}"].iloc[0]) and
            i not in indices_to_ignore
        )
    ]

class NegativeSampler:
    # def __init__(self, positive_samples: List[List[pd.DataFrame]]):
    def __init__(self, dataset: AugmentedDataset):
        """
        Initializes the NegativeSampler with a dataset of preprint-article pairs.
        :param positive_samples: List of dictionaries, each containing information about preprints and articles.
        """
        self.dataset = dataset
        self.positive_samples = dataset.positive_samples

        ### ARGUMENTS for negative sampling here?

    def sample_random(
        self, 
        preprint_index: int, 
        factor_max: int = 4,
        random_state: int = -1,
        custom_samples: List[List[pd.DataFrame]] = None
    ) -> List[List[pd.DataFrame]]:
        """
        Randomly samples a non-matching article to create the negative sample.
        :param preprint: The preprint for which to create a negative sample.
        :return: A randomly selected negative sample.
        """
        if random_state >= 0:
            np.random.seed(random_state)

        positive_samples = custom_samples if custom_samples is not None else self.positive_samples

        factor = min(len(positive_samples), factor_max)
        assert factor >= 1, "Dataset doesn't contain enough samples"

        # Sample `factor` non-matching articles from the dataset to create the negative samples
        other_indices = np.array([j for j in range(len(positive_samples)) if j != preprint_index])
        sampled_indices = np.random.choice(other_indices, size=factor, replace=False)
        sampled_rows = [positive_samples[j] for j in sampled_indices]

        if preprint_index < 0:
            return sampled_rows
        
        # Create and return the negative samples using the original preprint and the sampled article
        preprint, _ = positive_samples[preprint_index]
        return [
            [preprint, non_matching_article]
            for _, non_matching_article in sampled_rows
        ]

    def fuzz_title(
        self, 
        fuzz_count: int = -1,
        custom_samples: List[List[pd.DataFrame]] = None
    ) -> List[List[pd.DataFrame]]:
        """
        Fuzzes out the title to create the negative sample. Likely changes the abstract and/or authors.
        :param preprint: The preprint for which to create a negative sample.
        :param fuzz_factor: A threshold for title similarity (0.0 to 1.0).
        :return: A negative sample with a fuzzed title.
        """
        def replace_with_random_words(text: str, fuzz_count: int = fuzz_count) -> str:
            """
            Replaces a specified number of words in the input string with random words
            from the NLTK `words` corpus.

            Args:
                text (str): The input string to fuzz.
                fuzz_count (int): The number of words to replace in the string.

            Returns:
                str: The string with random word replacements.
            """
            if fuzz_count == -1:
                fuzz_count = text.count(" ") // 2

            # Load the list of English words from the NLTK corpus
            word_list = words.words()
            
            # Split the input text into a list of words
            text_words = text.split()
            
            for _ in range(fuzz_count):
                # Randomly pick a word in the text to replace
                index_to_replace = random.randint(0, len(text_words) - 1)
                
                # Replace it with a random word from the NLTK corpus
                random_word = random.choice(word_list)
                text_words[index_to_replace] = random_word

            # Join the list back into a string and return
            return " ".join(text_words)

        return [
            copy_column_value(preprint, article, "prpnt_basic_title", "article_basic_title", replace_with_random_words)
            for preprint, article in (custom_samples or self.positive_samples)
        ]

    def sample_authors_overlap_random(
        self,
        custom_samples: List[List[pd.DataFrame]] = None
    ) -> List[List[pd.DataFrame]]:
        """
        Samples a random non-matching article and replaces its authors with the preprint authors to create the negative sample.
        :param preprint: The preprint for which to create a negative sample.
        :return: A negative sample with authors replaced.
        """
        return [
            copy_column_value(preprint, article, "prpnt_authors_id", "article_authors_id")
            for preprint, article in (custom_samples or self.positive_samples)
        ]

    def sample_authors_overlap(
        self,
        preprint_index: int, 
        factor_max: int = 4,
        random_state: int = -1,
        authors_to_consider: int = 1,
        overlapping_authors: int = 1
    ) -> List[List[pd.DataFrame]]:
        """
        Samples a published article with some author overlap to create the negative sample.
        :param preprint: The preprint for which to create a negative sample.
        :return: A negative sample with some author overlap.
        """
        def extract_authors(authors_str: str, authors_to_keep: int = -1) -> list:
            """
            Extracts a list of authors from a string, with an optional limit on the number of authors to return.

            Args:
                authors_str (str): A string containing authors, expected to include `openalex` and be separated by `$@$`.
                authors_to_keep (int, optional): The number of authors to keep. If -1, all authors except the last one are kept. Defaults to -1.

            Returns:
                list: A list of authors, truncated to the specified number if `authors_to_keep` is provided.

            Raises:
                ValueError: If `authors_str` does not contain the substring `openalex`.
            """

            # Split the authors string into a list using the custom delimiter `$@$`
            authors_list = authors_str.split("$@$")

            if not authors_list:
                raise ValueError(f"Invalid input: {authors_str}. The string must contain 'openalex'.")

            # Determine how many authors to keep
            if authors_to_keep == -1:
                authors_to_keep = len(authors_list)  # Exclude the last item

            # Return the truncated list of authors
            return authors_list[:authors_to_keep]
        
        suffix = "authors_id"
        positive_preprint, _ = self.positive_samples[preprint_index]
        preprint_authors = set(extract_authors(positive_preprint[f"prpnt_{suffix}"].iloc[0]))

        
        def confirm_overlap(article_authors):
            article_authors = set(extract_authors(article_authors, authors_to_consider))
            if len(preprint_authors.intersection(article_authors)) >= overlapping_authors:
                print(f"\t{article_authors}")
            return len(preprint_authors.intersection(article_authors)) >= overlapping_authors

        # Collect preprint-article pairs where the article has some overlapping authors with the selected preprint.
        # Exclude the pair matching the selected preprint to ensure proper functionality of random sampling later.
        custom_samples = keep_on_condition(
            self.positive_samples, 
            suffix,
            preprint_index,
            confirm_overlap
        )

        # If preprint_index == -1, no index is excluded from being sampled by sample_random.
        # This is because the indices are derived from the following logic:
        # np.array([j for j in range(len(positive_samples)) if j != preprint_index]).
        # Since j >= 0 and preprint_index is -1, the condition (j != preprint_index) always evaluates to True.
        return [
            (positive_preprint, article)
            for _, article in self.sample_random(-1, factor_max, random_state, custom_samples)
        ]

    def sample_similar_topic(
        self,
        preprint_index: int, 
        factor_max: int = 4,
        random_state: int = -1
    ) -> List[List[pd.DataFrame]]:
        """
        Samples a non-matching article with the same topic to create the negative sample.
        :param preprint: The preprint for which to create a negative sample.
        :param topic_key: The key in the dataset that contains the topics.
        :return: A negative sample with a similar topic.
        """
        suffix = "classification_primary_topic_field"
        positive_preprint, positive_article = self.positive_samples[preprint_index]

        # Collect preprint-article pairs where the article shares the same topic as the selected preprint.
        # Exclude the pair matching the selected preprint to ensure proper functionality of random sampling later.
        custom_samples = keep_on_condition(
            self.positive_samples, 
            suffix,
            preprint_index,
            lambda x: x == positive_article[f"article_{suffix}"].iloc[0]
        )

        # If preprint_index == -1, no index is excluded from being sampled by sample_random.
        # This is because the indices are derived from the following logic:
        # np.array([j for j in range(len(positive_samples)) if j != preprint_index]).
        # Since j >= 0 and preprint_index is -1, the condition (j != preprint_index) always evaluates to True.
        return [
            (positive_preprint, article)
            for _, article in self.sample_random(-1, factor_max, random_state, custom_samples)
        ]

    def create_negative_samples(self, config):
            """
            Generate negative samples based on the configuration.
            """
            negative_samples = []
            for preprint_index in tqdm(range(len(self.positive_samples)), desc="Negative Sampling"):
                negatives = []
                if config.overlap_auth and not config.overlap_topic:
                    negatives = self.sample_authors_overlap(
                        preprint_index, factor_max=config.factor_max,
                        random_state=config.seed,
                        authors_to_consider=config.authors_to_consider,
                        overlapping_authors=config.overlapping_authors
                    )
                elif config.overlap_topic and not config.overlap_auth:
                    negatives = self.sample_similar_topic(preprint_index, factor_max=config.factor_max, random_state=config.seed)
                elif config.random:
                    negatives = self.sample_random(preprint_index, factor_max=config.factor_max, random_state=config.seed)
                else:
                    continue

                if config.fuzz_title:
                    negatives = self.fuzz_title(custom_samples=negatives)

                if config.replace_auth:
                    negatives = self.sample_authors_overlap_random(negatives)

                negative_samples.extend(negatives)

            self.dataset.negative_samples = negative_samples