markqiu's picture
百度文心一言的例子
569cdb0
from unittest import mock
import pytest
from erniebot_agent.agents.base import Agent
from erniebot_agent.agents.callback.callback_manager import CallbackManager
from erniebot_agent.agents.schema import AgentResponse
from erniebot_agent.chat_models.base import ChatModel
from erniebot_agent.messages import AIMessage
from erniebot_agent.tools.base import Tool
from tests.unit_tests.testing_utils.mocks.mock_callback_handler import (
MockCallbackHandler,
)
@pytest.mark.asyncio
async def test_callback_manager_hit():
def _assert_num_calls(handler):
assert handler.run_starts == 1
assert handler.llm_starts == 1
assert handler.llm_ends == 1
assert handler.llm_errors == 1
assert handler.tool_starts == 1
assert handler.tool_ends == 1
assert handler.tool_errors == 1
assert handler.run_ends == 1
agent = mock.Mock(spec=Agent)
llm = mock.Mock(spec=ChatModel)
tool = mock.Mock(spec=Tool)
handler1 = MockCallbackHandler()
handler2 = MockCallbackHandler()
callback_manager = CallbackManager(handlers=[handler1, handler2])
await callback_manager.on_run_start(agent, "")
await callback_manager.on_llm_start(agent, llm, [])
await callback_manager.on_llm_end(
agent,
llm,
AIMessage(content="", function_call=None, token_usage={"prompt_tokens": 0, "completion_tokens": 0}),
)
await callback_manager.on_llm_error(agent, llm, Exception())
await callback_manager.on_tool_start(agent, tool, "{}")
await callback_manager.on_tool_end(agent, tool, "{}")
await callback_manager.on_tool_error(agent, tool, Exception())
await callback_manager.on_run_end(
agent, AgentResponse(text="", chat_history=[], actions=[], files=[], status="FINISHED")
)
_assert_num_calls(handler1)
_assert_num_calls(handler2)
@pytest.mark.asyncio
async def test_callback_manager_add_remove_handlers():
handler1 = MockCallbackHandler()
handler2 = MockCallbackHandler()
callback_manager = CallbackManager(handlers=[handler1])
assert len(callback_manager.handlers) == 1
with pytest.raises(RuntimeError):
callback_manager.add_handler(handler1)
callback_manager.remove_handler(handler1)
assert len(callback_manager.handlers) == 0
callback_manager.add_handler(handler1)
assert len(callback_manager.handlers) == 1
callback_manager.add_handler(handler2)
assert len(callback_manager.handlers) == 2
callback_manager.remove_all_handlers()
assert len(callback_manager.handlers) == 0