# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations import inspect from dataclasses import dataclass from typing import Any, Dict, List, Optional, Type, get_args from erniebot_agent.utils.logging import logger from pydantic import BaseModel, Field, create_model from pydantic.fields import FieldInfo INVALID_FIELD_NAME = "__invalid_field_name__" def is_optional_type(type: Type): args = get_args(type) if len(args) == 0: return False return len([arg for arg in args if arg is None.__class__]) > 0 def get_typing_list_type(type): """get typing.List[T] element type Args: type (typing.List): Generics type """ # 1. checking list type if getattr(type, "_name", None) != "List": return None arg_type = get_args(type)[0] return json_type(arg_type) def json_type(type: Optional[Type[object]] = None): if type is None: return "object" mapping = { int: "integer", str: "string", list: "array", List: "array", float: "number", ToolParameterView: "object", } if inspect.isclass(type) and issubclass(type, ToolParameterView): return "object" if getattr(type, "_name", None) == "List": return "array" if type not in mapping: args = [arg for arg in get_args(type) if arg is not None.__class__] if len(args) > 1 or len(args) == 0: raise ValueError( "only support simple type: FieldType=int/str/float/ToolParameterView, " "so the target type should be one of: FieldType, List[FieldType], " f"Optional[FieldType], but receive {type}" ) type = args[0] if type in mapping: return mapping[type] if inspect.isclass(type) and issubclass(type, ToolParameterView): return "object" return str(type) def python_type_from_json_type(json_type_dict: dict) -> Type[object]: simple_types = {"integer": int, "string": str, "number": float, "object": ToolParameterView} if json_type_dict["type"] in simple_types: return simple_types[json_type_dict["type"]] assert ( json_type_dict["type"] == "array" ), f"only support simple_types<{','.join(simple_types)}> and array type" assert "type" in json_type_dict["items"], " field must be defined when 'type'=array" json_type_value = json_type_dict["items"]["type"] if json_type_value == "string": return List[str] if json_type_value == "integer": return List[int] if json_type_value == "number": return List[float] if json_type_value == "object": return List[ToolParameterView] raise ValueError(f"unsupported data type: {json_type_value}") def scrub_dict(d: dict, remove_empty_dict: bool = False) -> Optional[dict]: """remove empty Value node, function_call_schema: require Args: d (dict): the instance of dictionary remove_empty_dict (bool): whether remove empty dict Returns: dict: the dictionary data after slimming down """ if type(d) is dict: result = {} for k, v in d.items(): v = scrub_dict(v, remove_empty_dict) if v is not None: result[k] = v if len(result) == 0: if not remove_empty_dict: return {} return None return result elif isinstance(d, list): return [scrub_dict(item, remove_empty_dict) for item in d] # type: ignore else: return d class OpenAPIProperty(BaseModel): type: str description: Optional[str] = None required: Optional[List[str]] = None items: dict = Field(default_factory=dict) properties: dict = Field(default_factory=dict) def get_field_openapi_property(field_info: FieldInfo) -> OpenAPIProperty: """convert pydantic FieldInfo instance to OpenAPIProperty value Args: field_info (FieldInfo): the field instance Returns: OpenAPIProperty: the converted OpenAPI Property """ typing_list_type = get_typing_list_type(field_info.annotation) if typing_list_type is not None: field_type = "array" elif is_optional_type(field_info.annotation): field_type = json_type(get_args(field_info.annotation)[0]) else: field_type = json_type(field_info.annotation) property = { "type": field_type, "description": field_info.description, } if property["type"] == "array": if typing_list_type == "object": list_type: Type[ToolParameterView] = get_args(field_info.annotation)[0] property["items"] = list_type.to_openapi_dict() else: property["items"] = {"type": typing_list_type} elif property["type"] == "object": if is_optional_type(field_info.annotation): field_type_class: Type[ToolParameterView] = get_args(field_info.annotation)[0] else: field_type_class = field_info.annotation openapi_dict = field_type_class.to_openapi_dict() property.update(openapi_dict) property["description"] = property.get("description", "") return OpenAPIProperty(**property) class ToolParameterView(BaseModel): @classmethod def from_openapi_dict(cls, name, schema: dict) -> Type[ToolParameterView]: """parse openapi component schemas to ParameterView Args: response_or_returns (dict): the content of status code Returns: _type_: _description_ """ # TODO(wj-Mcat): to load Optional field fields = {} for field_name, field_dict in schema.get("properties", {}).items(): field_type = python_type_from_json_type(field_dict) if field_type is List[ToolParameterView]: SubParameterView: Type[ToolParameterView] = ToolParameterView.from_openapi_dict( field_name, field_dict["items"] ) field_type = List[SubParameterView] # type: ignore # TODO(wj-Mcat): remove supporting for `summary` field if "summary" in field_dict: description = field_dict["summary"] logger.info("`summary` field will be deprecated, please use `description`") if "description" in field_dict: logger.info("`description` field will be used instead of `summary`") description = field_dict["description"] else: description = field_dict.get("description", None) description = description or "" field = FieldInfo(annotation=field_type, description=description) # TODO(wj-Mcat): to handle list field required & not-required # if get_typing_list_type(field_type) is not None: # field.default_factory = list fields[field_name] = (field_type, field) return create_model("OpenAPIParameterView", __base__=ToolParameterView, **fields) @classmethod def to_openapi_dict(cls) -> dict: """convert ParametersView to openapi spec dict Returns: dict: schema of openapi """ required_names, properties = [], {} for field_name, field_info in cls.model_fields.items(): if field_info.is_required() and not is_optional_type(field_info.annotation): required_names.append(field_name) properties[field_name] = dict(get_field_openapi_property(field_info)) result = { "type": "object", "properties": properties, } if len(required_names) > 0: result["required"] = required_names result = scrub_dict(result, remove_empty_dict=True) # type: ignore return result or {} @classmethod def function_call_schema(cls) -> dict: """get function_call schame Returns: dict: the schema of function_call """ return cls.to_openapi_dict() @classmethod def from_dict(cls, field_map: Dict[str, Any]): """ Class method to create a Pydantic model dynamically based on a dictionary. Args: field_map (Dict[str, Any]): A dictionary mapping field names to their corresponding type and description. Returns: PydanticModel: A dynamically created Pydantic model with fields specified by the input dictionary. Note: This method is used to create a Pydantic model dynamically based on the provided dictionary, where each field's type and description are specified in the input. """ fields = {} for field_name, field_dict in field_map.items(): field_type = field_dict["type"] description = field_dict["description"] field = FieldInfo(annotation=field_type, description=description) fields[field_name] = (field_type, field) return create_model(cls.__name__, __base__=ToolParameterView, **fields) @dataclass class RemoteToolView: uri: str method: str name: str description: str parameters: Optional[Type[ToolParameterView]] = None parameters_description: Optional[str] = None returns: Optional[Type[ToolParameterView]] = None returns_description: Optional[str] = None returns_ref_uri: Optional[str] = None parameters_ref_uri: Optional[str] = None def to_openapi_dict(self): result = { "operationId": self.name, "description": self.description, } if self.returns is not None: response = { "200": { "description": self.returns_description, "content": { "application/json": { "schema": {"$ref": "#/components/schemas/" + (self.returns_ref_uri or "")} } }, } } result["responses"] = response if self.parameters is not None: parameters = { "required": True, "content": { "application/json": { "schema": {"$ref": "#/components/schemas/" + (self.parameters_ref_uri or "")} } }, } result["requestBody"] = parameters return {self.method: result} @staticmethod def from_openapi_dict( uri: str, method: str, path_info: dict, parameters_views: dict[str, Type[ToolParameterView]] ) -> RemoteToolView: """construct RemoteToolView from openapi spec-dict info Args: uri (str): the url path of remote tool method (str): http method: one of [get, post, put, delete] path_info (dict): the spec info of remote tool parameters_views (dict[str, ParametersView]): the dict of parameters views which are the schema of input/output of tool Returns: RemoteToolView: the instance of remote tool view """ parameters_ref_uri, returns_ref_uri = None, None parameters, parameters_description = None, None if "requestBody" in path_info: request_ref = path_info["requestBody"]["content"]["application/json"]["schema"]["$ref"] parameters_ref_uri = request_ref.split("/")[-1] assert parameters_ref_uri in parameters_views parameters = parameters_views[parameters_ref_uri] parameters_description = path_info["requestBody"].get("description", None) returns, returns_description = None, None if "responses" in path_info: response_ref = list(path_info["responses"].values())[0]["content"]["application/json"]["schema"][ "$ref" ] returns_ref_uri = response_ref.split("/")[-1] assert returns_ref_uri in parameters_views returns = parameters_views[returns_ref_uri] returns_description = list(path_info["responses"].values())[0].get("description", None) return RemoteToolView( name=path_info["operationId"], parameters=parameters, parameters_description=parameters_description, returns=returns, returns_description=returns_description, description=path_info.get("description", path_info.get("summary", None)), method=method, uri=uri, # save ref id info returns_ref_uri=returns_ref_uri, parameters_ref_uri=parameters_ref_uri, ) def function_call_schema(self): inputs = { "name": self.name, "description": self.description, # TODO(wj-Mcat): read examples from openapi.yaml # "examples": [example.to_dict() for example in self.examples], } if self.parameters is not None: inputs["parameters"] = self.parameters.function_call_schema() # type: ignore else: inputs["parameters"] = {"type": "object", "properties": {}} if self.returns is not None: inputs["responses"] = self.returns.function_call_schema() # type: ignore return scrub_dict(inputs) or {} @dataclass class Endpoint: url: str @dataclass class EndpointInfo: title: str description: str version: str