File size: 2,569 Bytes
569cdb0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
from unittest import mock

import pytest
from erniebot_agent.agents.base import Agent
from erniebot_agent.agents.callback.callback_manager import CallbackManager
from erniebot_agent.agents.schema import AgentResponse
from erniebot_agent.chat_models.base import ChatModel
from erniebot_agent.messages import AIMessage
from erniebot_agent.tools.base import Tool

from tests.unit_tests.testing_utils.mocks.mock_callback_handler import (
    MockCallbackHandler,
)


@pytest.mark.asyncio
async def test_callback_manager_hit():
    def _assert_num_calls(handler):
        assert handler.run_starts == 1
        assert handler.llm_starts == 1
        assert handler.llm_ends == 1
        assert handler.llm_errors == 1
        assert handler.tool_starts == 1
        assert handler.tool_ends == 1
        assert handler.tool_errors == 1
        assert handler.run_ends == 1

    agent = mock.Mock(spec=Agent)
    llm = mock.Mock(spec=ChatModel)
    tool = mock.Mock(spec=Tool)

    handler1 = MockCallbackHandler()
    handler2 = MockCallbackHandler()
    callback_manager = CallbackManager(handlers=[handler1, handler2])

    await callback_manager.on_run_start(agent, "")
    await callback_manager.on_llm_start(agent, llm, [])
    await callback_manager.on_llm_end(
        agent,
        llm,
        AIMessage(content="", function_call=None, token_usage={"prompt_tokens": 0, "completion_tokens": 0}),
    )
    await callback_manager.on_llm_error(agent, llm, Exception())
    await callback_manager.on_tool_start(agent, tool, "{}")
    await callback_manager.on_tool_end(agent, tool, "{}")
    await callback_manager.on_tool_error(agent, tool, Exception())
    await callback_manager.on_run_end(
        agent, AgentResponse(text="", chat_history=[], actions=[], files=[], status="FINISHED")
    )

    _assert_num_calls(handler1)
    _assert_num_calls(handler2)


@pytest.mark.asyncio
async def test_callback_manager_add_remove_handlers():
    handler1 = MockCallbackHandler()
    handler2 = MockCallbackHandler()
    callback_manager = CallbackManager(handlers=[handler1])
    assert len(callback_manager.handlers) == 1
    with pytest.raises(RuntimeError):
        callback_manager.add_handler(handler1)
    callback_manager.remove_handler(handler1)
    assert len(callback_manager.handlers) == 0
    callback_manager.add_handler(handler1)
    assert len(callback_manager.handlers) == 1
    callback_manager.add_handler(handler2)
    assert len(callback_manager.handlers) == 2
    callback_manager.remove_all_handlers()
    assert len(callback_manager.handlers) == 0