markqiu's picture
百度文心一言的例子
569cdb0
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import unittest
from typing import List, Optional
from erniebot_agent.tools.base import RemoteToolkit
from erniebot_agent.tools.schema import (
ToolParameterView,
get_typing_list_type,
is_optional_type,
json_type,
)
from openapi_spec_validator.readers import read_from_filename
from pydantic import Field
class TestToolSchema(unittest.TestCase):
openapi_file = "./tests/fixtures/openapi.yaml"
def test_plugin_schema(self):
schema = RemoteToolkit.from_openapi_file(self.openapi_file)
self.assertEqual(schema.info.title, "单词本")
self.assertEqual(schema.servers[0].url, "http://127.0.0.1:8081")
def test_load_and_save(self):
"""function_call requires empty fields, eg: items: {}, but yaml file doesn't
contain any empty field
"""
spec_dict = read_from_filename(self.openapi_file)
schema = RemoteToolkit.from_openapi_file(self.openapi_file)
saved_spec_dict = schema.to_openapi_dict()
self.assertEqual(spec_dict[0], saved_spec_dict)
def test_function_call_schemas(self):
toolkit = RemoteToolkit.from_openapi_file(self.openapi_file)
function_call_schemas = [tool.function_call_schema() for tool in toolkit.get_tools()]
self.assertEqual(len(function_call_schemas), 4)
self.assertEqual(function_call_schemas[0]["name"], "getWordbook")
self.assertEqual(function_call_schemas[0]["responses"]["required"], ["wordbook"])
self.assertEqual(function_call_schemas[0]["responses"]["properties"]["wordbook"]["type"], "array")
self.assertEqual(function_call_schemas[3]["name"], "deleteWord")
def test_get_typing_list_type(self):
result = get_typing_list_type(List[int])
self.assertEqual(result, "integer")
result = get_typing_list_type(List[str])
self.assertEqual(result, "string")
result = get_typing_list_type(int)
self.assertEqual(result, None)
result = get_typing_list_type(dict)
self.assertEqual(result, None)
result = get_typing_list_type(List[ToolParameterView])
self.assertEqual(result, "object")
def test_json_type(self):
result = json_type(List[int])
self.assertEqual(result, "array")
result = json_type(int)
self.assertEqual(result, "integer")
result = json_type(float)
self.assertEqual(result, "number")
result = json_type(ToolParameterView)
self.assertEqual(result, "object")
def test_list_tool_parameter_view(self):
class SearchResponseDocument(ToolParameterView):
document: str = Field(description="和query相关的规章片段")
filename: str = Field(description="规章名称")
page_num: int = Field(description="规章页数")
class SearchToolOutputView(ToolParameterView):
documents: List[SearchResponseDocument] = Field(description="检索结果,内容为住房和城乡建设部规章中和query相关的规章片段")
open_api_dict = SearchToolOutputView.to_openapi_dict()
expected_schema = {
"type": "object",
"required": ["documents"],
"properties": {
"documents": {
"type": "array",
"description": "检索结果,内容为住房和城乡建设部规章中和query相关的规章片段",
"items": {
"type": "object",
"required": ["document", "filename", "page_num"],
"properties": {
"document": {"type": "string", "description": "和query相关的规章片段"},
"filename": {"type": "string", "description": "规章名称"},
"page_num": {"type": "integer", "description": "规章页数"},
},
},
}
},
}
self.assertDictEqual(open_api_dict, expected_schema)
def test_optional_tool_parameter_view(self):
class SearchResponseDocument(ToolParameterView):
document: str = Field(description="和query相关的规章片段")
filename: str = Field(description="规章名称")
page_num: int = Field(description="规章页数")
class SearchToolOutputView(ToolParameterView):
name: str = Field(description="测试名称")
documents: Optional[SearchResponseDocument] = Field(
description="检索结果,内容为住房和城乡建设部规章中和query相关的规章片段", default_factory=list
)
openapi_dict = SearchToolOutputView.to_openapi_dict()
expected_openapi_dict = {
"type": "object",
"properties": {
"name": {
"type": "string",
"description": "测试名称",
},
"documents": {
"type": "object",
"description": "检索结果,内容为住房和城乡建设部规章中和query相关的规章片段",
"required": ["document", "filename", "page_num"],
"properties": {
"document": {
"type": "string",
"description": "和query相关的规章片段",
},
"filename": {
"type": "string",
"description": "规章名称",
},
"page_num": {
"type": "integer",
"description": "规章页数",
},
},
},
},
"required": ["name"],
}
self.assertDictEqual(openapi_dict, expected_openapi_dict)
def test_enum_value(self):
# TODO(wj-Mcat): to support enum[int/str/float]
pass
def test_is_optional_type(self):
self.assertFalse(is_optional_type(List[int]))
self.assertFalse(is_optional_type(int))
self.assertFalse(is_optional_type(str))
self.assertTrue(is_optional_type(Optional[int]))
self.assertTrue(is_optional_type(Optional[str]))
self.assertTrue(is_optional_type(Optional[ToolParameterView]))
def test_load_examples(self):
toolkit = RemoteToolkit.from_openapi_file("./tests/fixtures/openapi.yaml")
toolkit.examples = toolkit.load_examples_yaml("./tests/fixtures/examples.yaml")
self.assertEqual(len(toolkit.examples), 12)
# add_word examples
examples = toolkit.get_tool("getWordbook").examples
self.assertEqual(len(examples), 2)
self.assertEqual(examples[0].content, "展示单词列表")
self.assertEqual(examples[1].function_call["name"], "getWordbook")
self.assertEqual(examples[1].function_call["thoughts"], "这是一个展示单词本的需求")