Spaces:
Sleeping
Sleeping
from src.dataset.DataAugmenter import * | |
import pandas as pd | |
from tqdm import tqdm | |
import numpy as np | |
class FullAugmentedDataset: | |
def __init__(self): | |
self.augmenter = DataAugmenter() | |
self.full_raw_dataset = self._load_the_dataset() | |
def _load_the_dataset(self, type: DatasetType = DatasetType.FULL_RAW) -> pd.DataFrame: | |
"""Load as csv file one of the datasets for training.""" | |
assert str(PROJECT_ROOT).split("/")[-1] == "MatchingPubs", "Please run this script in the github repo folder " | |
if type == DatasetType.FULL_RAW: | |
return pd.read_csv(f"{PROJECT_ROOT}/data/crossref-preprint-article-relationships-Aug-2023.csv") | |
def retrieve_dois_couple(self, len: int = 1, random: bool = False, seed: bool = None, full: bool = False): | |
"""Retrieve two DOIs from the dataset""" | |
if random: | |
dois = self.full_raw_dataset.sample(n=len, random_state=seed)[["preprint_doi", "article_doi"]] | |
else: | |
dois = self.full_raw_dataset.head(len)[["preprint_doi", "article_doi"]] | |
if full: | |
dois = self.full_raw_dataset[["preprint_doi", "article_doi"]] | |
return dois.to_numpy() | |
def _flatten_list(lst): | |
""" | |
Flattens a nested list into a single list. If the input is not nested, it returns the original list. | |
Handles cases where some elements are lists and others are not. | |
""" | |
if not isinstance(lst, list): # Ensure the input is a list | |
raise ValueError("Input must be a list") | |
def _flatten(sublist): | |
for item in sublist: | |
if isinstance(item, list): # Check if the item is a list | |
yield from _flatten(item) # Recursively flatten the list | |
else: | |
yield item # Yield the non-list item | |
return list(_flatten(lst)) | |
def _augmented_data_to_row(self, filtered_data: Dict[str, Any], preprint: bool = True) -> pd.Series: | |
"""Transform filtered augmented data into a pandas Series | |
Args: | |
filtered_data: Dictionary containing filtered OpenAlex and Elsevier data | |
preprint: If True, use prpnt_ prefix, else use article_ prefix | |
Returns: | |
pd.Series: Flattened data as a single row | |
""" | |
additional_part = FullAugmentedDataset.filter_author(filtered_data.get("authors",{})) | |
# modify the key of additional part by adding authors_ at the beginning | |
additional_part = {f"authors_{k}": v for k, v in additional_part.items()} | |
# remove authos key from filtreed_info | |
filtered_data.pop("authors") | |
# append the additional part to the filtered_info | |
filtered_data.update(additional_part) | |
final_dictionary = FullAugmentedDataset.flatten_dict(filtered_data, preprint=preprint) | |
for k, v in final_dictionary.items(): | |
final_dictionary[k] = "$@$".join(map(str, FullAugmentedDataset._flatten_list(v))) if isinstance(v, list) else [v] | |
return pd.DataFrame(final_dictionary) | |
def filter_author(authors_info : list) -> dict: | |
try: | |
relevant_keys = authors_info[0].keys() | |
new_dict = {} | |
for key in relevant_keys: | |
new_dict[key] = [author[key] for author in authors_info] | |
return new_dict | |
except: | |
return {} | |
def flatten_dict(d: dict, parent_key: str = '', sep: str = '_', preprint = True) -> dict: | |
"""Flatten a nested dictionary. | |
Args: | |
d (dict): The dictionary to flatten. | |
parent_key (str): The base key string to use for the flattened keys. | |
sep (str): The separator to use between parent and child keys. | |
Returns: | |
dict: The flattened dictionary. | |
""" | |
addition = "prpnt_" if preprint else "article_" | |
def _flatten_dict(d: dict, parent_key: str = '', sep: str = '_') -> dict: | |
items = [] | |
for k, v in d.items(): | |
new_key = f"{parent_key}{sep}{k}" if parent_key else k | |
if isinstance(v, dict): | |
items.extend(_flatten_dict(v, new_key, sep=sep).items()) | |
else: | |
items.append((new_key, v)) | |
return dict(items) | |
return {f"{addition}{k}": v for k, v in _flatten_dict(d, parent_key, sep).items()} | |
def process_pair(self, dois) -> pd.DataFrame: | |
"""Process a pair of DOIs and return combined rows as a DataFrame""" | |
assert len(dois) > 0 | |
rows = [] | |
for preprint_doi, article_doi in tqdm(dois): | |
# Get preprint features | |
preprint_features = self.augmenter.get_alex_features(preprint_doi) # augment with all the features | |
preprint_filtered = self.augmenter.filter_augmented_data(preprint_features) # filter the relevant features | |
preprint_row = self._augmented_data_to_row(preprint_filtered, True) | |
# Get article features | |
article_features = self.augmenter.get_alex_features(article_doi) # augment with all the features | |
article_filtered = self.augmenter.filter_augmented_data(article_features) | |
article_row = self._augmented_data_to_row(article_filtered, False) | |
rows.append([preprint_row, article_row]) | |
return rows | |
def transform_array(input_array, factor): | |
output_list = [] | |
for i, row in enumerate(input_array): | |
other_indices = np.array([j for j in range(len(input_array)) if j != i]) | |
sampled_indices = np.random.choice(other_indices, size=factor, replace=False) | |
sampled_rows = [input_array[j] for j in sampled_indices] | |
output_list.append(pd.concat([row[0], row[1], pd.DataFrame(data=[1], columns=['label'])], axis=1)) | |
for B in sampled_rows: | |
output_list.append(pd.concat([row[0], B[1], pd.DataFrame(data=[0], columns=['label'])], axis=1)) | |
return pd.concat(output_list).reset_index(drop=True) | |
def get_full_dataset(self, len: int = 1, random: bool = True, seed: int = 42, full: bool = True) -> pd.DataFrame: | |
"""Process all DOI pairs and return full dataset""" | |
dois = self.retrieve_dois_couple(len, random, seed, full) | |
self.augmented_df = FullAugmentedDataset.transform_array(self.process_pair(dois), factor=4) | |
return self.augmented_df |