File size: 3,984 Bytes
569cdb0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import inspect
from typing import TYPE_CHECKING, Any, List, Union, final

from erniebot_agent.agents.callback.event import EventType
from erniebot_agent.agents.callback.handlers.base import CallbackHandler
from erniebot_agent.agents.schema import AgentResponse, LLMResponse, ToolResponse
from erniebot_agent.chat_models.base import ChatModel
from erniebot_agent.messages import Message
from erniebot_agent.tools.base import Tool

if TYPE_CHECKING:
    from erniebot_agent.agents.base import Agent


@final
class CallbackManager(object):
    def __init__(self, handlers: List[CallbackHandler]):
        super().__init__()
        self._handlers = handlers

    @property
    def handlers(self) -> List[CallbackHandler]:
        return self._handlers

    def add_handler(self, handler: CallbackHandler):
        if handler in self._handlers:
            raise RuntimeError(f"The callback handler {handler} is already registered.")
        self._handlers.append(handler)

    def remove_handler(self, handler):
        try:
            self._handlers.remove(handler)
        except ValueError as e:
            raise RuntimeError(f"The callback handler {handler} is not registered.") from e

    def set_handlers(self, handlers: List[CallbackHandler]):
        self._handlers = []
        for handler in handlers:
            self.add_handler(handler)

    def remove_all_handlers(self):
        self._handlers = []

    async def handle_event(self, event_type: EventType, *args: Any, **kwargs: Any) -> None:
        callback_name = "on_" + event_type.value
        for handler in self._handlers:
            callback = getattr(handler, callback_name, None)
            if not inspect.iscoroutinefunction(callback):
                raise TypeError("Callback must be a coroutine function.")
            await callback(*args, **kwargs)

    async def on_run_start(self, agent: Agent, prompt: str) -> None:
        await self.handle_event(EventType.RUN_START, agent=agent, prompt=prompt)

    async def on_llm_start(self, agent: Agent, llm: ChatModel, messages: List[Message]) -> None:
        await self.handle_event(EventType.LLM_START, agent=agent, llm=llm, messages=messages)

    async def on_llm_end(self, agent: Agent, llm: ChatModel, response: LLMResponse) -> None:
        await self.handle_event(EventType.LLM_END, agent=agent, llm=llm, response=response)

    async def on_llm_error(
        self, agent: Agent, llm: ChatModel, error: Union[Exception, KeyboardInterrupt]
    ) -> None:
        await self.handle_event(EventType.LLM_ERROR, agent=agent, llm=llm, error=error)

    async def on_tool_start(self, agent: Agent, tool: Tool, input_args: str) -> None:
        await self.handle_event(EventType.TOOL_START, agent=agent, tool=tool, input_args=input_args)

    async def on_tool_end(self, agent: Agent, tool: Tool, response: ToolResponse) -> None:
        await self.handle_event(EventType.TOOL_END, agent=agent, tool=tool, response=response)

    async def on_tool_error(
        self, agent: Agent, tool: Tool, error: Union[Exception, KeyboardInterrupt]
    ) -> None:
        await self.handle_event(EventType.TOOL_ERROR, agent=agent, tool=tool, error=error)

    async def on_run_end(self, agent: Agent, response: AgentResponse) -> None:
        await self.handle_event(EventType.RUN_END, agent=agent, response=response)