|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import abc |
|
import inspect |
|
import json |
|
from typing import Any, Dict, List, Literal, Optional, Union |
|
|
|
from erniebot_agent.agents.callback.callback_manager import CallbackManager |
|
from erniebot_agent.agents.callback.default import get_default_callbacks |
|
from erniebot_agent.agents.callback.handlers.base import CallbackHandler |
|
from erniebot_agent.agents.schema import ( |
|
AgentFile, |
|
AgentResponse, |
|
LLMResponse, |
|
ToolResponse, |
|
) |
|
from erniebot_agent.chat_models.base import ChatModel |
|
from erniebot_agent.file_io.file_manager import FileManager |
|
from erniebot_agent.file_io.protocol import is_local_file_id, is_remote_file_id |
|
from erniebot_agent.memory.base import Memory |
|
from erniebot_agent.messages import Message, SystemMessage |
|
from erniebot_agent.tools.base import Tool |
|
from erniebot_agent.tools.tool_manager import ToolManager |
|
from erniebot_agent.utils.logging import logger |
|
|
|
|
|
class BaseAgent(metaclass=abc.ABCMeta): |
|
llm: ChatModel |
|
memory: Memory |
|
|
|
@abc.abstractmethod |
|
async def async_run(self, prompt: str) -> AgentResponse: |
|
raise NotImplementedError |
|
|
|
|
|
class Agent(BaseAgent): |
|
def __init__( |
|
self, |
|
llm: ChatModel, |
|
tools: Union[ToolManager, List[Tool]], |
|
memory: Memory, |
|
system_message: Optional[SystemMessage] = None, |
|
*, |
|
callbacks: Optional[Union[CallbackManager, List[CallbackHandler]]] = None, |
|
file_manager: Optional[FileManager] = None, |
|
) -> None: |
|
super().__init__() |
|
self.llm = llm |
|
self.memory = memory |
|
|
|
|
|
if system_message: |
|
self.system_message = system_message |
|
else: |
|
self.system_message = memory.get_system_message() |
|
if isinstance(tools, ToolManager): |
|
self._tool_manager = tools |
|
else: |
|
self._tool_manager = ToolManager(tools) |
|
if callbacks is None: |
|
callbacks = get_default_callbacks() |
|
if isinstance(callbacks, CallbackManager): |
|
self._callback_manager = callbacks |
|
else: |
|
self._callback_manager = CallbackManager(callbacks) |
|
self.file_manager = file_manager |
|
|
|
async def async_run(self, prompt: str) -> AgentResponse: |
|
await self._callback_manager.on_run_start(agent=self, prompt=prompt) |
|
agent_resp = await self._async_run(prompt) |
|
await self._callback_manager.on_run_end(agent=self, response=agent_resp) |
|
return agent_resp |
|
|
|
def load_tool(self, tool: Tool) -> None: |
|
self._tool_manager.add_tool(tool) |
|
|
|
def unload_tool(self, tool: Tool) -> None: |
|
self._tool_manager.remove_tool(tool) |
|
|
|
def reset_memory(self) -> None: |
|
self.memory.clear_chat_history() |
|
|
|
def launch_gradio_demo(self, **launch_kwargs: Any): |
|
|
|
try: |
|
import gradio as gr |
|
except ImportError: |
|
raise ImportError( |
|
"Could not import gradio, which is required for `launch_gradio_demo()`." |
|
" Please run `pip install erniebot-agent[gradio]` to install the optional dependencies." |
|
) from None |
|
|
|
raw_messages = [] |
|
|
|
def _pre_chat(text, history): |
|
history.append([text, None]) |
|
return history, gr.update(value="", interactive=False), gr.update(interactive=False) |
|
|
|
async def _chat(history): |
|
prompt = history[-1][0] |
|
if len(prompt) == 0: |
|
raise gr.Error("Prompt should not be empty.") |
|
response = await self.async_run(prompt) |
|
history[-1][1] = response.text |
|
raw_messages.extend(response.chat_history) |
|
return ( |
|
history, |
|
_messages_to_dicts(raw_messages), |
|
_messages_to_dicts(self.memory.get_messages()), |
|
) |
|
|
|
def _post_chat(): |
|
return gr.update(interactive=True), gr.update(interactive=True) |
|
|
|
def _clear(): |
|
raw_messages.clear() |
|
self.reset_memory() |
|
return None, None, None, None |
|
|
|
def _messages_to_dicts(messages): |
|
return [message.to_dict() for message in messages] |
|
|
|
with gr.Blocks( |
|
title="ERNIE Bot Agent Demo", theme=gr.themes.Soft(spacing_size="sm", text_size="md") |
|
) as demo: |
|
with gr.Column(): |
|
chatbot = gr.Chatbot( |
|
label="Chat history", |
|
latex_delimiters=[ |
|
{"left": "$$", "right": "$$", "display": True}, |
|
{"left": "$", "right": "$", "display": False}, |
|
], |
|
bubble_full_width=False, |
|
) |
|
prompt_textbox = gr.Textbox(label="Prompt", placeholder="Write a prompt here...") |
|
with gr.Row(): |
|
submit_button = gr.Button("Submit") |
|
clear_button = gr.Button("Clear") |
|
with gr.Accordion("Tools", open=False): |
|
attached_tools = self._tool_manager.get_tools() |
|
tool_descriptions = [tool.function_call_schema() for tool in attached_tools] |
|
gr.JSON(value=tool_descriptions) |
|
with gr.Accordion("Raw messages", open=False): |
|
all_messages_json = gr.JSON(label="All messages") |
|
agent_memory_json = gr.JSON(label="Messges in memory") |
|
prompt_textbox.submit( |
|
_pre_chat, |
|
inputs=[prompt_textbox, chatbot], |
|
outputs=[chatbot, prompt_textbox, submit_button], |
|
).then( |
|
_chat, |
|
inputs=[chatbot], |
|
outputs=[ |
|
chatbot, |
|
all_messages_json, |
|
agent_memory_json, |
|
], |
|
).then( |
|
_post_chat, outputs=[prompt_textbox, submit_button] |
|
) |
|
submit_button.click( |
|
_pre_chat, |
|
inputs=[prompt_textbox, chatbot], |
|
outputs=[chatbot, prompt_textbox, submit_button], |
|
).then( |
|
_chat, |
|
inputs=[chatbot], |
|
outputs=[ |
|
chatbot, |
|
all_messages_json, |
|
agent_memory_json, |
|
], |
|
).then( |
|
_post_chat, outputs=[prompt_textbox, submit_button] |
|
) |
|
clear_button.click( |
|
_clear, |
|
outputs=[ |
|
chatbot, |
|
prompt_textbox, |
|
all_messages_json, |
|
agent_memory_json, |
|
], |
|
) |
|
|
|
demo.launch(**launch_kwargs) |
|
|
|
@abc.abstractmethod |
|
async def _async_run(self, prompt: str) -> AgentResponse: |
|
raise NotImplementedError |
|
|
|
async def _async_run_tool(self, tool_name: str, tool_args: str) -> ToolResponse: |
|
tool = self._tool_manager.get_tool(tool_name) |
|
await self._callback_manager.on_tool_start(agent=self, tool=tool, input_args=tool_args) |
|
try: |
|
tool_resp = await self._async_run_tool_without_hooks(tool, tool_args) |
|
except (Exception, KeyboardInterrupt) as e: |
|
await self._callback_manager.on_tool_error(agent=self, tool=tool, error=e) |
|
raise |
|
await self._callback_manager.on_tool_end(agent=self, tool=tool, response=tool_resp) |
|
return tool_resp |
|
|
|
async def _async_run_llm(self, messages: List[Message], **opts: Any) -> LLMResponse: |
|
await self._callback_manager.on_llm_start(agent=self, llm=self.llm, messages=messages) |
|
try: |
|
llm_resp = await self._async_run_llm_without_hooks(messages, **opts) |
|
except (Exception, KeyboardInterrupt) as e: |
|
await self._callback_manager.on_llm_error(agent=self, llm=self.llm, error=e) |
|
raise |
|
await self._callback_manager.on_llm_end(agent=self, llm=self.llm, response=llm_resp) |
|
return llm_resp |
|
|
|
async def _async_run_tool_without_hooks(self, tool: Tool, tool_args: str) -> ToolResponse: |
|
bnd_args = self._parse_tool_args(tool, tool_args) |
|
|
|
|
|
|
|
input_files = await self._sniff_and_extract_files_from_args(bnd_args.arguments, tool, "input") |
|
tool_ret = await tool(*bnd_args.args, **bnd_args.kwargs) |
|
output_files = await self._sniff_and_extract_files_from_args(tool_ret, tool, "output") |
|
tool_ret_json = json.dumps(tool_ret, ensure_ascii=False) |
|
return ToolResponse(json=tool_ret_json, files=input_files + output_files) |
|
|
|
async def _async_run_llm_without_hooks( |
|
self, messages: List[Message], functions=None, **opts: Any |
|
) -> LLMResponse: |
|
llm_ret = await self.llm.async_chat(messages, functions=functions, stream=False, **opts) |
|
return LLMResponse(message=llm_ret) |
|
|
|
def _parse_tool_args(self, tool: Tool, tool_args: str) -> inspect.BoundArguments: |
|
args_dict = json.loads(tool_args) |
|
if not isinstance(args_dict, dict): |
|
raise ValueError("`tool_args` cannot be interpreted as a dict.") |
|
|
|
sig = inspect.signature(tool.__call__) |
|
bnd_args = sig.bind(**args_dict) |
|
bnd_args.apply_defaults() |
|
return bnd_args |
|
|
|
async def _sniff_and_extract_files_from_args( |
|
self, args: Dict[str, Any], tool: Tool, file_type: Literal["input", "output"] |
|
) -> List[AgentFile]: |
|
agent_files: List[AgentFile] = [] |
|
for val in args.values(): |
|
if isinstance(val, str): |
|
if is_local_file_id(val): |
|
if self.file_manager is None: |
|
logger.warning( |
|
f"A file is used by {repr(tool)}, but the agent has no file manager to fetch it." |
|
) |
|
continue |
|
file = self.file_manager.look_up_file_by_id(val) |
|
if file is None: |
|
raise RuntimeError(f"Unregistered ID {repr(val)} is used by {repr(tool)}.") |
|
elif is_remote_file_id(val): |
|
if self.file_manager is None: |
|
logger.warning( |
|
f"A file is used by {repr(tool)}, but the agent has no file manager to fetch it." |
|
) |
|
continue |
|
file = self.file_manager.look_up_file_by_id(val) |
|
if file is None: |
|
file = await self.file_manager.retrieve_remote_file_by_id(val) |
|
else: |
|
continue |
|
agent_files.append(AgentFile(file=file, type=file_type, used_by=tool.tool_name)) |
|
return agent_files |
|
|