import os from typing import Union from stark_qa.retrieval import STaRKDataset REGISTERED_DATASETS = ['amazon', 'prime', 'mag'] def load_qa(name: str, root: Union[str, None] = None, human_generated_eval: bool = False) -> STaRKDataset: """ Load the QA dataset. Args: name (str): Name of the dataset. One of 'amazon', 'prime', or 'mag'. root (Union[str, None]): Root directory to store the dataset. If not provided, the default Hugging Face cache path is used. human_generated_eval (bool): Whether to use human-generated evaluation data. Default is False. Returns: STaRKDataset: The loaded STaRK dataset. Raises: ValueError: If the dataset name is not registered. """ if root is not None: if not os.path.isabs(root): root = os.path.abspath(root) if name in REGISTERED_DATASETS: return STaRKDataset(name, root, human_generated_eval=human_generated_eval) else: raise ValueError(f"Unknown dataset {name}")