tmencatt's picture
app
b5cf002
from src.dataset.DataAugmenter import *
import pandas as pd
from tqdm import tqdm
import numpy as np
class FullAugmentedDataset:
def __init__(self):
self.augmenter = DataAugmenter()
self.full_raw_dataset = self._load_the_dataset()
def _load_the_dataset(self, type: DatasetType = DatasetType.FULL_RAW) -> pd.DataFrame:
"""Load as csv file one of the datasets for training."""
assert str(PROJECT_ROOT).split("/")[-1] == "MatchingPubs", "Please run this script in the github repo folder "
if type == DatasetType.FULL_RAW:
return pd.read_csv(f"{PROJECT_ROOT}/data/crossref-preprint-article-relationships-Aug-2023.csv")
def retrieve_dois_couple(self, len: int = 1, random: bool = False, seed: bool = None, full: bool = False):
"""Retrieve two DOIs from the dataset"""
if random:
dois = self.full_raw_dataset.sample(n=len, random_state=seed)[["preprint_doi", "article_doi"]]
else:
dois = self.full_raw_dataset.head(len)[["preprint_doi", "article_doi"]]
if full:
dois = self.full_raw_dataset[["preprint_doi", "article_doi"]]
return dois.to_numpy()
@staticmethod
def _flatten_list(lst):
"""
Flattens a nested list into a single list. If the input is not nested, it returns the original list.
Handles cases where some elements are lists and others are not.
"""
if not isinstance(lst, list): # Ensure the input is a list
raise ValueError("Input must be a list")
def _flatten(sublist):
for item in sublist:
if isinstance(item, list): # Check if the item is a list
yield from _flatten(item) # Recursively flatten the list
else:
yield item # Yield the non-list item
return list(_flatten(lst))
def _augmented_data_to_row(self, filtered_data: Dict[str, Any], preprint: bool = True) -> pd.Series:
"""Transform filtered augmented data into a pandas Series
Args:
filtered_data: Dictionary containing filtered OpenAlex and Elsevier data
preprint: If True, use prpnt_ prefix, else use article_ prefix
Returns:
pd.Series: Flattened data as a single row
"""
additional_part = FullAugmentedDataset.filter_author(filtered_data.get("authors",{}))
# modify the key of additional part by adding authors_ at the beginning
additional_part = {f"authors_{k}": v for k, v in additional_part.items()}
# remove authos key from filtreed_info
filtered_data.pop("authors")
# append the additional part to the filtered_info
filtered_data.update(additional_part)
final_dictionary = FullAugmentedDataset.flatten_dict(filtered_data, preprint=preprint)
for k, v in final_dictionary.items():
final_dictionary[k] = "$@$".join(map(str, FullAugmentedDataset._flatten_list(v))) if isinstance(v, list) else [v]
return pd.DataFrame(final_dictionary)
@staticmethod
def filter_author(authors_info : list) -> dict:
try:
relevant_keys = authors_info[0].keys()
new_dict = {}
for key in relevant_keys:
new_dict[key] = [author[key] for author in authors_info]
return new_dict
except:
return {}
@staticmethod
def flatten_dict(d: dict, parent_key: str = '', sep: str = '_', preprint = True) -> dict:
"""Flatten a nested dictionary.
Args:
d (dict): The dictionary to flatten.
parent_key (str): The base key string to use for the flattened keys.
sep (str): The separator to use between parent and child keys.
Returns:
dict: The flattened dictionary.
"""
addition = "prpnt_" if preprint else "article_"
def _flatten_dict(d: dict, parent_key: str = '', sep: str = '_') -> dict:
items = []
for k, v in d.items():
new_key = f"{parent_key}{sep}{k}" if parent_key else k
if isinstance(v, dict):
items.extend(_flatten_dict(v, new_key, sep=sep).items())
else:
items.append((new_key, v))
return dict(items)
return {f"{addition}{k}": v for k, v in _flatten_dict(d, parent_key, sep).items()}
def process_pair(self, dois) -> pd.DataFrame:
"""Process a pair of DOIs and return combined rows as a DataFrame"""
assert len(dois) > 0
rows = []
for preprint_doi, article_doi in tqdm(dois):
# Get preprint features
preprint_features = self.augmenter.get_alex_features(preprint_doi) # augment with all the features
preprint_filtered = self.augmenter.filter_augmented_data(preprint_features) # filter the relevant features
preprint_row = self._augmented_data_to_row(preprint_filtered, True)
# Get article features
article_features = self.augmenter.get_alex_features(article_doi) # augment with all the features
article_filtered = self.augmenter.filter_augmented_data(article_features)
article_row = self._augmented_data_to_row(article_filtered, False)
rows.append([preprint_row, article_row])
return rows
@staticmethod
def transform_array(input_array, factor):
output_list = []
for i, row in enumerate(input_array):
other_indices = np.array([j for j in range(len(input_array)) if j != i])
sampled_indices = np.random.choice(other_indices, size=factor, replace=False)
sampled_rows = [input_array[j] for j in sampled_indices]
output_list.append(pd.concat([row[0], row[1], pd.DataFrame(data=[1], columns=['label'])], axis=1))
for B in sampled_rows:
output_list.append(pd.concat([row[0], B[1], pd.DataFrame(data=[0], columns=['label'])], axis=1))
return pd.concat(output_list).reset_index(drop=True)
def get_full_dataset(self, len: int = 1, random: bool = True, seed: int = 42, full: bool = True) -> pd.DataFrame:
"""Process all DOI pairs and return full dataset"""
dois = self.retrieve_dois_couple(len, random, seed, full)
self.augmented_df = FullAugmentedDataset.transform_array(self.process_pair(dois), factor=4)
return self.augmented_df