韩宇
init
1b7e88c
import json
from abc import ABC
from distutils.util import strtobool
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Union
import yaml
from omagent_core.base import BotBase
from omagent_core.models.od.schemas import Target
from omagent_core.services.handlers.sql_data_handler import SQLDataHandler
from omagent_core.utils.error import VQLError
from omagent_core.utils.logger import logging
from omagent_core.utils.plot import Annotator
from PIL import Image
from pydantic import BaseModel, model_validator
class ArgSchema(BaseModel):
"""ArgSchema defines the tool input schema. Only support one layer definition. Please prevent using complex structure."""
class Config:
"""Configuration for this pydantic object."""
extra = "allow"
arbitrary_types_allowed = True
class ArgInfo(BaseModel):
description: Optional[str]
type: str = "str"
enum: Optional[List] = None
required: Optional[bool] = True
@model_validator(mode="before")
@classmethod
def validate_all(cls, values):
for key, value in values.items():
if type(value) is str:
values[key] = cls.ArgInfo(name=value)
elif type(value) is dict:
values[key] = cls.ArgInfo(**value)
elif type(value) is cls.ArgInfo:
pass
else:
raise ValueError(
"The arg type must be one of string, dict or self.ArgInfo."
)
return values
@classmethod
def from_file(cls, schema_file: Union[str, Path]):
if type(schema_file) is str:
schema_file = Path(schema_file)
if schema_file.suffix == ".json":
with open(schema_file, "r") as f:
schema = json.load(f)
elif schema_file.suffix == ".yaml":
with open(schema_file, "r") as f:
schema = yaml.load(f, Loader=yaml.FullLoader)
else:
raise ValueError("Only support json and yaml file.")
return cls(**schema)
def generate_schema(self) -> Union[dict, list]:
required_args = []
parameters = {}
for key, value in self.model_dump(exclude_none=True).items():
parameters[key] = value
if parameters[key].pop("required"):
required_args.append(key)
return parameters, required_args
def validate_args(self, args: dict) -> dict:
if type(args) is not dict:
raise ValueError(
"ArgSchema validate only support dict, not {}".format(type(args))
)
new_args = {}
required_fields = set(
[k for k, v in self.model_dump().items() if v["required"]]
)
name_mapping = {
"str": "string",
"int": "integer",
"float": "number",
"bool": "boolean",
}
for name, value in args.items():
if name not in self.model_dump():
logging.warning(
"The input args includes an unnecessary parameter {}. Removed from the args.".format(
name
)
)
continue
if name_mapping[type(value).__name__] == self.model_dump()[name]["type"]:
if (
self.model_dump()[name]["enum"]
and value not in self.model_dump()[name]["enum"]
):
raise ValueError(
"The value of {} should be one of {}, but got {}".format(
name, str(self.model_dump()[name]["enum"]), value
)
)
new_args[name] = value
elif self.model_dump()[name]["type"] == "string":
try:
new_args[name] = str(value)
except:
raise ValueError(
"Parameter {} type expect a str value, but got a {} {}".format(
name, type(value), value
)
)
elif self.model_dump()[name]["type"] == "integer":
try:
new_args[name] = int(value)
except:
raise ValueError(
"Parameter {} type expect an int value, but got a {} {}".format(
name, type(value), value
)
)
elif self.model_dump()[name]["type"] == "number":
try:
new_args[name] = float(value)
except:
raise ValueError(
"Parameter {} type expect a float value, but got a {} {}".format(
name, type(value), value
)
)
elif self.model_dump()[name]["type"] == "boolean":
if type(value) is bool:
new_args[name] = value
else:
try:
new_args[name] = strtobool(str(value))
except:
raise ValueError(
"Parameter {} type expect a boolean value, but got a {} {}".format(
name, type(value), value
)
)
else:
raise ValueError(
"Parameter {} type expect one of string, integer, number and boolean, but got a {} {}".format(
name, self.model_dump()[name]["type"], type(value), value
)
)
if required_fields - set(new_args.keys()):
raise VQLError(
"The required fields {} are missing.".format(
required_fields - set(new_args.keys())
)
)
return new_args
class BaseTool(BotBase, ABC):
description: str
func: Optional[Callable] = None
args_schema: Optional[ArgSchema]
special_params: Dict = {}
def model_post_init(self, __context: Any) -> None:
for _, attr_value in self.__dict__.items():
if isinstance(attr_value, BotBase):
attr_value._parent = self
@property
def workflow_instance_id(self) -> str:
if hasattr(self, "_parent"):
return self._parent.workflow_instance_id
return None
@workflow_instance_id.setter
def workflow_instance_id(self, value: str):
if hasattr(self, "_parent"):
self._parent.workflow_instance_id = value
def _run(self, **input) -> str:
"""Implement this function or pass 'func' arg when initializing."""
return self.func(**input)
async def _arun(self, **input) -> str:
"""Implement this function or pass 'func' arg when initializing."""
return await self.func(**input)
def run(self, input: Any) -> str:
if self.args_schema != None:
if type(input) != dict:
raise ValueError(
"The input type must be dict when args_schema is specified."
)
self.args_schema.validate_args(input)
return self._run(**input, **self.special_params)
async def arun(self, input: Any) -> str:
if self.args_schema != None:
if type(input) != dict:
raise ValueError(
"The input type must be dict when args_schema is specified."
)
self.args_schema.validate_args(input)
return await self._arun(**input, **self.special_params)
def generate_schema(self):
if not self.args_schema:
return {
"type": "function",
"description": self.description,
"function": {
"name": self.name,
"parameters": {
"type": "object",
"name": "input",
"required": ["input"],
},
},
}
else:
properties, required = self.args_schema.generate_schema()
return {
"type": "function",
"function": {
"name": self.name,
"description": self.description,
"parameters": {
"type": "object",
"properties": properties,
"required": required,
},
},
}
class BaseModelTool(BaseTool, ABC):
# data_handler: Optional[SQLDataHandler]
def visual_prompting(
self,
image: Image.Image,
annotation: List[Target],
prompting_type: str = "label_on_img",
include_labels: Union[List, set, tuple] = None,
exclude_labels: Union[List, set, tuple] = None,
) -> List[Image.Image]:
annotator = Annotator(image)
for obj in annotation:
if (exclude_labels is not None and obj.label in exclude_labels) or (
include_labels is not None and obj.label not in include_labels
):
continue
if obj.bbox:
annotator.box_label(obj.bbox, obj.label, color="red")
# TODO: Add polygon support
return annotator.result()
def infer(self, images: List[Image.Image], kwargs) -> List[List[Target]]:
"""The model inference step. Only support OD type detection.
Args:
images (List[Image.Image]): The list of input images. Image should be PIL Image object.
kwargs (dict): The additional arguments for the model.
Returns:
List[List[Target]]: The detection results.
"""
def ainfer(self, images: List[Image.Image], kwargs) -> List[List[Target]]:
"""The async version of model inference step. Only support OD type detection.
Args:
images (List[Image.Image]): The list of input images. Image should be PIL Image object.
kwargs (dict): The additional arguments for the model.
Returns:
List[List[Target]]: The detection results.
"""
class MemoryTool(BaseTool):
memory_handler: Optional[SQLDataHandler]
def generate_schema(self) -> dict:
"""Generate the data table schema in dict format.
Returns:
dict: The data table schema. Including the table name, and the name, data type and additional information of each column.
"""
table = self.memory_handler.table
schema = {"table_name": table.__tablename__, "columns": []}
for column in table.__table__.columns:
schema["columns"].append(
{
"name": column.name,
"type": column.type.__visit_name__,
"info": column.info,
}
)
return schema
def generate_prompt(self):
pass
def _run(self):
self.memory_handler.execute_sql()
async def _arun(self):
self.memory_handler.execute_sql()