Spaces:
Sleeping
Sleeping
import json | |
from abc import ABC | |
from distutils.util import strtobool | |
from pathlib import Path | |
from typing import Any, Callable, Dict, List, Optional, Union | |
import yaml | |
from omagent_core.base import BotBase | |
from omagent_core.models.od.schemas import Target | |
from omagent_core.services.handlers.sql_data_handler import SQLDataHandler | |
from omagent_core.utils.error import VQLError | |
from omagent_core.utils.logger import logging | |
from omagent_core.utils.plot import Annotator | |
from PIL import Image | |
from pydantic import BaseModel, model_validator | |
class ArgSchema(BaseModel): | |
"""ArgSchema defines the tool input schema. Only support one layer definition. Please prevent using complex structure.""" | |
class Config: | |
"""Configuration for this pydantic object.""" | |
extra = "allow" | |
arbitrary_types_allowed = True | |
class ArgInfo(BaseModel): | |
description: Optional[str] | |
type: str = "str" | |
enum: Optional[List] = None | |
required: Optional[bool] = True | |
def validate_all(cls, values): | |
for key, value in values.items(): | |
if type(value) is str: | |
values[key] = cls.ArgInfo(name=value) | |
elif type(value) is dict: | |
values[key] = cls.ArgInfo(**value) | |
elif type(value) is cls.ArgInfo: | |
pass | |
else: | |
raise ValueError( | |
"The arg type must be one of string, dict or self.ArgInfo." | |
) | |
return values | |
def from_file(cls, schema_file: Union[str, Path]): | |
if type(schema_file) is str: | |
schema_file = Path(schema_file) | |
if schema_file.suffix == ".json": | |
with open(schema_file, "r") as f: | |
schema = json.load(f) | |
elif schema_file.suffix == ".yaml": | |
with open(schema_file, "r") as f: | |
schema = yaml.load(f, Loader=yaml.FullLoader) | |
else: | |
raise ValueError("Only support json and yaml file.") | |
return cls(**schema) | |
def generate_schema(self) -> Union[dict, list]: | |
required_args = [] | |
parameters = {} | |
for key, value in self.model_dump(exclude_none=True).items(): | |
parameters[key] = value | |
if parameters[key].pop("required"): | |
required_args.append(key) | |
return parameters, required_args | |
def validate_args(self, args: dict) -> dict: | |
if type(args) is not dict: | |
raise ValueError( | |
"ArgSchema validate only support dict, not {}".format(type(args)) | |
) | |
new_args = {} | |
required_fields = set( | |
[k for k, v in self.model_dump().items() if v["required"]] | |
) | |
name_mapping = { | |
"str": "string", | |
"int": "integer", | |
"float": "number", | |
"bool": "boolean", | |
} | |
for name, value in args.items(): | |
if name not in self.model_dump(): | |
logging.warning( | |
"The input args includes an unnecessary parameter {}. Removed from the args.".format( | |
name | |
) | |
) | |
continue | |
if name_mapping[type(value).__name__] == self.model_dump()[name]["type"]: | |
if ( | |
self.model_dump()[name]["enum"] | |
and value not in self.model_dump()[name]["enum"] | |
): | |
raise ValueError( | |
"The value of {} should be one of {}, but got {}".format( | |
name, str(self.model_dump()[name]["enum"]), value | |
) | |
) | |
new_args[name] = value | |
elif self.model_dump()[name]["type"] == "string": | |
try: | |
new_args[name] = str(value) | |
except: | |
raise ValueError( | |
"Parameter {} type expect a str value, but got a {} {}".format( | |
name, type(value), value | |
) | |
) | |
elif self.model_dump()[name]["type"] == "integer": | |
try: | |
new_args[name] = int(value) | |
except: | |
raise ValueError( | |
"Parameter {} type expect an int value, but got a {} {}".format( | |
name, type(value), value | |
) | |
) | |
elif self.model_dump()[name]["type"] == "number": | |
try: | |
new_args[name] = float(value) | |
except: | |
raise ValueError( | |
"Parameter {} type expect a float value, but got a {} {}".format( | |
name, type(value), value | |
) | |
) | |
elif self.model_dump()[name]["type"] == "boolean": | |
if type(value) is bool: | |
new_args[name] = value | |
else: | |
try: | |
new_args[name] = strtobool(str(value)) | |
except: | |
raise ValueError( | |
"Parameter {} type expect a boolean value, but got a {} {}".format( | |
name, type(value), value | |
) | |
) | |
else: | |
raise ValueError( | |
"Parameter {} type expect one of string, integer, number and boolean, but got a {} {}".format( | |
name, self.model_dump()[name]["type"], type(value), value | |
) | |
) | |
if required_fields - set(new_args.keys()): | |
raise VQLError( | |
"The required fields {} are missing.".format( | |
required_fields - set(new_args.keys()) | |
) | |
) | |
return new_args | |
class BaseTool(BotBase, ABC): | |
description: str | |
func: Optional[Callable] = None | |
args_schema: Optional[ArgSchema] | |
special_params: Dict = {} | |
def model_post_init(self, __context: Any) -> None: | |
for _, attr_value in self.__dict__.items(): | |
if isinstance(attr_value, BotBase): | |
attr_value._parent = self | |
def workflow_instance_id(self) -> str: | |
if hasattr(self, "_parent"): | |
return self._parent.workflow_instance_id | |
return None | |
def workflow_instance_id(self, value: str): | |
if hasattr(self, "_parent"): | |
self._parent.workflow_instance_id = value | |
def _run(self, **input) -> str: | |
"""Implement this function or pass 'func' arg when initializing.""" | |
return self.func(**input) | |
async def _arun(self, **input) -> str: | |
"""Implement this function or pass 'func' arg when initializing.""" | |
return await self.func(**input) | |
def run(self, input: Any) -> str: | |
if self.args_schema != None: | |
if type(input) != dict: | |
raise ValueError( | |
"The input type must be dict when args_schema is specified." | |
) | |
self.args_schema.validate_args(input) | |
return self._run(**input, **self.special_params) | |
async def arun(self, input: Any) -> str: | |
if self.args_schema != None: | |
if type(input) != dict: | |
raise ValueError( | |
"The input type must be dict when args_schema is specified." | |
) | |
self.args_schema.validate_args(input) | |
return await self._arun(**input, **self.special_params) | |
def generate_schema(self): | |
if not self.args_schema: | |
return { | |
"type": "function", | |
"description": self.description, | |
"function": { | |
"name": self.name, | |
"parameters": { | |
"type": "object", | |
"name": "input", | |
"required": ["input"], | |
}, | |
}, | |
} | |
else: | |
properties, required = self.args_schema.generate_schema() | |
return { | |
"type": "function", | |
"function": { | |
"name": self.name, | |
"description": self.description, | |
"parameters": { | |
"type": "object", | |
"properties": properties, | |
"required": required, | |
}, | |
}, | |
} | |
class BaseModelTool(BaseTool, ABC): | |
# data_handler: Optional[SQLDataHandler] | |
def visual_prompting( | |
self, | |
image: Image.Image, | |
annotation: List[Target], | |
prompting_type: str = "label_on_img", | |
include_labels: Union[List, set, tuple] = None, | |
exclude_labels: Union[List, set, tuple] = None, | |
) -> List[Image.Image]: | |
annotator = Annotator(image) | |
for obj in annotation: | |
if (exclude_labels is not None and obj.label in exclude_labels) or ( | |
include_labels is not None and obj.label not in include_labels | |
): | |
continue | |
if obj.bbox: | |
annotator.box_label(obj.bbox, obj.label, color="red") | |
# TODO: Add polygon support | |
return annotator.result() | |
def infer(self, images: List[Image.Image], kwargs) -> List[List[Target]]: | |
"""The model inference step. Only support OD type detection. | |
Args: | |
images (List[Image.Image]): The list of input images. Image should be PIL Image object. | |
kwargs (dict): The additional arguments for the model. | |
Returns: | |
List[List[Target]]: The detection results. | |
""" | |
def ainfer(self, images: List[Image.Image], kwargs) -> List[List[Target]]: | |
"""The async version of model inference step. Only support OD type detection. | |
Args: | |
images (List[Image.Image]): The list of input images. Image should be PIL Image object. | |
kwargs (dict): The additional arguments for the model. | |
Returns: | |
List[List[Target]]: The detection results. | |
""" | |
class MemoryTool(BaseTool): | |
memory_handler: Optional[SQLDataHandler] | |
def generate_schema(self) -> dict: | |
"""Generate the data table schema in dict format. | |
Returns: | |
dict: The data table schema. Including the table name, and the name, data type and additional information of each column. | |
""" | |
table = self.memory_handler.table | |
schema = {"table_name": table.__tablename__, "columns": []} | |
for column in table.__table__.columns: | |
schema["columns"].append( | |
{ | |
"name": column.name, | |
"type": column.type.__visit_name__, | |
"info": column.info, | |
} | |
) | |
return schema | |
def generate_prompt(self): | |
pass | |
def _run(self): | |
self.memory_handler.execute_sql() | |
async def _arun(self): | |
self.memory_handler.execute_sql() | |