|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from __future__ import annotations |
|
|
|
import inspect |
|
from typing import TYPE_CHECKING, Any, List, Union, final |
|
|
|
from erniebot_agent.agents.callback.event import EventType |
|
from erniebot_agent.agents.callback.handlers.base import CallbackHandler |
|
from erniebot_agent.agents.schema import AgentResponse, LLMResponse, ToolResponse |
|
from erniebot_agent.chat_models.base import ChatModel |
|
from erniebot_agent.messages import Message |
|
from erniebot_agent.tools.base import Tool |
|
|
|
if TYPE_CHECKING: |
|
from erniebot_agent.agents.base import Agent |
|
|
|
|
|
@final |
|
class CallbackManager(object): |
|
def __init__(self, handlers: List[CallbackHandler]): |
|
super().__init__() |
|
self._handlers = handlers |
|
|
|
@property |
|
def handlers(self) -> List[CallbackHandler]: |
|
return self._handlers |
|
|
|
def add_handler(self, handler: CallbackHandler): |
|
if handler in self._handlers: |
|
raise RuntimeError(f"The callback handler {handler} is already registered.") |
|
self._handlers.append(handler) |
|
|
|
def remove_handler(self, handler): |
|
try: |
|
self._handlers.remove(handler) |
|
except ValueError as e: |
|
raise RuntimeError(f"The callback handler {handler} is not registered.") from e |
|
|
|
def set_handlers(self, handlers: List[CallbackHandler]): |
|
self._handlers = [] |
|
for handler in handlers: |
|
self.add_handler(handler) |
|
|
|
def remove_all_handlers(self): |
|
self._handlers = [] |
|
|
|
async def handle_event(self, event_type: EventType, *args: Any, **kwargs: Any) -> None: |
|
callback_name = "on_" + event_type.value |
|
for handler in self._handlers: |
|
callback = getattr(handler, callback_name, None) |
|
if not inspect.iscoroutinefunction(callback): |
|
raise TypeError("Callback must be a coroutine function.") |
|
await callback(*args, **kwargs) |
|
|
|
async def on_run_start(self, agent: Agent, prompt: str) -> None: |
|
await self.handle_event(EventType.RUN_START, agent=agent, prompt=prompt) |
|
|
|
async def on_llm_start(self, agent: Agent, llm: ChatModel, messages: List[Message]) -> None: |
|
await self.handle_event(EventType.LLM_START, agent=agent, llm=llm, messages=messages) |
|
|
|
async def on_llm_end(self, agent: Agent, llm: ChatModel, response: LLMResponse) -> None: |
|
await self.handle_event(EventType.LLM_END, agent=agent, llm=llm, response=response) |
|
|
|
async def on_llm_error( |
|
self, agent: Agent, llm: ChatModel, error: Union[Exception, KeyboardInterrupt] |
|
) -> None: |
|
await self.handle_event(EventType.LLM_ERROR, agent=agent, llm=llm, error=error) |
|
|
|
async def on_tool_start(self, agent: Agent, tool: Tool, input_args: str) -> None: |
|
await self.handle_event(EventType.TOOL_START, agent=agent, tool=tool, input_args=input_args) |
|
|
|
async def on_tool_end(self, agent: Agent, tool: Tool, response: ToolResponse) -> None: |
|
await self.handle_event(EventType.TOOL_END, agent=agent, tool=tool, response=response) |
|
|
|
async def on_tool_error( |
|
self, agent: Agent, tool: Tool, error: Union[Exception, KeyboardInterrupt] |
|
) -> None: |
|
await self.handle_event(EventType.TOOL_ERROR, agent=agent, tool=tool, error=error) |
|
|
|
async def on_run_end(self, agent: Agent, response: AgentResponse) -> None: |
|
await self.handle_event(EventType.RUN_END, agent=agent, response=response) |
|
|