File size: 3,137 Bytes
569cdb0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import os
import unittest

import pytest
from erniebot_agent.chat_models.erniebot import ERNIEBot
from erniebot_agent.message import AIMessage, FunctionMessage, HumanMessage


class TestChatModel(unittest.IsolatedAsyncioTestCase):
    @pytest.mark.asyncio
    async def test_chat(self):
        eb = ERNIEBot(
            model="ernie-bot-turbo", api_type="aistudio", access_token=os.environ["AISTUDIO_ACCESS_TOKEN"]
        )
        messages = [
            HumanMessage(content="你好!"),
        ]
        res = await eb.async_chat(messages, stream=False)
        self.assertTrue(isinstance(res, AIMessage))
        self.assertIsNotNone(res.content)

    @pytest.mark.asyncio
    async def test_function_call(self):
        functions = [
            {
                "name": "get_current_temperature",
                "description": "获取指定城市的气温",
                "parameters": {
                    "type": "object",
                    "properties": {
                        "location": {
                            "type": "string",
                            "description": "城市名称",
                        },
                        "unit": {
                            "type": "string",
                            "enum": [
                                "摄氏度",
                                "华氏度",
                            ],
                        },
                    },
                    "required": [
                        "location",
                        "unit",
                    ],
                },
                "responses": {
                    "type": "object",
                    "properties": {
                        "temperature": {
                            "type": "integer",
                            "description": "城市气温",
                        },
                        "unit": {
                            "type": "string",
                            "enum": [
                                "摄氏度",
                                "华氏度",
                            ],
                        },
                    },
                },
            }
        ]
        # use ernie-bot here since ernie-bot-turbo doesn't support function call
        eb = ERNIEBot(
            model="ernie-bot", api_type="aistudio", access_token=os.environ["AISTUDIO_ACCESS_TOKEN"]
        )
        messages = [
            HumanMessage(content="深圳市今天的气温是多少摄氏度?"),
        ]
        res = await eb.async_chat(messages, functions=functions)
        self.assertTrue(isinstance(res, AIMessage))
        self.assertIsNone(res.content)
        self.assertIsNotNone(res.function_call)
        self.assertEqual(res.function_call["name"], "get_current_temperature")

        messages.append(res)
        messages.append(
            FunctionMessage(name="get_current_temperature", content='{"temperature":25,"unit":"摄氏度"}')
        )
        res = await eb.async_chat(messages, functions=functions)
        self.assertTrue(isinstance(res, AIMessage))
        self.assertIsNotNone(res.content)