markqiu's picture
百度文心一言的例子
569cdb0
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
import asyncio
import functools
import pathlib
import time
import uuid
from typing import ClassVar, Dict, List
import anyio
from baidubce.auth.bce_credentials import BceCredentials
from baidubce.bce_client_configuration import BceClientConfiguration
from baidubce.services.bos.bos_client import BosClient
from erniebot_agent.file_io.base import File
from erniebot_agent.file_io.protocol import (
build_remote_file_id_from_uuid,
is_remote_file_id,
)
class RemoteFile(File):
def __init__(self, id: str, filename: str, created_at: int, client: "RemoteFileClient") -> None:
if not is_remote_file_id(id):
raise ValueError("Invalid file ID: {id}")
super().__init__(id=id, filename=filename, created_at=created_at)
self._client = client
async def read_contents(self) -> bytes:
file_contents = await self._client.retrieve_file_contents(self.id)
return file_contents
async def delete(self) -> None:
await self._client.delete_file(self.id)
class RemoteFileClient(metaclass=abc.ABCMeta):
@abc.abstractmethod
async def upload_file(self, file_path: pathlib.Path) -> RemoteFile:
raise NotImplementedError
@abc.abstractmethod
async def retrieve_file(self, file_id: str) -> RemoteFile:
raise NotImplementedError
@abc.abstractmethod
async def retrieve_file_contents(self, file_id: str) -> bytes:
raise NotImplementedError
@abc.abstractmethod
async def list_files(self) -> List[RemoteFile]:
raise NotImplementedError
@abc.abstractmethod
async def delete_file(self, file_id: str) -> None:
raise NotImplementedError
class BOSFileClient(RemoteFileClient):
_ENDPOINT: ClassVar[str] = "bj.bcebos.com"
def __init__(self, ak: str, sk: str, bucket_name: str, prefix: str) -> None:
super().__init__()
self.bucket_name = bucket_name
self.prefix = prefix
config = BceClientConfiguration(credentials=BceCredentials(ak, sk), endpoint=self._ENDPOINT)
self._bos_client = BosClient(config=config)
async def upload_file(self, file_path: pathlib.Path) -> RemoteFile:
file_id = self._generate_file_id()
filename = file_path.name
created_at = int(time.time())
user_metadata: Dict[str, str] = {"id": file_id, "filename": filename, "created_at": str(created_at)}
async with await anyio.open_file(file_path, mode="rb") as f:
data = await f.read()
loop = asyncio.get_running_loop()
await loop.run_in_executor(
None,
functools.partial(
self._bos_client.put_object_from_string,
bucket=self.bucket_name,
key=self._get_key(file_id),
data=data,
user_metadata=user_metadata,
),
)
return RemoteFile(
id=file_id,
filename=filename,
created_at=created_at,
client=self,
)
async def retrieve_file(self, file_id: str) -> RemoteFile:
loop = asyncio.get_running_loop()
response = await loop.run_in_executor(
None,
functools.partial(
self._bos_client.get_object_meta_data, self.bucket_name, self._get_key(file_id)
),
)
user_metadata = {
"id": response.metadata.bce_meta_id,
"filename": response.metadata.bce_meta_filename,
"created_at": int(response.metadata.bce_meta_created_at),
}
if file_id != user_metadata["id"]:
raise RuntimeError("`file_id` is not the same as the one in metadata.")
return RemoteFile(
id=user_metadata["id"],
filename=user_metadata["filename"],
created_at=user_metadata["created_at"],
client=self,
)
async def retrieve_file_contents(self, file_id: str) -> bytes:
loop = asyncio.get_running_loop()
result = await loop.run_in_executor(
None,
functools.partial(
self._bos_client.get_object_as_string, self.bucket_name, self._get_key(file_id)
),
)
return result
async def list_files(self) -> List[RemoteFile]:
raise RuntimeError(f"`{self.__class__.__name__}.list_files` is not supported.")
async def delete_file(self, file_id: str) -> None:
loop = asyncio.get_running_loop()
await loop.run_in_executor(
None, functools.partial(self._bos_client.delete_object, self.bucket_name, self._get_key(file_id))
)
def _get_key(self, file_id: str) -> str:
return self.prefix + file_id
@staticmethod
def _generate_file_id() -> str:
return build_remote_file_id_from_uuid(str(uuid.uuid1()))