File size: 2,569 Bytes
569cdb0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 |
from unittest import mock
import pytest
from erniebot_agent.agents.base import Agent
from erniebot_agent.agents.callback.callback_manager import CallbackManager
from erniebot_agent.agents.schema import AgentResponse
from erniebot_agent.chat_models.base import ChatModel
from erniebot_agent.messages import AIMessage
from erniebot_agent.tools.base import Tool
from tests.unit_tests.testing_utils.mocks.mock_callback_handler import (
MockCallbackHandler,
)
@pytest.mark.asyncio
async def test_callback_manager_hit():
def _assert_num_calls(handler):
assert handler.run_starts == 1
assert handler.llm_starts == 1
assert handler.llm_ends == 1
assert handler.llm_errors == 1
assert handler.tool_starts == 1
assert handler.tool_ends == 1
assert handler.tool_errors == 1
assert handler.run_ends == 1
agent = mock.Mock(spec=Agent)
llm = mock.Mock(spec=ChatModel)
tool = mock.Mock(spec=Tool)
handler1 = MockCallbackHandler()
handler2 = MockCallbackHandler()
callback_manager = CallbackManager(handlers=[handler1, handler2])
await callback_manager.on_run_start(agent, "")
await callback_manager.on_llm_start(agent, llm, [])
await callback_manager.on_llm_end(
agent,
llm,
AIMessage(content="", function_call=None, token_usage={"prompt_tokens": 0, "completion_tokens": 0}),
)
await callback_manager.on_llm_error(agent, llm, Exception())
await callback_manager.on_tool_start(agent, tool, "{}")
await callback_manager.on_tool_end(agent, tool, "{}")
await callback_manager.on_tool_error(agent, tool, Exception())
await callback_manager.on_run_end(
agent, AgentResponse(text="", chat_history=[], actions=[], files=[], status="FINISHED")
)
_assert_num_calls(handler1)
_assert_num_calls(handler2)
@pytest.mark.asyncio
async def test_callback_manager_add_remove_handlers():
handler1 = MockCallbackHandler()
handler2 = MockCallbackHandler()
callback_manager = CallbackManager(handlers=[handler1])
assert len(callback_manager.handlers) == 1
with pytest.raises(RuntimeError):
callback_manager.add_handler(handler1)
callback_manager.remove_handler(handler1)
assert len(callback_manager.handlers) == 0
callback_manager.add_handler(handler1)
assert len(callback_manager.handlers) == 1
callback_manager.add_handler(handler2)
assert len(callback_manager.handlers) == 2
callback_manager.remove_all_handlers()
assert len(callback_manager.handlers) == 0
|