Paper2Agent's picture
Upload 56 files
8b54db1 verified
raw
history blame
18.8 kB
"""
Tool Manager for CodeAct Agent.
Unified management system for all types of tools: local functions, decorated tools, and MCP tools.
"""
from typing import Dict, List, Optional, Callable, Any, Union
from enum import Enum
from dataclasses import dataclass
import re
# Import components
from .tool_registry import ToolRegistry, create_module2api_from_functions
from .mcp_manager import MCPManager
class ToolSource(Enum):
"""Enumeration of tool sources."""
LOCAL = "local"
DECORATED = "decorated"
MCP = "mcp"
ALL = "all"
@dataclass
class ToolInfo:
"""Comprehensive tool information."""
name: str
description: str
source: ToolSource
function: Optional[Callable] = None
schema: Optional[Dict] = None
server: Optional[str] = None # For MCP tools
module: Optional[str] = None
required_parameters: List[Dict] = None
optional_parameters: List[Dict] = None
def __post_init__(self):
if self.required_parameters is None:
self.required_parameters = []
if self.optional_parameters is None:
self.optional_parameters = []
class ToolManager:
"""
Unified tool management system for CodeAct Agent.
Manages all types of tools:
- Local functions (legacy function registry)
- Decorated tools (@tool decorator)
- MCP tools (Model Context Protocol)
Provides a single, consistent interface for tool operations.
"""
def __init__(self, console_display=None):
"""
Initialize the ToolManager.
Args:
console_display: Optional console display for MCP status output
"""
# Core components
self.tool_registry = ToolRegistry()
self.mcp_manager = MCPManager(console_display)
# Unified tool catalog
self._tool_catalog: Dict[str, ToolInfo] = {}
# Legacy function registry (for backward compatibility)
self._legacy_functions: Dict[str, Callable] = {}
# Initialize with decorated tools from function_tools.py
self._discover_decorated_tools()
# ====================
# CORE TOOL MANAGEMENT
# ====================
def add_tool(self, tool: Union[Callable, Dict], name: str = None,
description: str = None, source: ToolSource = ToolSource.LOCAL) -> bool:
"""
Add a tool to the manager.
Args:
tool: Either a callable function or a tool schema dict
name: Optional custom name (defaults to function.__name__)
description: Optional description (defaults to function.__doc__)
source: Source type of the tool
Returns:
True if successfully added
"""
try:
if callable(tool):
# Handle callable functions
tool_name = name or tool.__name__
tool_desc = description or tool.__doc__ or f"Function {tool.__name__}"
# Add to tool registry
success = self.tool_registry.add_function_directly(tool_name, tool, tool_desc)
if success:
# Create ToolInfo and add to catalog
tool_info = ToolInfo(
name=tool_name,
description=tool_desc,
source=source,
function=tool,
schema=self._create_schema_from_function(tool_name, tool, tool_desc)
)
self._tool_catalog[tool_name] = tool_info
if source == ToolSource.LOCAL:
self._legacy_functions[tool_name] = tool
return True
elif isinstance(tool, dict):
# Handle tool schema dictionaries
tool_name = tool.get("name") or name
tool_desc = tool.get("description") or description
if not tool_name:
print("Warning: Tool schema must have a name")
return False
# Register with tool registry
success = self.tool_registry.register_tool(tool)
if success:
tool_info = ToolInfo(
name=tool_name,
description=tool_desc,
source=source,
schema=tool,
required_parameters=tool.get("required_parameters", []),
optional_parameters=tool.get("optional_parameters", [])
)
self._tool_catalog[tool_name] = tool_info
return True
return False
except Exception as e:
print(f"Error adding tool {name}: {e}")
return False
def remove_tool(self, name: str) -> bool:
"""Remove a tool by name from all registries."""
try:
success = False
# Remove from tool registry
if self.tool_registry.remove_tool_by_name(name):
success = True
# Remove from MCP if it's an MCP tool
if name in self.mcp_manager.mcp_functions:
if self.mcp_manager.remove_mcp_tool(name, self.tool_registry):
success = True
# Remove from legacy functions
if name in self._legacy_functions:
del self._legacy_functions[name]
success = True
# Remove from catalog
if name in self._tool_catalog:
del self._tool_catalog[name]
success = True
return success
except Exception as e:
print(f"Error removing tool {name}: {e}")
return False
def get_tool(self, name: str) -> Optional[ToolInfo]:
"""Get comprehensive tool information by name."""
return self._tool_catalog.get(name)
def get_tool_function(self, name: str) -> Optional[Callable]:
"""Get the actual function object by name."""
tool_info = self.get_tool(name)
if tool_info and tool_info.function:
return tool_info.function
# Check tool registry
return self.tool_registry.get_function_by_name(name)
# ====================
# TOOL DISCOVERY AND LISTING
# ====================
def list_tools(self, source: ToolSource = ToolSource.ALL,
include_details: bool = False) -> List[Dict]:
"""
List tools with optional filtering by source.
Args:
source: Filter by tool source (LOCAL, DECORATED, MCP, ALL)
include_details: Whether to include detailed information
Returns:
List of tool dictionaries
"""
tools = []
for tool_name, tool_info in self._tool_catalog.items():
if source == ToolSource.ALL or tool_info.source == source:
if include_details:
tools.append({
"name": tool_info.name,
"description": tool_info.description,
"source": tool_info.source.value,
"server": tool_info.server,
"module": tool_info.module,
"has_function": tool_info.function is not None,
"required_params": len(tool_info.required_parameters),
"optional_params": len(tool_info.optional_parameters)
})
else:
tools.append({
"name": tool_info.name,
"description": tool_info.description,
"source": tool_info.source.value
})
return sorted(tools, key=lambda x: x["name"])
def search_tools(self, query: str, search_descriptions: bool = True) -> List[Dict]:
"""
Search tools by name and optionally description.
Args:
query: Search query (supports regex)
search_descriptions: Whether to also search in descriptions
Returns:
List of matching tools
"""
pattern = re.compile(query, re.IGNORECASE)
matching_tools = []
for tool_name, tool_info in self._tool_catalog.items():
match = False
# Search in name
if pattern.search(tool_name):
match = True
# Search in description if enabled
elif search_descriptions and pattern.search(tool_info.description or ""):
match = True
if match:
matching_tools.append({
"name": tool_info.name,
"description": tool_info.description,
"source": tool_info.source.value,
"server": tool_info.server
})
return sorted(matching_tools, key=lambda x: x["name"])
def get_tools_by_source(self, source: ToolSource) -> Dict[str, ToolInfo]:
"""Get all tools from a specific source."""
return {
name: tool_info
for name, tool_info in self._tool_catalog.items()
if tool_info.source == source
}
# ====================
# MCP INTEGRATION
# ====================
def add_mcp_server(self, config_path: str = "./mcp_config.yaml") -> None:
"""Add MCP tools from configuration file."""
try:
# Use MCP manager to load tools
self.mcp_manager.add_mcp(config_path, self.tool_registry)
# Update our catalog with MCP tools
mcp_tools = self.mcp_manager.list_mcp_tools()
for tool_name, tool_data in mcp_tools.items():
tool_info = ToolInfo(
name=tool_name,
description=tool_data.get("description", "MCP tool"),
source=ToolSource.MCP,
function=tool_data.get("function"),
server=tool_data.get("server"),
module=tool_data.get("module"),
required_parameters=tool_data.get("required_parameters", []),
optional_parameters=tool_data.get("optional_parameters", [])
)
self._tool_catalog[tool_name] = tool_info
except Exception as e:
print(f"Error adding MCP server: {e}")
def list_mcp_servers(self) -> Dict[str, List[str]]:
"""List all MCP servers and their tools."""
mcp_tools = self.get_tools_by_source(ToolSource.MCP)
servers = {}
for tool_name, tool_info in mcp_tools.items():
server_name = tool_info.server or "unknown"
if server_name not in servers:
servers[server_name] = []
servers[server_name].append(tool_name)
return servers
def show_mcp_status(self) -> None:
"""Display detailed MCP status."""
self.mcp_manager.show_mcp_status()
def get_mcp_summary(self) -> Dict[str, Any]:
"""Get MCP tools summary."""
return self.mcp_manager.get_mcp_summary()
# ====================
# TOOL EXECUTION SUPPORT
# ====================
def get_all_functions(self) -> Dict[str, Callable]:
"""Get all available functions as a dictionary."""
functions = {}
# Add from tool registry
functions.update(self.tool_registry.get_all_functions())
# Add from legacy functions
functions.update(self._legacy_functions)
# Add MCP functions
mcp_tools = self.mcp_manager.list_mcp_tools()
for tool_name, tool_data in mcp_tools.items():
if tool_data.get("function"):
functions[tool_name] = tool_data["function"]
return functions
def get_tool_schemas(self, openai_format: bool = True) -> List[Dict]:
"""
Get tool schemas for all tools.
Args:
openai_format: Whether to format as OpenAI function schemas
Returns:
List of tool schemas
"""
schemas = []
for tool_name, tool_info in self._tool_catalog.items():
if openai_format:
# Convert to OpenAI function schema format
schema = {
"type": "function",
"function": {
"name": tool_info.name,
"description": tool_info.description,
"parameters": {
"type": "object",
"properties": {},
"required": []
}
}
}
# Add required parameters
for param in tool_info.required_parameters:
param_schema = {
"type": param.get("type", "string"),
"description": param.get("description", "")
}
# Add enum values if present
if "enum" in param:
param_schema["enum"] = param["enum"]
schema["function"]["parameters"]["properties"][param["name"]] = param_schema
schema["function"]["parameters"]["required"].append(param["name"])
# Add optional parameters
for param in tool_info.optional_parameters:
param_schema = {
"type": param.get("type", "string"),
"description": param.get("description", "")
}
# Add enum values if present
if "enum" in param:
param_schema["enum"] = param["enum"]
if "default" in param:
param_schema["default"] = param["default"]
schema["function"]["parameters"]["properties"][param["name"]] = param_schema
schemas.append(schema)
else:
# Return raw schema
if tool_info.schema:
schemas.append(tool_info.schema)
return schemas
# ====================
# STATISTICS AND REPORTING
# ====================
def get_tool_statistics(self) -> Dict[str, Any]:
"""Get comprehensive tool statistics."""
stats = {
"total_tools": len(self._tool_catalog),
"by_source": {source.value: 0 for source in ToolSource if source != ToolSource.ALL},
"with_functions": 0,
"mcp_servers": len(self.list_mcp_servers()),
"tool_registry_size": len(self.tool_registry.tools),
"legacy_functions": len(self._legacy_functions)
}
for tool_info in self._tool_catalog.values():
stats["by_source"][tool_info.source.value] += 1
if tool_info.function:
stats["with_functions"] += 1
return stats
def validate_tools(self) -> Dict[str, List[str]]:
"""Validate all tools and return any issues found."""
issues = {
"missing_functions": [],
"missing_descriptions": [],
"duplicate_names": [],
"invalid_schemas": []
}
seen_names = set()
for tool_name, tool_info in self._tool_catalog.items():
# Check for duplicates
if tool_name in seen_names:
issues["duplicate_names"].append(tool_name)
seen_names.add(tool_name)
# Check for missing functions (except MCP tools which may not have direct functions)
if not tool_info.function and tool_info.source != ToolSource.MCP:
issues["missing_functions"].append(tool_name)
# Check for missing descriptions
if not tool_info.description or tool_info.description.strip() == "":
issues["missing_descriptions"].append(tool_name)
return issues
# ====================
# PRIVATE METHODS
# ====================
def _discover_decorated_tools(self) -> None:
"""Discover and register tools marked with @tool decorator."""
try:
from .builtin_tools import get_all_tool_functions
tool_functions = get_all_tool_functions()
for func in tool_functions:
name = getattr(func, '_tool_name', func.__name__)
description = getattr(func, '_tool_description', func.__doc__ or f"Function {func.__name__}")
tool_info = ToolInfo(
name=name,
description=description,
source=ToolSource.DECORATED,
function=func,
schema=self._create_schema_from_function(name, func, description)
)
self._tool_catalog[name] = tool_info
# Also add to tool registry for consistency
self.tool_registry.add_function_directly(name, func, description)
except ImportError:
print("Warning: Could not import builtin_tools module for decorated tool discovery")
def _create_schema_from_function(self, name: str, function: Callable, description: str) -> Dict:
"""Create a tool schema from a function object."""
return self.tool_registry._create_schema_from_function(name, function, description)
def _refresh_catalog(self) -> None:
"""Refresh the tool catalog from all sources."""
# Clear current catalog
self._tool_catalog.clear()
# Re-discover decorated tools
self._discover_decorated_tools()
# Re-add MCP tools
mcp_tools = self.mcp_manager.list_mcp_tools()
for tool_name, tool_data in mcp_tools.items():
tool_info = ToolInfo(
name=tool_name,
description=tool_data.get("description", "MCP tool"),
source=ToolSource.MCP,
function=tool_data.get("function"),
server=tool_data.get("server"),
module=tool_data.get("module"),
required_parameters=tool_data.get("required_parameters", []),
optional_parameters=tool_data.get("optional_parameters", [])
)
self._tool_catalog[tool_name] = tool_info
# ====================
# LEGACY COMPATIBILITY
# ====================
def add_legacy_functions(self, functions: Dict[str, Callable]) -> int:
"""Add legacy functions for backward compatibility."""
added_count = 0
for name, func in functions.items():
if self.add_tool(func, name, source=ToolSource.LOCAL):
added_count += 1
return added_count