File size: 6,554 Bytes
b5cf002
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
from src.dataset.DataAugmenter import *
import pandas as pd 
from tqdm import tqdm
import numpy as np

class FullAugmentedDataset: 

    def __init__(self):
        self.augmenter = DataAugmenter()
        self.full_raw_dataset = self._load_the_dataset()

    def _load_the_dataset(self, type: DatasetType = DatasetType.FULL_RAW) -> pd.DataFrame:
        """Load as csv file one of the datasets for training."""
        assert str(PROJECT_ROOT).split("/")[-1] == "MatchingPubs", "Please run this script in the github repo folder "
        
        if type == DatasetType.FULL_RAW:
            return pd.read_csv(f"{PROJECT_ROOT}/data/crossref-preprint-article-relationships-Aug-2023.csv")

    def retrieve_dois_couple(self, len: int = 1, random: bool = False, seed: bool = None, full: bool = False):
        """Retrieve two DOIs from the dataset"""
        if random:
            dois = self.full_raw_dataset.sample(n=len, random_state=seed)[["preprint_doi", "article_doi"]]
        else:
            dois = self.full_raw_dataset.head(len)[["preprint_doi", "article_doi"]]
        if full:
            dois = self.full_raw_dataset[["preprint_doi", "article_doi"]]
        return dois.to_numpy()
    
    @staticmethod
    def _flatten_list(lst):
        """
        Flattens a nested list into a single list. If the input is not nested, it returns the original list.
        Handles cases where some elements are lists and others are not.
        """
        if not isinstance(lst, list):  # Ensure the input is a list
            raise ValueError("Input must be a list")

        def _flatten(sublist):
            for item in sublist:
                if isinstance(item, list):  # Check if the item is a list
                    yield from _flatten(item)  # Recursively flatten the list
                else:
                    yield item  # Yield the non-list item

        return list(_flatten(lst))
    
    def _augmented_data_to_row(self, filtered_data: Dict[str, Any], preprint: bool = True) -> pd.Series:
        """Transform filtered augmented data into a pandas Series
        
        Args:
            filtered_data: Dictionary containing filtered OpenAlex and Elsevier data
            preprint: If True, use prpnt_ prefix, else use article_ prefix
            
        Returns:
            pd.Series: Flattened data as a single row
        """

        additional_part = FullAugmentedDataset.filter_author(filtered_data.get("authors",{}))
        # modify the key of additional part by adding authors_ at the beginning
        additional_part = {f"authors_{k}": v for k, v in additional_part.items()} 
        # remove authos key from filtreed_info
        filtered_data.pop("authors")
        # append the additional part to the filtered_info
        filtered_data.update(additional_part)
        final_dictionary = FullAugmentedDataset.flatten_dict(filtered_data, preprint=preprint)

        for k, v in final_dictionary.items():
            final_dictionary[k] = "$@$".join(map(str, FullAugmentedDataset._flatten_list(v))) if isinstance(v, list) else [v]

        return pd.DataFrame(final_dictionary)

    @staticmethod
    def filter_author(authors_info : list) -> dict:

        try:
            relevant_keys = authors_info[0].keys()
            new_dict = {}
            for key in relevant_keys:
                new_dict[key] = [author[key] for author in authors_info]
            return new_dict
        except:
            return {}
    
    @staticmethod
    def flatten_dict(d: dict, parent_key: str = '', sep: str = '_', preprint = True) -> dict:
        """Flatten a nested dictionary.
        
        Args:
            d (dict): The dictionary to flatten.
            parent_key (str): The base key string to use for the flattened keys.
            sep (str): The separator to use between parent and child keys.
            
        Returns:
            dict: The flattened dictionary.
        """
        addition = "prpnt_" if preprint else "article_"
        def _flatten_dict(d: dict, parent_key: str = '', sep: str = '_') -> dict:
            items = []
            for k, v in d.items():
                new_key = f"{parent_key}{sep}{k}" if parent_key else k
                if isinstance(v, dict):
                    items.extend(_flatten_dict(v, new_key, sep=sep).items())
                else:
                    items.append((new_key, v))
            return dict(items)
        return {f"{addition}{k}": v for k, v in _flatten_dict(d, parent_key, sep).items()}

    def process_pair(self, dois) -> pd.DataFrame:
        """Process a pair of DOIs and return combined rows as a DataFrame"""
        assert len(dois) > 0
        rows = []
        for preprint_doi, article_doi in tqdm(dois):
            # Get preprint features
            preprint_features = self.augmenter.get_alex_features(preprint_doi) # augment with all the features
            preprint_filtered = self.augmenter.filter_augmented_data(preprint_features) # filter the relevant features
            preprint_row = self._augmented_data_to_row(preprint_filtered, True)

            # Get article features
            article_features = self.augmenter.get_alex_features(article_doi) # augment with all the features
            article_filtered = self.augmenter.filter_augmented_data(article_features)
            article_row = self._augmented_data_to_row(article_filtered, False)

            rows.append([preprint_row, article_row])

        return rows

    @staticmethod
    def transform_array(input_array, factor):
        output_list = []
        
        for i, row in enumerate(input_array):
            other_indices = np.array([j for j in range(len(input_array)) if j != i])
            sampled_indices = np.random.choice(other_indices, size=factor, replace=False)
            sampled_rows = [input_array[j] for j in sampled_indices]

            output_list.append(pd.concat([row[0], row[1], pd.DataFrame(data=[1], columns=['label'])], axis=1))
            for B in sampled_rows:
                output_list.append(pd.concat([row[0], B[1], pd.DataFrame(data=[0], columns=['label'])], axis=1))

        return pd.concat(output_list).reset_index(drop=True)

    def get_full_dataset(self, len: int = 1, random: bool = True, seed: int = 42, full: bool = True) -> pd.DataFrame:
        """Process all DOI pairs and return full dataset"""
        dois = self.retrieve_dois_couple(len, random, seed, full)
        self.augmented_df = FullAugmentedDataset.transform_array(self.process_pair(dois), factor=4)
        return self.augmented_df