alphagenome_agent / managers /tools /tool_registry.py
Paper2Agent's picture
Upload 56 files
8b54db1 verified
raw
history blame
10.9 kB
"""
Tool Registry System for CodeAct - Based on biomni's approach.
Provides centralized tool management with schema validation and discovery.
"""
import ast
import importlib
import importlib.util
import inspect
import os
import pickle
from typing import Any, Dict, List, Optional, Callable
import pandas as pd
__all__ = ["ToolRegistry", "tool", "discover_tools_in_module", "create_module2api_from_functions"]
class ToolRegistry:
"""
Central registry for managing tools, similar to biomni's ToolRegistry.
Handles tool registration, validation, and lookup by name/ID.
"""
def __init__(self, module2api: Optional[Dict[str, List[Dict]]] = None):
"""
Initialize the tool registry.
Args:
module2api: Dictionary mapping module names to lists of tool schemas
e.g., {"module.tools": [{"name": "func1", "description": "...", ...}]}
"""
self.tools = []
self.next_id = 0
self._name_to_function = {} # Map tool names to actual functions
# Register tools from module2api if provided
if module2api:
for module_name, tool_list in module2api.items():
for tool_schema in tool_list:
self.register_tool(tool_schema, module_name)
# Create document dataframe for retrieval (similar to biomni)
self._create_document_df()
def register_tool(self, tool_schema: Dict, module_name: Optional[str] = None) -> bool:
"""
Register a tool with the registry.
Args:
tool_schema: Tool schema dictionary with name, description, parameters
module_name: Optional module name where the tool function is located
Returns:
True if registration successful, False otherwise
"""
if not self.validate_tool(tool_schema):
raise ValueError(f"Invalid tool format for {tool_schema.get('name', 'unknown')}")
# Add unique ID
tool_schema = tool_schema.copy()
tool_schema["id"] = self.next_id
tool_schema["module_name"] = module_name
# Try to load the actual function if module_name provided
if module_name and "name" in tool_schema:
try:
function = self._load_function_from_module(module_name, tool_schema["name"])
if function:
self._name_to_function[tool_schema["name"]] = function
except Exception as e:
print(f"Warning: Could not load function {tool_schema['name']} from {module_name}: {e}")
self.tools.append(tool_schema)
self.next_id += 1
return True
def validate_tool(self, tool_schema: Dict) -> bool:
"""Validate that a tool schema has required fields."""
required_keys = ["name", "description"]
return all(key in tool_schema for key in required_keys)
def get_tool_by_name(self, name: str) -> Optional[Dict]:
"""Get tool schema by name."""
for tool in self.tools:
if tool["name"] == name:
return tool
return None
def get_tool_by_id(self, tool_id: int) -> Optional[Dict]:
"""Get tool schema by ID."""
for tool in self.tools:
if tool["id"] == tool_id:
return tool
return None
def get_function_by_name(self, name: str) -> Optional[Callable]:
"""Get the actual function object by name."""
return self._name_to_function.get(name)
def get_all_functions(self) -> Dict[str, Callable]:
"""Get all registered functions as a dictionary."""
return self._name_to_function.copy()
def list_tools(self) -> List[Dict]:
"""List all registered tools with basic info."""
return [{"name": tool["name"], "id": tool["id"], "description": tool["description"]}
for tool in self.tools]
def list_tool_names(self) -> List[str]:
"""Get list of all tool names."""
return [tool["name"] for tool in self.tools]
def add_function_directly(self, name: str, function: Callable, description: str = None) -> bool:
"""
Add a function directly to the registry.
Args:
name: Function name
function: The callable function
description: Optional description, will be extracted from docstring if not provided
Returns:
True if added successfully
"""
if description is None:
description = function.__doc__ or f"Function {name}"
# Create schema from function signature
schema = self._create_schema_from_function(name, function, description)
# Register the tool
self.register_tool(schema)
self._name_to_function[name] = function
return True
def _create_schema_from_function(self, name: str, function: Callable, description: str) -> Dict:
"""Create a tool schema from a function object."""
sig = inspect.signature(function)
schema = {
"name": name,
"description": description,
"required_parameters": [],
"optional_parameters": []
}
for param_name, param in sig.parameters.items():
param_info = {
"name": param_name,
"type": self._get_param_type(param),
"description": f"Parameter {param_name}"
}
if param.default == inspect.Parameter.empty:
schema["required_parameters"].append(param_info)
else:
param_info["default"] = param.default
schema["optional_parameters"].append(param_info)
return schema
def _get_param_type(self, param: inspect.Parameter) -> str:
"""Extract parameter type as string."""
if param.annotation != inspect.Parameter.empty:
if hasattr(param.annotation, '__name__'):
return param.annotation.__name__
else:
return str(param.annotation)
return "Any"
def _load_function_from_module(self, module_name: str, function_name: str) -> Optional[Callable]:
"""Load a function from a module."""
try:
module = importlib.import_module(module_name)
return getattr(module, function_name, None)
except (ImportError, AttributeError):
return None
def _create_document_df(self):
"""Create a pandas DataFrame for tool retrieval (similar to biomni)."""
docs = []
for tool in self.tools:
doc_content = {
"name": tool["name"],
"description": tool["description"],
"required_parameters": tool.get("required_parameters", []),
"optional_parameters": tool.get("optional_parameters", []),
"module_name": tool.get("module_name", "")
}
docs.append([tool["id"], doc_content])
self.document_df = pd.DataFrame(docs, columns=["docid", "document_content"])
def remove_tool_by_name(self, name: str) -> bool:
"""Remove a tool by name."""
tool = self.get_tool_by_name(name)
if tool:
self.tools = [t for t in self.tools if t["name"] != name]
self._name_to_function.pop(name, None)
self._create_document_df() # Refresh document df
return True
return False
def save_registry(self, filename: str):
"""Save registry to file."""
with open(filename, "wb") as file:
pickle.dump(self, file)
@staticmethod
def load_registry(filename: str) -> 'ToolRegistry':
"""Load registry from file."""
with open(filename, "rb") as file:
return pickle.load(file)
# ====================
# TOOL DECORATOR
# ====================
def tool(func: Callable = None, *, name: str = None, description: str = None):
"""
Decorator to mark functions as tools, similar to biomni's @tool decorator.
Usage:
@tool
def my_function(x: int) -> int:
'''This function does something'''
return x * 2
@tool(name="custom_name", description="Custom description")
def another_function():
pass
"""
def decorator(f):
# Store metadata on the function
f._tool_name = name or f.__name__
f._tool_description = description or f.__doc__ or f"Function {f.__name__}"
f._is_tool = True
return f
if func is None:
# Called with arguments: @tool(name="...", description="...")
return decorator
else:
# Called without arguments: @tool
return decorator(func)
# ====================
# TOOL DISCOVERY UTILITIES
# ====================
def discover_tools_in_module(module_path: str) -> List[Callable]:
"""
Discover all functions marked with @tool decorator in a module.
Args:
module_path: Path to the Python module file
Returns:
List of function objects marked as tools
"""
with open(module_path, 'r') as file:
tree = ast.parse(file.read(), filename=module_path)
tool_function_names = []
# Find functions with @tool decorator
for node in ast.walk(tree):
if isinstance(node, ast.FunctionDef):
for decorator in node.decorator_list:
if (isinstance(decorator, ast.Name) and decorator.id == "tool") or \
(isinstance(decorator, ast.Call) and
isinstance(decorator.func, ast.Name) and decorator.func.id == "tool"):
tool_function_names.append(node.name)
break
# Import the module and get function objects
spec = importlib.util.spec_from_file_location("temp_module", module_path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
tool_functions = []
for name in tool_function_names:
func = getattr(module, name, None)
if func and hasattr(func, '_is_tool'):
tool_functions.append(func)
return tool_functions
def create_module2api_from_functions(functions: List[Callable], module_name: str = "custom_tools") -> Dict[str, List[Dict]]:
"""
Create a module2api dictionary from a list of functions.
Args:
functions: List of function objects
module_name: Name to assign to the module
Returns:
Dictionary in module2api format
"""
tool_schemas = []
for func in functions:
name = getattr(func, '_tool_name', func.__name__)
description = getattr(func, '_tool_description', func.__doc__ or f"Function {func.__name__}")
# Create schema from function signature
registry = ToolRegistry()
schema = registry._create_schema_from_function(name, func, description)
tool_schemas.append(schema)
return {module_name: tool_schemas}