| import os | |
| from tempfile import TemporaryDirectory | |
| from typing import Mapping, Optional, Sequence, Union | |
| from datasets import load_dataset as hf_load_dataset | |
| from tqdm import tqdm | |
| from .operator import SourceOperator | |
| from .stream import MultiStream | |
| try: | |
| import ibm_boto3 | |
| from ibm_botocore.client import ClientError | |
| ibm_boto3_available = True | |
| except ImportError: | |
| ibm_boto3_available = False | |
| class Loader(SourceOperator): | |
| pass | |
| class LoadHF(Loader): | |
| path: str | |
| name: Optional[str] = None | |
| data_dir: Optional[str] = None | |
| data_files: Optional[Union[str, Sequence[str], Mapping[str, Union[str, Sequence[str]]]]] = None | |
| cached = False | |
| def process(self): | |
| dataset = hf_load_dataset( | |
| self.path, name=self.name, data_dir=self.data_dir, data_files=self.data_files, streaming=True | |
| ) | |
| return MultiStream.from_iterables(dataset) | |
| class LoadFromIBMCloud(Loader): | |
| endpoint_url_env: str | |
| aws_access_key_id_env: str | |
| aws_secret_access_key_env: str | |
| bucket_name: str | |
| data_dir: str | |
| data_files: Sequence[str] | |
| def _download_from_cos(self, cos, bucket_name, item_name, local_file): | |
| print(f"Downloading {item_name} from {bucket_name} COS to {local_file}") | |
| try: | |
| response = cos.Object(bucket_name, item_name).get() | |
| size = response["ContentLength"] | |
| except Exception as e: | |
| raise Exception(f"Unabled to access {item_name} in {bucket_name} in COS", e) | |
| progress_bar = tqdm(total=size, unit="iB", unit_scale=True) | |
| def upload_progress(chunk): | |
| progress_bar.update(chunk) | |
| try: | |
| cos.Bucket(bucket_name).download_file(item_name, local_file, Callback=upload_progress) | |
| print("\nDownload Successful") | |
| except Exception as e: | |
| raise Exception(f"Unabled to download {item_name} in {bucket_name}", e) | |
| def prepare(self): | |
| super().prepare() | |
| self.endpoint_url = os.getenv(self.endpoint_url_env) | |
| self.aws_access_key_id = os.getenv(self.aws_access_key_id_env) | |
| self.aws_secret_access_key = os.getenv(self.aws_secret_access_key_env) | |
| def verify(self): | |
| super().verify() | |
| assert ( | |
| ibm_boto3_available | |
| ), f"Please install ibm_boto3 in order to use the LoadFromIBMCloud loader (using `pip install ibm-cos-sdk`) " | |
| assert self.endpoint_url is not None, f"Please set the {self.endpoint_url_env} environmental variable" | |
| assert self.aws_access_key_id is not None, f"Please set {self.aws_access_key_id_env} environmental variable" | |
| assert ( | |
| self.aws_secret_access_key is not None | |
| ), f"Please set {self.aws_secret_access_key_env} environmental variable" | |
| def process(self): | |
| cos = ibm_boto3.resource( | |
| "s3", | |
| aws_access_key_id=self.aws_access_key_id, | |
| aws_secret_access_key=self.aws_secret_access_key, | |
| endpoint_url=self.endpoint_url, | |
| ) | |
| with TemporaryDirectory() as temp_directory: | |
| for data_file in self.data_files: | |
| self._download_from_cos( | |
| cos, self.bucket_name, self.data_dir + "/" + data_file, temp_directory + "/" + data_file | |
| ) | |
| dataset = hf_load_dataset(temp_directory, streaming=False) | |
| return MultiStream.from_iterables(dataset) | |