Spaces:
Sleeping
Sleeping
from src.dataset.GoodDataAugmenter import * | |
from src.utils.struct_utils import * | |
import pandas as pd | |
from tqdm import tqdm | |
import numpy as np | |
from concurrent.futures import ThreadPoolExecutor, as_completed | |
import pickle as pkl | |
class AugmentedDataset: | |
def __init__(self, path: str = None): | |
""" | |
Initializes the AugmentedDataset object. | |
Loads the dataset and prepares the augmenter for data augmentation tasks. | |
""" | |
self.augmenter = DataAugmenter() | |
self.full_raw_dataset = self._load_the_dataset(path) | |
self.positive_samples = None | |
def _load_the_dataset(self, path: str = None) -> pd.DataFrame: | |
""" | |
Load the dataset as a CSV file. | |
Args: | |
type (str): The type of dataset to load (default is 'FULL_RAW'). | |
Returns: | |
pd.DataFrame: The loaded dataset as a pandas DataFrame. | |
""" | |
assert str(PROJECT_ROOT).split("/")[-1] == "MatchingPubs", \ | |
"Please run this script in the project repository folder." | |
if not path: | |
return pd.read_csv(f"{PROJECT_ROOT}/data/crossref-preprint-article-relationships-Aug-2023.csv") | |
return pd.read_csv(path) | |
def sample_dois_pairs( | |
self, | |
num_samples: int = 1, | |
random: bool = False, | |
seed: int = None, | |
full: bool = False | |
) -> np.ndarray: | |
""" | |
Sample DOI pairs from the dataset. | |
Args: | |
num_samples (int): Number of DOI pairs to sample. | |
random (bool): If True, sample randomly; otherwise, use the top rows. | |
seed (int): Random seed for reproducibility (used if random=True). | |
full (bool): If True, return all DOI pairs without sampling. | |
Returns: | |
np.ndarray: The sampled DOI pairs. | |
""" | |
seed = seed if seed >= 0 else None | |
num_samples = min(num_samples, len(self.full_raw_dataset)) | |
if full: | |
sampled_data = self.full_raw_dataset[["preprint_doi", "article_doi"]] | |
elif random: | |
sampled_data = self.full_raw_dataset.sample(n=num_samples, random_state=seed)[["preprint_doi", "article_doi"]] | |
else: | |
sampled_data = self.full_raw_dataset.iloc[:num_samples][["preprint_doi", "article_doi"]] | |
return sampled_data.to_numpy() | |
def _augmented_data_to_row(self, filtered_data: Dict[str, Any], preprint: bool = True) -> pd.DataFrame: | |
"""Transform filtered augmented data into a pandas Series | |
Args: | |
filtered_data: Dictionary containing filtered OpenAlex and Elsevier data | |
preprint: If True, use prpnt_ prefix, else use article_ prefix | |
Returns: | |
pd.Series: Flattened data as a single row | |
""" | |
authors_info = filtered_data.pop("authors", {}) | |
if authors_info: | |
additional_part = {f"authors_{k}": v for k, v in authors_info[0].items()} | |
filtered_data.update(additional_part) | |
prefix = "prpnt_" if preprint else "article_" | |
final_dictionary = {f"{prefix}{k}": v for k, v in flatten_dict(filtered_data).items()} | |
for key, value in final_dictionary.items(): | |
final_dictionary[key] = "$@$".join(map(str, flatten_list(value))) if isinstance(value, list) else [value] | |
return pd.DataFrame(final_dictionary) | |
def process_pairs(self, dois: np.ndarray) -> List[List[pd.DataFrame]]: | |
""" | |
Process pairs of DOIs and return combined rows as a list of DataFrame pairs. | |
Args: | |
dois (np.ndarray): Array of DOI pairs. | |
Returns: | |
List[List[pd.DataFrame]]: List of preprint and article DataFrame pairs. | |
""" | |
assert len(dois) > 0, "DOI pairs cannot be empty." | |
rows = [] | |
for preprint_doi, article_doi in tqdm(dois, desc="Processing DOI pairs"): | |
preprint_features = self.augmenter.get_alex_features(preprint_doi) | |
article_features = self.augmenter.get_alex_features(article_doi) | |
preprint_filtered = self.augmenter.filter_augmented_data(preprint_features) | |
article_filtered = self.augmenter.filter_augmented_data(article_features) | |
preprint_row = self._augmented_data_to_row(preprint_filtered, True) | |
article_row = self._augmented_data_to_row(article_filtered, False) | |
rows.append([preprint_row, article_row]) | |
return rows | |
def fetch_positive_samples( | |
self, | |
num_samples: int = 1, | |
random: bool = True, | |
seed: int = 42, | |
full: bool = True, | |
): | |
""" | |
Process all DOI pairs and return the full augmented dataset. | |
Args: | |
num_samples (int): Number of DOI pairs to process. | |
random (bool): Whether to sample DOI pairs randomly. | |
seed (int): Seed for reproducibility. | |
full (bool): If True, process the entire dataset. | |
Returns: | |
""" | |
dois = self.sample_dois_pairs(num_samples, random, seed, full) | |
self.positive_samples = self.process_pairs(dois) | |
return self.positive_samples | |
def process_pairs_parallel(self, dois: np.ndarray, max_workers: int = 4) -> List[List[pd.DataFrame]]: | |
""" | |
Process pairs of DOIs in parallel and return combined rows as a list of DataFrame pairs. | |
Args: | |
dois (np.ndarray): Array of DOI pairs. | |
max_workers (int): Number of threads to use for parallel processing. | |
Returns: | |
List[List[pd.DataFrame]]: List of preprint and article DataFrame pairs. | |
""" | |
assert len(dois) > 0, "DOI pairs cannot be empty." | |
def process_single_pair(preprint_doi: str, article_doi: str) -> List[pd.DataFrame]: | |
""" | |
Process a single DOI pair to extract preprint and article data. | |
Args: | |
preprint_doi (str): DOI for the preprint. | |
article_doi (str): DOI for the article. | |
Returns: | |
List[pd.DataFrame]: A list containing preprint and article rows. | |
""" | |
try: | |
# Preprint features | |
preprint_features = self.augmenter.get_alex_features(preprint_doi) | |
preprint_filtered = self.augmenter.filter_augmented_data(preprint_features) | |
preprint_row = self._augmented_data_to_row(preprint_filtered, True) | |
# Article features | |
article_features = self.augmenter.get_alex_features(article_doi) | |
article_filtered = self.augmenter.filter_augmented_data(article_features) | |
article_row = self._augmented_data_to_row(article_filtered, False) | |
return [preprint_row, article_row] | |
except Exception as e: | |
print(f"Error processing pair ({preprint_doi}, {article_doi})")#: {e}") | |
return [] | |
rows = [] | |
# Use ThreadPoolExecutor for parallel processing | |
with ThreadPoolExecutor(max_workers=max_workers) as executor: | |
# Submit tasks to the executor | |
futures = { | |
executor.submit(process_single_pair, preprint_doi, article_doi): (preprint_doi, article_doi) | |
for preprint_doi, article_doi in dois | |
} | |
# Collect results as they complete | |
for future in tqdm(as_completed(futures), total=len(futures), desc="Processing DOI pairs in parallel"): | |
try: | |
result = future.result() | |
if result: # Append only non-empty results | |
rows.append(result) | |
except Exception as e: | |
doi_pair = futures[future] | |
# print(f"Error with DOI pair {doi_pair}: {e}") | |
return rows | |
def fetch_positive_samples_parallel( | |
self, | |
num_samples: int = 1, | |
random: bool = True, | |
seed: int = 42, | |
full: bool = True, | |
): | |
""" | |
Process all DOI pairs and return the full augmented dataset. | |
Args: | |
num_samples (int): Number of DOI pairs to process. | |
random (bool): Whether to sample DOI pairs randomly. | |
seed (int): Seed for reproducibility. | |
full (bool): If True, process the entire dataset. | |
Returns: | |
""" | |
dois = self.sample_dois_pairs(num_samples, random, seed, full) | |
self.positive_samples = self.process_pairs_parallel(dois) | |
return self.positive_samples | |
def augment_dataset( | |
self, | |
augmentation_factor: int = 4, | |
# possible augmenation parameters | |
): | |
self.augmented_df = self.transform_array(self.positive_pairs, factor=augmentation_factor) | |
def save(self, path: str): | |
with open(path, 'wb') as file: | |
pkl.dump(self.positive_samples, file) | |
def load(self, path: str): | |
with open(path, 'rb') as file: | |
self.positive_samples = pkl.load(file) | |
def save_csv(self, path: str): | |
custom_struct_to_df(self.positive_samples).to_csv(path) | |
def load_csv(self, path: str): | |
self.positive_samples = df_to_custom_struct(pd.read_csv(path)) | |