|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import abc |
|
import asyncio |
|
import functools |
|
import pathlib |
|
import time |
|
import uuid |
|
from typing import ClassVar, Dict, List |
|
|
|
import anyio |
|
from baidubce.auth.bce_credentials import BceCredentials |
|
from baidubce.bce_client_configuration import BceClientConfiguration |
|
from baidubce.services.bos.bos_client import BosClient |
|
from erniebot_agent.file_io.base import File |
|
from erniebot_agent.file_io.protocol import ( |
|
build_remote_file_id_from_uuid, |
|
is_remote_file_id, |
|
) |
|
|
|
|
|
class RemoteFile(File): |
|
def __init__(self, id: str, filename: str, created_at: int, client: "RemoteFileClient") -> None: |
|
if not is_remote_file_id(id): |
|
raise ValueError("Invalid file ID: {id}") |
|
super().__init__(id=id, filename=filename, created_at=created_at) |
|
self._client = client |
|
|
|
async def read_contents(self) -> bytes: |
|
file_contents = await self._client.retrieve_file_contents(self.id) |
|
return file_contents |
|
|
|
async def delete(self) -> None: |
|
await self._client.delete_file(self.id) |
|
|
|
|
|
class RemoteFileClient(metaclass=abc.ABCMeta): |
|
@abc.abstractmethod |
|
async def upload_file(self, file_path: pathlib.Path) -> RemoteFile: |
|
raise NotImplementedError |
|
|
|
@abc.abstractmethod |
|
async def retrieve_file(self, file_id: str) -> RemoteFile: |
|
raise NotImplementedError |
|
|
|
@abc.abstractmethod |
|
async def retrieve_file_contents(self, file_id: str) -> bytes: |
|
raise NotImplementedError |
|
|
|
@abc.abstractmethod |
|
async def list_files(self) -> List[RemoteFile]: |
|
raise NotImplementedError |
|
|
|
@abc.abstractmethod |
|
async def delete_file(self, file_id: str) -> None: |
|
raise NotImplementedError |
|
|
|
|
|
class BOSFileClient(RemoteFileClient): |
|
_ENDPOINT: ClassVar[str] = "bj.bcebos.com" |
|
|
|
def __init__(self, ak: str, sk: str, bucket_name: str, prefix: str) -> None: |
|
super().__init__() |
|
self.bucket_name = bucket_name |
|
self.prefix = prefix |
|
config = BceClientConfiguration(credentials=BceCredentials(ak, sk), endpoint=self._ENDPOINT) |
|
self._bos_client = BosClient(config=config) |
|
|
|
async def upload_file(self, file_path: pathlib.Path) -> RemoteFile: |
|
file_id = self._generate_file_id() |
|
filename = file_path.name |
|
created_at = int(time.time()) |
|
user_metadata: Dict[str, str] = {"id": file_id, "filename": filename, "created_at": str(created_at)} |
|
async with await anyio.open_file(file_path, mode="rb") as f: |
|
data = await f.read() |
|
loop = asyncio.get_running_loop() |
|
await loop.run_in_executor( |
|
None, |
|
functools.partial( |
|
self._bos_client.put_object_from_string, |
|
bucket=self.bucket_name, |
|
key=self._get_key(file_id), |
|
data=data, |
|
user_metadata=user_metadata, |
|
), |
|
) |
|
return RemoteFile( |
|
id=file_id, |
|
filename=filename, |
|
created_at=created_at, |
|
client=self, |
|
) |
|
|
|
async def retrieve_file(self, file_id: str) -> RemoteFile: |
|
loop = asyncio.get_running_loop() |
|
response = await loop.run_in_executor( |
|
None, |
|
functools.partial( |
|
self._bos_client.get_object_meta_data, self.bucket_name, self._get_key(file_id) |
|
), |
|
) |
|
user_metadata = { |
|
"id": response.metadata.bce_meta_id, |
|
"filename": response.metadata.bce_meta_filename, |
|
"created_at": int(response.metadata.bce_meta_created_at), |
|
} |
|
if file_id != user_metadata["id"]: |
|
raise RuntimeError("`file_id` is not the same as the one in metadata.") |
|
|
|
return RemoteFile( |
|
id=user_metadata["id"], |
|
filename=user_metadata["filename"], |
|
created_at=user_metadata["created_at"], |
|
client=self, |
|
) |
|
|
|
async def retrieve_file_contents(self, file_id: str) -> bytes: |
|
loop = asyncio.get_running_loop() |
|
result = await loop.run_in_executor( |
|
None, |
|
functools.partial( |
|
self._bos_client.get_object_as_string, self.bucket_name, self._get_key(file_id) |
|
), |
|
) |
|
return result |
|
|
|
async def list_files(self) -> List[RemoteFile]: |
|
raise RuntimeError(f"`{self.__class__.__name__}.list_files` is not supported.") |
|
|
|
async def delete_file(self, file_id: str) -> None: |
|
loop = asyncio.get_running_loop() |
|
await loop.run_in_executor( |
|
None, functools.partial(self._bos_client.delete_object, self.bucket_name, self._get_key(file_id)) |
|
) |
|
|
|
def _get_key(self, file_id: str) -> str: |
|
return self.prefix + file_id |
|
|
|
@staticmethod |
|
def _generate_file_id() -> str: |
|
return build_remote_file_id_from_uuid(str(uuid.uuid1())) |
|
|