markqiu's picture
百度文心一言的例子
569cdb0
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Union
from erniebot_agent.agents.base import Agent, ToolManager
from erniebot_agent.agents.callback.callback_manager import CallbackManager
from erniebot_agent.agents.callback.handlers.base import CallbackHandler
from erniebot_agent.agents.schema import AgentAction, AgentFile, AgentResponse
from erniebot_agent.chat_models.base import ChatModel
from erniebot_agent.file_io.file_manager import FileManager
from erniebot_agent.memory.base import Memory
from erniebot_agent.messages import (
FunctionMessage,
HumanMessage,
Message,
SystemMessage,
)
from erniebot_agent.tools.base import Tool
_MAX_STEPS = 5
class FunctionalAgent(Agent):
def __init__(
self,
llm: ChatModel,
tools: Union[ToolManager, List[Tool]],
memory: Memory,
system_message: Optional[SystemMessage] = None,
*,
callbacks: Optional[Union[CallbackManager, List[CallbackHandler]]] = None,
file_manager: Optional[FileManager] = None,
max_steps: Optional[int] = None,
) -> None:
super().__init__(
llm=llm,
tools=tools,
memory=memory,
system_message=system_message,
callbacks=callbacks,
file_manager=file_manager,
)
if max_steps is not None:
if max_steps <= 0:
raise ValueError("Invalid `max_steps` value")
self.max_steps = max_steps
else:
self.max_steps = _MAX_STEPS
async def _async_run(self, prompt: str) -> AgentResponse:
chat_history: List[Message] = []
actions_taken: List[AgentAction] = []
files_involved: List[AgentFile] = []
ask = HumanMessage(content=prompt)
num_steps_taken = 0
next_step_input: Message = ask
while num_steps_taken < self.max_steps:
curr_step_output = await self._async_step(
next_step_input, chat_history, actions_taken, files_involved
)
if curr_step_output is None:
response = self._create_finished_response(chat_history, actions_taken, files_involved)
self.memory.add_message(chat_history[0])
self.memory.add_message(chat_history[-1])
return response
num_steps_taken += 1
next_step_input = curr_step_output
response = self._create_stopped_response(chat_history, actions_taken, files_involved)
return response
async def _async_step(
self,
step_input,
chat_history: List[Message],
actions: List[AgentAction],
files: List[AgentFile],
) -> Optional[Message]:
maybe_action = await self._async_plan(step_input, chat_history)
if isinstance(maybe_action, AgentAction):
action: AgentAction = maybe_action
tool_resp = await self._async_run_tool(tool_name=action.tool_name, tool_args=action.tool_args)
actions.append(action)
files.extend(tool_resp.files)
return FunctionMessage(name=action.tool_name, content=tool_resp.json)
else:
return None
async def _async_plan(
self, input_message: Message, chat_history: List[Message]
) -> Optional[AgentAction]:
chat_history.append(input_message)
messages = self.memory.get_messages() + chat_history
llm_resp = await self._async_run_llm(
messages=messages,
functions=self._tool_manager.get_tool_schemas(),
system=self.system_message.content if self.system_message is not None else None,
)
output_message = llm_resp.message
chat_history.append(output_message)
if output_message.function_call is not None:
return AgentAction(
tool_name=output_message.function_call["name"], # type: ignore
tool_args=output_message.function_call["arguments"],
)
else:
return None
def _create_finished_response(
self,
chat_history: List[Message],
actions: List[AgentAction],
files: List[AgentFile],
) -> AgentResponse:
last_message = chat_history[-1]
return AgentResponse(
text=last_message.content,
chat_history=chat_history,
actions=actions,
files=files,
status="FINISHED",
)
def _create_stopped_response(
self,
chat_history: List[Message],
actions: List[AgentAction],
files: List[AgentFile],
) -> AgentResponse:
return AgentResponse(
text="Agent run stopped early.",
chat_history=chat_history,
actions=actions,
files=files,
status="STOPPED",
)