File size: 1,862 Bytes
7bf4b88 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 |
import os
import os.path as osp
from typing import Union
from stark_qa.skb import SKB, AmazonSKB, PrimeSKB, MagSKB
def load_skb(name: str,
root: Union[str, None] = None,
download_processed: bool = True,
**kwargs) -> SKB:
"""
Load the SKB dataset.
Args:
name (str): Name of the dataset. One of 'amazon', 'prime', or 'mag'.
root (Union[str, None]): Root directory to store the dataset. If None, defaults to the HF cache path.
download_processed (bool): Whether to download processed data. Default is False. If True, `root` must be provided.
**kwargs: Additional keyword arguments for the specific dataset class.
Returns:
An instance of the specified SKB dataset class.
Raises:
ValueError: If the dataset name is not recognized.
AssertionError: If `root` is not provided when `download_processed` is False.
"""
if not download_processed:
assert root is not None, "root must be provided if download_processed is False"
if root is None:
data_root = None
else:
root = os.path.abspath(root)
data_root = osp.join(root, name)
if name == 'amazon':
categories = ['Sports_and_Outdoors']
skb = AmazonSKB(root=data_root,
categories=categories,
download_processed=download_processed,
**kwargs
)
elif name == 'prime':
skb = PrimeSKB(root=data_root,
download_processed=download_processed,
**kwargs)
elif name == 'mag':
skb = MagSKB(root=data_root,
download_processed=download_processed,
**kwargs)
else:
raise ValueError(f"Unknown dataset {name}")
return skb |