|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
import pathlib |
|
import uuid |
|
from typing import Literal, Optional, Union, overload |
|
|
|
import anyio |
|
from erniebot_agent.file_io.base import File |
|
from erniebot_agent.file_io.file_registry import FileRegistry, get_file_registry |
|
from erniebot_agent.file_io.local_file import LocalFile, create_local_file_from_path |
|
from erniebot_agent.file_io.remote_file import RemoteFile, RemoteFileClient |
|
from erniebot_agent.utils.temp_file import create_tracked_temp_dir |
|
from typing_extensions import TypeAlias |
|
|
|
_PathType: TypeAlias = Union[str, os.PathLike] |
|
|
|
|
|
class FileManager(object): |
|
_remote_file_client: Optional[RemoteFileClient] |
|
|
|
def __init__( |
|
self, |
|
remote_file_client: Optional[RemoteFileClient] = None, |
|
*, |
|
auto_register: bool = True, |
|
save_dir: Optional[_PathType] = None, |
|
) -> None: |
|
super().__init__() |
|
if remote_file_client is not None: |
|
self._remote_file_client = remote_file_client |
|
else: |
|
self._remote_file_client = None |
|
self._auto_register = auto_register |
|
if save_dir is not None: |
|
self._save_dir = pathlib.Path(save_dir) |
|
else: |
|
|
|
self._save_dir = create_tracked_temp_dir() |
|
|
|
self._file_registry = get_file_registry() |
|
|
|
@property |
|
def registry(self) -> FileRegistry: |
|
return self._file_registry |
|
|
|
@property |
|
def remote_file_client(self) -> RemoteFileClient: |
|
if self._remote_file_client is None: |
|
raise AttributeError("No remote file client is set.") |
|
else: |
|
return self._remote_file_client |
|
|
|
@overload |
|
async def create_file_from_path( |
|
self, file_path: _PathType, *, file_type: Literal["local"] = ... |
|
) -> LocalFile: |
|
... |
|
|
|
@overload |
|
async def create_file_from_path( |
|
self, file_path: _PathType, *, file_type: Literal["remote"] |
|
) -> RemoteFile: |
|
... |
|
|
|
async def create_file_from_path( |
|
self, file_path: _PathType, *, file_type: Literal["local", "remote"] = "local" |
|
) -> Union[LocalFile, RemoteFile]: |
|
file: Union[LocalFile, RemoteFile] |
|
if file_type == "local": |
|
file = await self.create_local_file_from_path(file_path) |
|
elif file_type == "remote": |
|
file = await self.create_remote_file_from_path(file_path) |
|
else: |
|
raise ValueError(f"Unsupported file type: {file_type}") |
|
return file |
|
|
|
async def create_local_file_from_path(self, file_path: _PathType) -> LocalFile: |
|
file = create_local_file_from_path(pathlib.Path(file_path)) |
|
self._file_registry.register_file(file) |
|
return file |
|
|
|
async def create_remote_file_from_path(self, file_path: _PathType) -> RemoteFile: |
|
file = await self.remote_file_client.upload_file(pathlib.Path(file_path)) |
|
if self._auto_register: |
|
self._file_registry.register_file(file) |
|
return file |
|
|
|
@overload |
|
async def create_file_from_bytes( |
|
self, file_contents: bytes, filename: str, *, file_type: Literal["local"] = ... |
|
) -> LocalFile: |
|
... |
|
|
|
@overload |
|
async def create_file_from_bytes( |
|
self, file_contents: bytes, filename: str, *, file_type: Literal["remote"] |
|
) -> RemoteFile: |
|
... |
|
|
|
async def create_file_from_bytes( |
|
self, file_contents: bytes, filename: str, *, file_type: Literal["local", "remote"] = "local" |
|
) -> Union[LocalFile, RemoteFile]: |
|
|
|
file_path = self._fs_create_file( |
|
prefix=pathlib.PurePath(filename).stem, suffix=pathlib.PurePath(filename).suffix |
|
) |
|
async with await anyio.open_file(file_path, "wb") as f: |
|
await f.write(file_contents) |
|
file = await self.create_file_from_path(file_path, file_type=file_type) |
|
return file |
|
|
|
async def retrieve_remote_file_by_id(self, file_id: str) -> RemoteFile: |
|
file = await self.remote_file_client.retrieve_file(file_id) |
|
if self._auto_register: |
|
self._file_registry.register_file(file) |
|
return file |
|
|
|
def look_up_file_by_id(self, file_id: str) -> Optional[File]: |
|
return self._file_registry.look_up_file(file_id) |
|
|
|
def _fs_create_file(self, prefix: Optional[str] = None, suffix: Optional[str] = None) -> pathlib.Path: |
|
filename = f"{prefix or ''}{str(uuid.uuid4())}{suffix or ''}" |
|
file_path = self._save_dir / filename |
|
file_path.touch() |
|
return file_path |
|
|