MatchPrePrintArticles / src /dataset /GoodDataAugmenter.py
tmencatt's picture
app
b5cf002
from enum import Enum
from typing import List, Dict, Any
from dataclasses import dataclass
import os
import yaml
import pyalex
from pyalex import Works
from src.utils.io_utils import PROJECT_ROOT
import time
from requests.exceptions import RequestException
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type, wait_fixed
@dataclass
class ConfigAugmentation:
"""Configuration for OpenAlex features"""
basic: Dict[str, bool] = None # id, doi, title, etc
source: Dict[str, bool] = None # journal info
authors: Dict[str, bool] = None # author details
metrics: Dict[str, bool] = None # citations, fwci, etc
classification: Dict[str, bool] = None # topics, concepts
access: Dict[str, bool] = None # OA status
related_works: Dict[str, bool] = None # references
abstract: bool = False
class DatasetType(Enum):
FULL_RAW = "full_raw"
PARTIAL_RAW = "partial_raw"
FULL_AUGMENTED = "full_augmented"
PARTIAL_AUGMENTED = "partial_augmented"
@dataclass
class Field:
"""Field configuration for data extraction"""
name: str
path: List[str]
default: Any = None
class AlexFields:
"""OpenAlex field definitions"""
BASIC = [
Field("id", ["id"]),
Field("doi", ["doi"]),
Field("title", ["title"]),
Field("display_name", ["display_name"]),
Field("publication_year", ["publication_year"]),
Field("publication_date", ["publication_date"]),
Field("language", ["language"]),
Field("type", ["type"]),
Field("type_crossref", ["type_crossref"])
]
SOURCE = [
Field("journal_name", ["primary_location", "source", "display_name"]),
Field("issn", ["primary_location", "source", "issn"]),
Field("issn_l", ["primary_location", "source", "issn_l"]),
Field("publisher", ["primary_location", "source", "host_organization_name"]),
Field("type", ["primary_location", "source", "type"])
]
METRICS = [
Field("cited_by_count", ["cited_by_count"]),
Field("cited_by_percentile", ["citation_normalized_percentile"]),
Field("is_retracted", ["is_retracted"]),
Field("fwci", ["fwci"]),
Field("referenced_works_count", ["referenced_works_count"])
]
ACCESS = [
Field("is_oa", ["open_access", "is_oa"]),
Field("oa_status", ["open_access", "oa_status"]),
Field("oa_url", ["open_access", "oa_url"]),
Field("pdf_url", ["primary_location", "pdf_url"]),
Field("license", ["primary_location", "license"])
]
def get_nested_value(data: Dict, path: List[str], default: Any = None) -> Any:
"""Extract nested value from dictionary using path"""
value = data
for key in path:
try:
value = value[key]
except (KeyError, TypeError):
return default
return value
class DataAugmenter:
"""Class for augmenting data with OpenAlex features"""
def __init__(self):
"""Initialize augmenter with API credentials"""
self.profile = self._load_profile()
self.email = self.profile["email"]
self.filters = ConfigAugmentation(
basic={
"id": True,
"doi": True,
"title": True,
"display_name": True,
"publication_year": True,
"publication_date": True,
"language": True,
"type": True,
"type_crossref": True
},
source={
"journal_name": True,
"issn": True,
"issn_l": True,
"publisher": True,
"type": True
},
authors={
"position": True,
"name": True,
"id": True,
"orcid": True,
"is_corresponding": True,
"affiliations": False
},
metrics={
"cited_by_count": True,
"cited_by_percentile": False,
"is_retracted": True,
"fwci": True,
"referenced_works_count": True
},
classification={
"primary_topic": True,
"topics": False,
"concepts": False,
},
access={
"is_oa": True,
"oa_status": True,
"oa_url": True,
"pdf_url": True,
"license": True
},
related_works={
"references": True,
"referenced_by_count": True,
"related": True
},
abstract=True
)
pyalex.config.email = self.email
def _load_profile(self) -> Dict[str, str]:
"""Load API credentials from profile"""
profile_path = f"{PROJECT_ROOT}/user_information/profile.yaml"
assert str(PROJECT_ROOT).split("/")[-1] == "MatchingPubs", "Please run this script in the github repo folder "
assert os.path.exists(profile_path), "create a profile.yaml with your email (email:) and your api key (api_key:). Go here to get one https://dev.elsevier.com/"
with open(profile_path, "r") as f:
profile = yaml.safe_load(f)
return {
"email": profile["email"]
}
@retry(
stop=stop_after_attempt(5), # Retry up to 5 times
wait=wait_exponential(multiplier=1, min=1, max=60), # Exponential backoff,
# wait=wait_fixed(.2),
retry=retry_if_exception_type(RequestException)
)
def get_alex_features(self, doi: str) -> Dict:
"""Extract all OpenAlex features for a DOI"""
try:
work = Works()[f"https://doi.org/{doi}"]
result = {}
# Basic metadata
result["basic"] = {
field.name: get_nested_value(work, field.path, None)
for field in AlexFields.BASIC
}
# Source/journal info
result["source"] = {
field.name: get_nested_value(work, field.path, None)
for field in AlexFields.SOURCE
}
# Authors with affiliations
try:
result["authors"] = [
{
"position": auth.get("author_position", None),
"name": auth.get("author", {}).get("display_name", None),
"id": auth.get("author", {}).get("id", None),
"orcid": auth.get("author", {}).get("orcid", None),
"is_corresponding": auth.get("is_corresponding", None),
"affiliations": [
{
"name": inst.get("display_name", None),
"id": inst.get("id", None),
"country": inst.get("country_code", None),
"type": inst.get("type", None),
"ror": inst.get("ror", None)
}
for inst in auth.get("institutions", [])
]
}
for auth in work.get("authorships", [])
]
except:
result["authors"] = None
# Topics and classifications
try:
result["classification"] = {
"primary_topic": {
"name": work.get("primary_topic", {}).get("display_name", None),
"score": work.get("primary_topic", {}).get("score", None),
"field": work.get("primary_topic", {}).get("field", {}).get("display_name", None),
"subfield": work.get("primary_topic", {}).get("subfield", {}).get("display_name", None)
},
"topics": [
{
"name": topic.get("display_name", None),
"score": topic.get("score", None),
"field": topic.get("field", {}).get("display_name", None)
}
for topic in work.get("topics", [])
],
"concepts": [
{
"name": concept.get("display_name", None),
"level": concept.get("level", None),
"score": concept.get("score", None),
"wikidata": concept.get("wikidata", None)
}
for concept in work.get("concepts", [])
]
}
except:
result["classification"] = None
# Metrics
result["metrics"] = {
field.name: get_nested_value(work, field.path, None)
for field in AlexFields.METRICS
}
# Access info
result["access"] = {
field.name: get_nested_value(work, field.path, None)
for field in AlexFields.ACCESS
}
# Abstract
try:
if "abstract_inverted_index" in work:
abstract_dict = work["abstract_inverted_index"]
if abstract_dict:
max_pos = max(max(positions) for positions in abstract_dict.values())
words = [""] * (max_pos + 1)
for word, positions in abstract_dict.items():
for pos in positions:
words[pos] = word
result["abstract"] = " ".join(words)
else:
result["abstract"] = None
else:
result["abstract"] = None
except:
result["abstract"] = None
return result
except Exception as e:
print(f"OpenAlex error for DOI {doi}")#: {e}")
# return {}
raise
def filter_augmented_data(self, data: Dict[str, Any], config: ConfigAugmentation = None) -> Dict[str, Any]:
"""Filter data based on configuration
Args:
data: Dictionary containing raw data
config: Configuration specifying which features to include
Returns:
Filtered dictionary containing only the configured features
"""
config = config or self.filters
def filter_section(section_data: Dict[str, Any], section_config: Dict[str, bool]) -> Dict[str, Any]:
"""Filter a section of the data based on the section configuration"""
if not isinstance(section_data, dict): return {}
return {k: v for k, v in section_data.items() if k in section_config and section_config[k]}
filtered_data = {}
# Filter OpenAlex data
alex_filtered = {}
# Basic metadata
if config.basic:
alex_filtered["basic"] = filter_section(data.get("basic", {}), config.basic)
# Source/journal info
if config.source:
alex_filtered["source"] = filter_section(data.get("source", {}), config.source)
# Authors
if config.authors:
authors_data = data.get("authors", [])
filtered_authors = []
for author in authors_data:
filtered_author = filter_section(author, config.authors)
if config.authors.get("affiliations", False):
affiliations = author.get("affiliations", [])
filtered_author["affiliations"] = [
filter_section(aff, config.authors["affiliations"])
for aff in affiliations
] if affiliations else []
filtered_authors.append(filtered_author)
alex_filtered["authors"] = filtered_authors
# Metrics
if config.metrics:
alex_filtered["metrics"] = filter_section(data.get("metrics", {}), config.metrics)
# Classification
if config.classification:
classification_data = data.get("classification", {})
alex_filtered["classification"] = {
k: v for k, v in classification_data.items() if k in config.classification and config.classification[k]
} if classification_data else {}
# Access info
if config.access:
alex_filtered["access"] = filter_section(data.get("access", {}), config.access)
# Related works
if config.related_works:
alex_filtered["related_works"] = filter_section(data.get("related_works", {}), config.related_works)
# Abstract
if config.abstract and "abstract" in data:
alex_filtered["abstract"] = data["abstract"]
filtered_data = alex_filtered
return filtered_data