Spaces:
Running
Running
| """ | |
| CodeAgent: A LangGraph-based agent for executing Python code and using tools. | |
| Fully modular version with unified tool management. | |
| """ | |
| import os | |
| import re | |
| import time | |
| from typing import Dict, List, Optional | |
| from jinja2 import Template | |
| from langchain_core.language_models.chat_models import BaseChatModel | |
| from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage | |
| from langchain_openai import ChatOpenAI | |
| from dotenv import load_dotenv | |
| # Import core types and constants | |
| from core.types import AgentState, AgentConfig | |
| from core.constants import SYSTEM_PROMPT_TEMPLATE | |
| # Import managers (organized by subsystem) | |
| from managers import ( | |
| # Support | |
| PackageManager, | |
| ConsoleDisplay, | |
| # Workflow | |
| PlanManager, | |
| StateManager, | |
| WorkflowEngine, | |
| # Tools | |
| ToolManager, | |
| ToolSource, | |
| ToolSelector, | |
| # Execution | |
| Timing, | |
| PythonExecutor | |
| ) | |
| # Load environment variables | |
| load_dotenv("./.env") | |
| def get_system_prompt(functions: Dict[str, dict], packages: Dict[str, str] = None) -> str: | |
| """Generate system prompt using template and functions.""" | |
| if packages is None: | |
| from core.constants import LIBRARY_CONTENT_DICT | |
| packages = LIBRARY_CONTENT_DICT | |
| return Template(SYSTEM_PROMPT_TEMPLATE).render(functions=functions, packages=packages) | |
| class CodeAgent: | |
| """A code-based agent that can execute Python code and use tools to solve tasks.""" | |
| def __init__(self, model: BaseChatModel, | |
| config: Optional[AgentConfig] = None, | |
| use_tool_manager: bool = True, | |
| use_tool_selection: bool = True): | |
| """ | |
| Initialize the CodeAgent with unified tool management. | |
| Args: | |
| model: The language model to use for generation | |
| config: Configuration for the agent | |
| use_tool_manager: Whether to use the unified ToolManager (recommended) | |
| use_tool_selection: Whether to use LLM-based tool selection (like Biomni) | |
| """ | |
| self.model = model | |
| self.config = config or AgentConfig() | |
| self.use_tool_manager = use_tool_manager | |
| self.use_tool_selection = use_tool_selection | |
| # Cache selected tools to avoid re-selection at each step | |
| self._selected_tools_cache = None | |
| # Initialize modular components | |
| self.package_manager = PackageManager() | |
| self.console = ConsoleDisplay() | |
| self.state_manager = StateManager() | |
| self.plan_manager = PlanManager() | |
| # Initialize unified tool management | |
| if not self.use_tool_manager: | |
| raise ValueError("ToolManager is required. Legacy mode (use_tool_manager=False) has been removed.") | |
| self.tool_manager = ToolManager(self.console) | |
| # Initialize tool selector for LLM-based tool selection | |
| if self.use_tool_selection: | |
| self.tool_selector = ToolSelector(self.model) | |
| else: | |
| self.tool_selector = None | |
| # Initialize workflow engine | |
| self.workflow_engine = WorkflowEngine(model, self.config, self.console, self.state_manager) | |
| # Initialize Python executor | |
| self.python_executor = PythonExecutor() | |
| # Setup workflow | |
| self._setup_workflow() | |
| # ==================== | |
| # WORKFLOW SETUP | |
| # ==================== | |
| def _setup_workflow(self): | |
| """Setup the LangGraph workflow using WorkflowEngine.""" | |
| self.workflow_engine.setup_workflow( | |
| self.generate, | |
| self.execute, | |
| self.should_continue | |
| ) | |
| # ==================== | |
| # WORKFLOW NODES | |
| # ==================== | |
| def generate(self, state: AgentState) -> AgentState: | |
| """Generate response using LLM with tool-aware prompt.""" | |
| # Get all available tools first | |
| all_schemas = self.tool_manager.get_tool_schemas(openai_format=True) | |
| all_functions_dict = {schema['function']['name']: schema for schema in all_schemas} | |
| # Use tool selection if enabled and not cached | |
| if self.use_tool_selection and self.tool_selector and state.get("messages") and self._selected_tools_cache is None: | |
| # Get the user's query from the first message | |
| user_query = "" | |
| for msg in state["messages"]: | |
| if hasattr(msg, 'content') and msg.content: | |
| user_query = msg.content | |
| break | |
| if user_query: | |
| # Prepare tools for selection (convert schemas to tool info format) | |
| available_tools = {} | |
| for tool_name, schema in all_functions_dict.items(): | |
| available_tools[tool_name] = { | |
| 'description': schema['function'].get('description', 'No description'), | |
| 'source': 'tool_manager' # Could be enhanced to show actual source | |
| } | |
| # Select relevant tools using LLM (only once) | |
| selected_tool_names = self.tool_selector.select_tools_for_task( | |
| user_query, available_tools, max_tools=15 | |
| ) | |
| # Cache the selected tools | |
| self._selected_tools_cache = {name: all_functions_dict[name] | |
| for name in selected_tool_names | |
| if name in all_functions_dict} | |
| self.console.console.print(f"π― Selected {len(self._selected_tools_cache)} tools from {len(all_functions_dict)} available tools (cached for session)") | |
| functions_dict = self._selected_tools_cache | |
| else: | |
| functions_dict = all_functions_dict | |
| elif self.use_tool_selection and self._selected_tools_cache is not None: | |
| # Use cached selected tools | |
| functions_dict = self._selected_tools_cache | |
| else: | |
| # No tool selection or selection disabled | |
| functions_dict = all_functions_dict | |
| all_packages = self.package_manager.get_all_packages() | |
| system_prompt = get_system_prompt(functions_dict, all_packages) | |
| # Truncate conversation history to prevent context overflow | |
| messages = [SystemMessage(content=system_prompt)] + state["messages"] | |
| response = self.model.invoke(messages) | |
| # Cut the text after the </execute> tag, while keeping the </execute> tag | |
| if "</execute>" in response.content: | |
| response.content = response.content.split("</execute>")[0] + "</execute>" | |
| # Parse the response | |
| msg = str(response.content) | |
| llm_reply = AIMessage(content=msg.strip()) | |
| # Update step count | |
| new_step_count = state.get("step_count", 0) + 1 | |
| return self.state_manager.create_state_dict( | |
| messages=[llm_reply], | |
| step_count=new_step_count, | |
| error_count=state.get("error_count", 0), | |
| start_time=state.get("start_time", time.time()), | |
| current_plan=self._extract_current_plan(msg) | |
| ) | |
| def _extract_current_plan(self, content: str) -> Optional[str]: | |
| """Extract the current plan from the agent's response.""" | |
| return self.plan_manager.extract_plan_from_content(content) | |
| def execute(self, state: AgentState) -> AgentState: | |
| """Execute code using persistent Python executor.""" | |
| try: | |
| last_message = state["messages"][-1].content | |
| execute_match = re.search(r"<execute>(.*?)</execute>", last_message, re.DOTALL) | |
| if execute_match: | |
| code = execute_match.group(1).strip() | |
| # Execute regular code in persistent environment (tools already injected) | |
| result = self.python_executor(code) | |
| # Include both the code and result in the observation | |
| obs = f"\n<observation>\nCode Output:\n{result}</observation>" | |
| return self.state_manager.create_state_dict( | |
| messages=[AIMessage(content=obs.strip())], | |
| step_count=state.get("step_count", 0), | |
| error_count=state.get("error_count", 0), | |
| start_time=state.get("start_time", time.time()), | |
| current_plan=state.get("current_plan") | |
| ) | |
| else: | |
| return self.state_manager.create_state_dict( | |
| messages=[AIMessage(content="<error>No executable code found</error>")], | |
| step_count=state.get("step_count", 0), | |
| error_count=state.get("error_count", 0) + 1, | |
| start_time=state.get("start_time", time.time()), | |
| current_plan=state.get("current_plan") | |
| ) | |
| except Exception as e: | |
| return self.state_manager.create_state_dict( | |
| messages=[AIMessage(content=f"<error>Execution error: {str(e)}</error>")], | |
| step_count=state.get("step_count", 0), | |
| error_count=state.get("error_count", 0) + 1, | |
| start_time=state.get("start_time", time.time()), | |
| current_plan=state.get("current_plan") | |
| ) | |
| def should_continue(self, state: AgentState) -> str: | |
| """Decide whether to continue executing or end the workflow.""" | |
| last_message = state["messages"][-1].content | |
| step_count = state.get("step_count", 0) | |
| error_count = state.get("error_count", 0) | |
| start_time = state.get("start_time", time.time()) | |
| # Check for timeout | |
| if time.time() - start_time > self.config.timeout_seconds: | |
| return "end" | |
| # Check for maximum steps | |
| if step_count >= self.config.max_steps: | |
| return "end" | |
| # Check for too many errors | |
| if error_count >= self.config.retry_attempts: | |
| return "end" | |
| # Check if the finish() tool has been called | |
| if "<solution>" in last_message and "</solution>" in last_message: | |
| return "end" | |
| # Check if there's an execute tag in the last message | |
| elif "<execute>" in last_message and "</execute>" in last_message: | |
| return "execute" | |
| else: | |
| return "end" | |
| # ==================== | |
| # PACKAGE MANAGEMENT - Delegated to PackageManager | |
| # ==================== | |
| def add_packages(self, packages: Dict[str, str]) -> bool: | |
| """Add new packages to the available packages.""" | |
| return self.package_manager.add_packages(packages) | |
| def get_all_packages(self) -> Dict[str, str]: | |
| """Get all available packages (default + custom).""" | |
| return self.package_manager.get_all_packages() | |
| # ==================== | |
| # UNIFIED TOOL MANAGEMENT - Delegated to ToolManager | |
| # ==================== | |
| def add_tool(self, function: callable, name: str = None, description: str = None) -> bool: | |
| """Add a tool function to the manager.""" | |
| return self.tool_manager.add_tool(function, name, description, ToolSource.LOCAL) | |
| def remove_tool(self, name: str) -> bool: | |
| """Remove a tool by name.""" | |
| return self.tool_manager.remove_tool(name) | |
| def list_tools(self, source: str = "all", include_details: bool = False) -> List[Dict]: | |
| """List all available tools with optional filtering.""" | |
| source_enum = ToolSource.ALL | |
| if source.lower() in ["local", "decorated", "mcp"]: | |
| source_enum = ToolSource(source.lower()) | |
| return self.tool_manager.list_tools(source_enum, include_details) | |
| def search_tools(self, query: str) -> List[Dict]: | |
| """Search tools by name and description.""" | |
| return self.tool_manager.search_tools(query) | |
| def get_tool_info(self, name: str) -> Optional[Dict]: | |
| """Get detailed information about a specific tool.""" | |
| tool_info = self.tool_manager.get_tool(name) | |
| if tool_info: | |
| return { | |
| "name": tool_info.name, | |
| "description": tool_info.description, | |
| "source": tool_info.source.value, | |
| "server": tool_info.server, | |
| "module": tool_info.module, | |
| "has_function": tool_info.function is not None, | |
| "required_parameters": tool_info.required_parameters, | |
| "optional_parameters": tool_info.optional_parameters | |
| } | |
| return None | |
| def get_all_tool_functions(self) -> Dict[str, callable]: | |
| """Get all tool functions as a dictionary.""" | |
| return self.tool_manager.get_all_functions() | |
| # ==================== | |
| # MCP METHODS - Now delegated to ToolManager | |
| # ==================== | |
| def add_mcp(self, config_path: str = "./mcp_config.yaml") -> None: | |
| """Add MCP tools from configuration file.""" | |
| self.tool_manager.add_mcp_server(config_path) | |
| def list_mcp_tools(self) -> List[Dict]: | |
| """List all loaded MCP tools.""" | |
| return self.tool_manager.list_tools(self.tool_manager.ToolSource.MCP) | |
| def list_mcp_servers(self) -> Dict[str, List[str]]: | |
| """List all MCP servers and their tools.""" | |
| return self.tool_manager.list_mcp_servers() | |
| def show_mcp_status(self) -> None: | |
| """Display detailed MCP status information to the user.""" | |
| self.tool_manager.show_mcp_status() | |
| def get_mcp_summary(self) -> Dict[str, any]: | |
| """Get a summary of MCP tools for programmatic access.""" | |
| return self.tool_manager.get_mcp_summary() | |
| # ==================== | |
| # ENHANCED TOOL FEATURES | |
| # ==================== | |
| def get_tool_statistics(self) -> Dict[str, any]: | |
| """Get comprehensive tool statistics.""" | |
| return self.tool_manager.get_tool_statistics() | |
| def validate_tools(self) -> Dict[str, List[str]]: | |
| """Validate all tools and return any issues.""" | |
| return self.tool_manager.validate_tools() | |
| # ==================== | |
| # TOOL SELECTION MANAGEMENT | |
| # ==================== | |
| def reset_tool_selection(self): | |
| """Reset the cached tool selection to allow re-selection on next query.""" | |
| self._selected_tools_cache = None | |
| if self.use_tool_selection: | |
| self.console.console.print("π Tool selection cache cleared - will re-select tools on next query") | |
| def get_selected_tools(self): | |
| """Get the currently selected tools (if any).""" | |
| return list(self._selected_tools_cache.keys()) if self._selected_tools_cache else None | |
| # ==================== | |
| # TRACE AND SUMMARY METHODS | |
| # ==================== | |
| def get_trace(self) -> Dict: | |
| """Get the complete trace of the last execution.""" | |
| if not self.workflow_engine: | |
| return {} | |
| return { | |
| "execution_time": time.strftime('%Y-%m-%d %H:%M:%S'), | |
| "config": { | |
| "max_steps": self.config.max_steps, | |
| "timeout_seconds": self.config.timeout_seconds, | |
| "verbose": self.config.verbose | |
| }, | |
| "messages": self.workflow_engine.message_history, | |
| "trace_logs": self.workflow_engine.trace_logs | |
| } | |
| def get_summary(self) -> Dict: | |
| """Get a summary of the last execution.""" | |
| if not self.workflow_engine: | |
| return {} | |
| return self.workflow_engine.generate_summary() | |
| def save_trace(self, filepath: str = None) -> str: | |
| """Save the trace of the last execution to a file.""" | |
| if not self.workflow_engine: | |
| raise RuntimeError("No workflow engine available") | |
| return self.workflow_engine.save_trace_to_file(filepath) | |
| def save_summary(self, filepath: str = None) -> str: | |
| """Save the summary of the last execution to a file.""" | |
| if not self.workflow_engine: | |
| raise RuntimeError("No workflow engine available") | |
| return self.workflow_engine.save_summary_to_file(filepath) | |
| # ==================== | |
| # PUBLIC INTERFACE | |
| # ==================== | |
| def run(self, query: str, save_trace: bool = False, save_summary: bool = False, | |
| trace_dir: str = "traces") -> str: | |
| """ | |
| Run the agent with a given query using modular components. | |
| Args: | |
| query: The task/question to solve | |
| save_trace: Whether to save the complete trace to a file | |
| save_summary: Whether to save the execution summary to a file | |
| trace_dir: Directory to save trace and summary files | |
| Returns: | |
| The final response content | |
| """ | |
| # Start timing the overall execution | |
| overall_timing = Timing(start_time=time.time()) | |
| # Display task header | |
| self.console.print_task_header(query) | |
| # Initialize agent with functions using ToolManager | |
| functions_dict = self.get_all_tool_functions() | |
| # Display enhanced tool information | |
| # Get detailed tool statistics | |
| stats = self.tool_manager.get_tool_statistics() | |
| mcp_servers = self.tool_manager.list_mcp_servers() | |
| self.console.console.print(f"π οΈ Loaded {stats['total_tools']} total tools:") | |
| if stats['by_source']['local'] > 0: | |
| self.console.console.print(f" π Local tools: {stats['by_source']['local']}") | |
| if stats['by_source']['decorated'] > 0: | |
| self.console.console.print(f" π― Decorated tools: {stats['by_source']['decorated']}") | |
| if stats['by_source']['mcp'] > 0: | |
| self.console.console.print(f" π MCP tools: {stats['by_source']['mcp']} from {len(mcp_servers)} servers") | |
| for server_name, tools in mcp_servers.items(): | |
| self.console.console.print(f" β’ {server_name}: {len(tools)} tools") | |
| # Inject functions into Python executor | |
| self.python_executor.send_functions(functions_dict) | |
| # Import available packages using PackageManager | |
| imported_packages, failed_packages = self.package_manager.import_packages(self.python_executor) | |
| self.console.print_packages_info(imported_packages, failed_packages) | |
| # Inject any initial variables | |
| state_variables = {} | |
| self.python_executor.send_variables(state_variables) | |
| # Create initial state using StateManager | |
| input_state = self.state_manager.create_state_dict( | |
| messages=[HumanMessage(content=query)], | |
| step_count=0, | |
| error_count=0, | |
| start_time=time.time(), | |
| current_plan=None | |
| ) | |
| # Execute workflow using WorkflowEngine and get result with final state | |
| result, final_state = self.workflow_engine.run_workflow(input_state) | |
| # Complete overall timing and display summary | |
| overall_timing.end_time = time.time() | |
| # Extract final state information for summary | |
| final_step_count = final_state.get("step_count", 0) if final_state else 0 | |
| final_error_count = final_state.get("error_count", 0) if final_state else 0 | |
| self.console.print_execution_summary(final_step_count, final_error_count, overall_timing.duration) | |
| # Save trace and summary if requested | |
| if save_trace or save_summary: | |
| # Create trace directory if it doesn't exist | |
| from pathlib import Path | |
| trace_path = Path(trace_dir) | |
| trace_path.mkdir(parents=True, exist_ok=True) | |
| if save_trace: | |
| trace_file = trace_path / f"agent_trace_{time.strftime('%Y%m%d_%H%M%S')}.json" | |
| saved_trace = self.workflow_engine.save_trace_to_file(str(trace_file)) | |
| self.console.console.print(f"πΎ Trace saved to: {saved_trace}") | |
| if save_summary: | |
| summary_file = trace_path / f"agent_summary_{time.strftime('%Y%m%d_%H%M%S')}.json" | |
| saved_summary = self.workflow_engine.save_summary_to_file(str(summary_file)) | |
| self.console.console.print(f"π Summary saved to: {saved_summary}") | |
| return result | |
| # ==================== | |
| # EXAMPLE USAGE | |
| # ==================== | |
| if __name__ == "__main__": | |
| # Example usage of the fully modular CodeAgent architecture | |
| # Create LLM client | |
| model = ChatOpenAI( | |
| model="google/gemini-2.5-flash", | |
| base_url="https://openrouter.ai/api/v1", | |
| temperature=0.7, | |
| api_key=os.environ["OPENROUTER_API_KEY"], | |
| ) | |
| model = ChatAnthropic(model='claude-sonnet-4-5-20250929') | |
| # Create configuration | |
| config = AgentConfig( | |
| max_steps=15, | |
| max_conversation_length=30, | |
| retry_attempts=3, | |
| timeout_seconds=1200, | |
| verbose=True | |
| ) | |
| # Create agent with unified tool management and LLM-based tool selection | |
| agent = CodeAgent(model=model, config=config, use_tool_manager=True, use_tool_selection=True) | |
| # Demonstrate tool management capabilities | |
| print("\nπ§ Tool Management Demo:") | |
| # Show tool statistics | |
| stats = agent.get_tool_statistics() | |
| print(f"π Tool Statistics: {stats}") | |
| # Add MCP tools | |
| try: | |
| print("π§ Loading MCP tools...") | |
| script_dir = os.path.dirname(os.path.abspath(__file__)) | |
| config_path = os.path.join(script_dir, "mcp_config.yaml") | |
| agent.add_mcp(config_path) | |
| print("β MCP tools loaded successfully!") | |
| # Show enhanced MCP status | |
| agent.show_mcp_status() | |
| # Show all available tools by source | |
| all_tools = agent.list_tools("all", include_details=True) | |
| print(f"\nπ All tools loaded:") | |
| for tool in all_tools: | |
| print(f" β’ {tool['name']} ({tool['source']}) - {tool['description'][:50]}...") | |
| except Exception as e: | |
| print(f"β οΈ Could not load MCP tools: {e}") | |
| # Validate tools | |
| issues = agent.validate_tools() | |
| if any(issues.values()): | |
| print(f"β οΈ Tool validation issues: {issues}") | |
| else: | |
| print("β All tools validated successfully!") | |
| # List available packages | |
| print(f"\nπ¦ Available packages: {list(agent.get_all_packages().keys())}") | |
| # Run the agent with a query and save trace/summary | |
| print("\nπ Running agent with trace and summary saving...") | |
| result = agent.run( | |
| query=""" | |
| Use AlphaGenome MCP to analyze heart gene expression data to identify the causal gene | |
| for the variant chr11:116837649:T>G, associated with Hypoalphalipoproteinemia. My API key is: AIzaSyD1USDNy9WqfIROICB3FWI1wJHmkO2z21U. | |
| """, | |
| save_trace=True, # Save complete execution trace | |
| save_summary=True, # Save execution summary | |
| trace_dir="traces" # Directory to save files | |
| ) | |
| # You can also access trace and summary programmatically | |
| print("\nπ Execution Summary:") | |
| summary = agent.get_summary() | |
| print(f" Total steps: {summary.get('total_steps', 0)}") | |
| print(f" Code executions: {len(summary.get('code_executions', []))}") | |
| print(f" Observations: {len(summary.get('observations', []))}") | |
| print(f" Errors: {len(summary.get('errors', []))}") | |
| # You can save trace/summary manually after execution | |
| # agent.save_trace("custom_trace.json") | |
| # agent.save_summary("custom_summary.json") |