tmencatt's picture
app
b5cf002
from src.dataset.GoodDataAugmenter import *
from src.utils.struct_utils import *
import pandas as pd
from tqdm import tqdm
import numpy as np
from concurrent.futures import ThreadPoolExecutor, as_completed
import pickle as pkl
class AugmentedDataset:
def __init__(self, path: str = None):
"""
Initializes the AugmentedDataset object.
Loads the dataset and prepares the augmenter for data augmentation tasks.
"""
self.augmenter = DataAugmenter()
self.full_raw_dataset = self._load_the_dataset(path)
self.positive_samples = None
def _load_the_dataset(self, path: str = None) -> pd.DataFrame:
"""
Load the dataset as a CSV file.
Args:
type (str): The type of dataset to load (default is 'FULL_RAW').
Returns:
pd.DataFrame: The loaded dataset as a pandas DataFrame.
"""
assert str(PROJECT_ROOT).split("/")[-1] == "MatchingPubs", \
"Please run this script in the project repository folder."
if not path:
return pd.read_csv(f"{PROJECT_ROOT}/data/crossref-preprint-article-relationships-Aug-2023.csv")
return pd.read_csv(path)
def sample_dois_pairs(
self,
num_samples: int = 1,
random: bool = False,
seed: int = None,
full: bool = False
) -> np.ndarray:
"""
Sample DOI pairs from the dataset.
Args:
num_samples (int): Number of DOI pairs to sample.
random (bool): If True, sample randomly; otherwise, use the top rows.
seed (int): Random seed for reproducibility (used if random=True).
full (bool): If True, return all DOI pairs without sampling.
Returns:
np.ndarray: The sampled DOI pairs.
"""
seed = seed if seed >= 0 else None
num_samples = min(num_samples, len(self.full_raw_dataset))
if full:
sampled_data = self.full_raw_dataset[["preprint_doi", "article_doi"]]
elif random:
sampled_data = self.full_raw_dataset.sample(n=num_samples, random_state=seed)[["preprint_doi", "article_doi"]]
else:
sampled_data = self.full_raw_dataset.iloc[:num_samples][["preprint_doi", "article_doi"]]
return sampled_data.to_numpy()
def _augmented_data_to_row(self, filtered_data: Dict[str, Any], preprint: bool = True) -> pd.DataFrame:
"""Transform filtered augmented data into a pandas Series
Args:
filtered_data: Dictionary containing filtered OpenAlex and Elsevier data
preprint: If True, use prpnt_ prefix, else use article_ prefix
Returns:
pd.Series: Flattened data as a single row
"""
authors_info = filtered_data.pop("authors", {})
if authors_info:
additional_part = {f"authors_{k}": v for k, v in authors_info[0].items()}
filtered_data.update(additional_part)
prefix = "prpnt_" if preprint else "article_"
final_dictionary = {f"{prefix}{k}": v for k, v in flatten_dict(filtered_data).items()}
for key, value in final_dictionary.items():
final_dictionary[key] = "$@$".join(map(str, flatten_list(value))) if isinstance(value, list) else [value]
return pd.DataFrame(final_dictionary)
def process_pairs(self, dois: np.ndarray) -> List[List[pd.DataFrame]]:
"""
Process pairs of DOIs and return combined rows as a list of DataFrame pairs.
Args:
dois (np.ndarray): Array of DOI pairs.
Returns:
List[List[pd.DataFrame]]: List of preprint and article DataFrame pairs.
"""
assert len(dois) > 0, "DOI pairs cannot be empty."
rows = []
for preprint_doi, article_doi in tqdm(dois, desc="Processing DOI pairs"):
preprint_features = self.augmenter.get_alex_features(preprint_doi)
article_features = self.augmenter.get_alex_features(article_doi)
preprint_filtered = self.augmenter.filter_augmented_data(preprint_features)
article_filtered = self.augmenter.filter_augmented_data(article_features)
preprint_row = self._augmented_data_to_row(preprint_filtered, True)
article_row = self._augmented_data_to_row(article_filtered, False)
rows.append([preprint_row, article_row])
return rows
def fetch_positive_samples(
self,
num_samples: int = 1,
random: bool = True,
seed: int = 42,
full: bool = True,
):
"""
Process all DOI pairs and return the full augmented dataset.
Args:
num_samples (int): Number of DOI pairs to process.
random (bool): Whether to sample DOI pairs randomly.
seed (int): Seed for reproducibility.
full (bool): If True, process the entire dataset.
Returns:
"""
dois = self.sample_dois_pairs(num_samples, random, seed, full)
self.positive_samples = self.process_pairs(dois)
return self.positive_samples
def process_pairs_parallel(self, dois: np.ndarray, max_workers: int = 4) -> List[List[pd.DataFrame]]:
"""
Process pairs of DOIs in parallel and return combined rows as a list of DataFrame pairs.
Args:
dois (np.ndarray): Array of DOI pairs.
max_workers (int): Number of threads to use for parallel processing.
Returns:
List[List[pd.DataFrame]]: List of preprint and article DataFrame pairs.
"""
assert len(dois) > 0, "DOI pairs cannot be empty."
def process_single_pair(preprint_doi: str, article_doi: str) -> List[pd.DataFrame]:
"""
Process a single DOI pair to extract preprint and article data.
Args:
preprint_doi (str): DOI for the preprint.
article_doi (str): DOI for the article.
Returns:
List[pd.DataFrame]: A list containing preprint and article rows.
"""
try:
# Preprint features
preprint_features = self.augmenter.get_alex_features(preprint_doi)
preprint_filtered = self.augmenter.filter_augmented_data(preprint_features)
preprint_row = self._augmented_data_to_row(preprint_filtered, True)
# Article features
article_features = self.augmenter.get_alex_features(article_doi)
article_filtered = self.augmenter.filter_augmented_data(article_features)
article_row = self._augmented_data_to_row(article_filtered, False)
return [preprint_row, article_row]
except Exception as e:
print(f"Error processing pair ({preprint_doi}, {article_doi})")#: {e}")
return []
rows = []
# Use ThreadPoolExecutor for parallel processing
with ThreadPoolExecutor(max_workers=max_workers) as executor:
# Submit tasks to the executor
futures = {
executor.submit(process_single_pair, preprint_doi, article_doi): (preprint_doi, article_doi)
for preprint_doi, article_doi in dois
}
# Collect results as they complete
for future in tqdm(as_completed(futures), total=len(futures), desc="Processing DOI pairs in parallel"):
try:
result = future.result()
if result: # Append only non-empty results
rows.append(result)
except Exception as e:
doi_pair = futures[future]
# print(f"Error with DOI pair {doi_pair}: {e}")
return rows
def fetch_positive_samples_parallel(
self,
num_samples: int = 1,
random: bool = True,
seed: int = 42,
full: bool = True,
):
"""
Process all DOI pairs and return the full augmented dataset.
Args:
num_samples (int): Number of DOI pairs to process.
random (bool): Whether to sample DOI pairs randomly.
seed (int): Seed for reproducibility.
full (bool): If True, process the entire dataset.
Returns:
"""
dois = self.sample_dois_pairs(num_samples, random, seed, full)
self.positive_samples = self.process_pairs_parallel(dois)
return self.positive_samples
def augment_dataset(
self,
augmentation_factor: int = 4,
# possible augmenation parameters
):
self.augmented_df = self.transform_array(self.positive_pairs, factor=augmentation_factor)
def save(self, path: str):
with open(path, 'wb') as file:
pkl.dump(self.positive_samples, file)
def load(self, path: str):
with open(path, 'rb') as file:
self.positive_samples = pkl.load(file)
def save_csv(self, path: str):
custom_struct_to_df(self.positive_samples).to_csv(path)
def load_csv(self, path: str):
self.positive_samples = df_to_custom_struct(pd.read_csv(path))