|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from __future__ import annotations |
|
|
|
import unittest |
|
from typing import List, Optional |
|
|
|
from erniebot_agent.tools.base import RemoteToolkit |
|
from erniebot_agent.tools.schema import ( |
|
ToolParameterView, |
|
get_typing_list_type, |
|
is_optional_type, |
|
json_type, |
|
) |
|
from openapi_spec_validator.readers import read_from_filename |
|
from pydantic import Field |
|
|
|
|
|
class TestToolSchema(unittest.TestCase): |
|
openapi_file = "./tests/fixtures/openapi.yaml" |
|
|
|
def test_plugin_schema(self): |
|
schema = RemoteToolkit.from_openapi_file(self.openapi_file) |
|
|
|
self.assertEqual(schema.info.title, "单词本") |
|
self.assertEqual(schema.servers[0].url, "http://127.0.0.1:8081") |
|
|
|
def test_load_and_save(self): |
|
"""function_call requires empty fields, eg: items: {}, but yaml file doesn't |
|
contain any empty field |
|
""" |
|
spec_dict = read_from_filename(self.openapi_file) |
|
schema = RemoteToolkit.from_openapi_file(self.openapi_file) |
|
saved_spec_dict = schema.to_openapi_dict() |
|
self.assertEqual(spec_dict[0], saved_spec_dict) |
|
|
|
def test_function_call_schemas(self): |
|
toolkit = RemoteToolkit.from_openapi_file(self.openapi_file) |
|
function_call_schemas = [tool.function_call_schema() for tool in toolkit.get_tools()] |
|
self.assertEqual(len(function_call_schemas), 4) |
|
|
|
self.assertEqual(function_call_schemas[0]["name"], "getWordbook") |
|
self.assertEqual(function_call_schemas[0]["responses"]["required"], ["wordbook"]) |
|
self.assertEqual(function_call_schemas[0]["responses"]["properties"]["wordbook"]["type"], "array") |
|
self.assertEqual(function_call_schemas[3]["name"], "deleteWord") |
|
|
|
def test_get_typing_list_type(self): |
|
result = get_typing_list_type(List[int]) |
|
self.assertEqual(result, "integer") |
|
|
|
result = get_typing_list_type(List[str]) |
|
self.assertEqual(result, "string") |
|
|
|
result = get_typing_list_type(int) |
|
self.assertEqual(result, None) |
|
|
|
result = get_typing_list_type(dict) |
|
self.assertEqual(result, None) |
|
|
|
result = get_typing_list_type(List[ToolParameterView]) |
|
self.assertEqual(result, "object") |
|
|
|
def test_json_type(self): |
|
result = json_type(List[int]) |
|
self.assertEqual(result, "array") |
|
|
|
result = json_type(int) |
|
self.assertEqual(result, "integer") |
|
|
|
result = json_type(float) |
|
self.assertEqual(result, "number") |
|
|
|
result = json_type(ToolParameterView) |
|
self.assertEqual(result, "object") |
|
|
|
def test_list_tool_parameter_view(self): |
|
class SearchResponseDocument(ToolParameterView): |
|
document: str = Field(description="和query相关的规章片段") |
|
filename: str = Field(description="规章名称") |
|
page_num: int = Field(description="规章页数") |
|
|
|
class SearchToolOutputView(ToolParameterView): |
|
documents: List[SearchResponseDocument] = Field(description="检索结果,内容为住房和城乡建设部规章中和query相关的规章片段") |
|
|
|
open_api_dict = SearchToolOutputView.to_openapi_dict() |
|
|
|
expected_schema = { |
|
"type": "object", |
|
"required": ["documents"], |
|
"properties": { |
|
"documents": { |
|
"type": "array", |
|
"description": "检索结果,内容为住房和城乡建设部规章中和query相关的规章片段", |
|
"items": { |
|
"type": "object", |
|
"required": ["document", "filename", "page_num"], |
|
"properties": { |
|
"document": {"type": "string", "description": "和query相关的规章片段"}, |
|
"filename": {"type": "string", "description": "规章名称"}, |
|
"page_num": {"type": "integer", "description": "规章页数"}, |
|
}, |
|
}, |
|
} |
|
}, |
|
} |
|
self.assertDictEqual(open_api_dict, expected_schema) |
|
|
|
def test_optional_tool_parameter_view(self): |
|
class SearchResponseDocument(ToolParameterView): |
|
document: str = Field(description="和query相关的规章片段") |
|
filename: str = Field(description="规章名称") |
|
page_num: int = Field(description="规章页数") |
|
|
|
class SearchToolOutputView(ToolParameterView): |
|
name: str = Field(description="测试名称") |
|
documents: Optional[SearchResponseDocument] = Field( |
|
description="检索结果,内容为住房和城乡建设部规章中和query相关的规章片段", default_factory=list |
|
) |
|
|
|
openapi_dict = SearchToolOutputView.to_openapi_dict() |
|
expected_openapi_dict = { |
|
"type": "object", |
|
"properties": { |
|
"name": { |
|
"type": "string", |
|
"description": "测试名称", |
|
}, |
|
"documents": { |
|
"type": "object", |
|
"description": "检索结果,内容为住房和城乡建设部规章中和query相关的规章片段", |
|
"required": ["document", "filename", "page_num"], |
|
"properties": { |
|
"document": { |
|
"type": "string", |
|
"description": "和query相关的规章片段", |
|
}, |
|
"filename": { |
|
"type": "string", |
|
"description": "规章名称", |
|
}, |
|
"page_num": { |
|
"type": "integer", |
|
"description": "规章页数", |
|
}, |
|
}, |
|
}, |
|
}, |
|
"required": ["name"], |
|
} |
|
self.assertDictEqual(openapi_dict, expected_openapi_dict) |
|
|
|
def test_enum_value(self): |
|
|
|
pass |
|
|
|
def test_is_optional_type(self): |
|
self.assertFalse(is_optional_type(List[int])) |
|
self.assertFalse(is_optional_type(int)) |
|
self.assertFalse(is_optional_type(str)) |
|
|
|
self.assertTrue(is_optional_type(Optional[int])) |
|
self.assertTrue(is_optional_type(Optional[str])) |
|
self.assertTrue(is_optional_type(Optional[ToolParameterView])) |
|
|
|
def test_load_examples(self): |
|
toolkit = RemoteToolkit.from_openapi_file("./tests/fixtures/openapi.yaml") |
|
toolkit.examples = toolkit.load_examples_yaml("./tests/fixtures/examples.yaml") |
|
self.assertEqual(len(toolkit.examples), 12) |
|
|
|
|
|
examples = toolkit.get_tool("getWordbook").examples |
|
|
|
self.assertEqual(len(examples), 2) |
|
self.assertEqual(examples[0].content, "展示单词列表") |
|
self.assertEqual(examples[1].function_call["name"], "getWordbook") |
|
self.assertEqual(examples[1].function_call["thoughts"], "这是一个展示单词本的需求") |
|
|