|
import pdb
|
|
|
|
import pyperclip
|
|
from typing import Optional, Type, Callable, Dict, Any, Union, Awaitable, TypeVar
|
|
from pydantic import BaseModel
|
|
from browser_use.agent.views import ActionResult
|
|
from browser_use.browser.context import BrowserContext
|
|
from browser_use.controller.service import Controller, DoneAction
|
|
from browser_use.controller.registry.service import Registry, RegisteredAction
|
|
from main_content_extractor import MainContentExtractor
|
|
from browser_use.controller.views import (
|
|
ClickElementAction,
|
|
DoneAction,
|
|
ExtractPageContentAction,
|
|
GoToUrlAction,
|
|
InputTextAction,
|
|
OpenTabAction,
|
|
ScrollAction,
|
|
SearchGoogleAction,
|
|
SendKeysAction,
|
|
SwitchTabAction,
|
|
)
|
|
import logging
|
|
import inspect
|
|
import asyncio
|
|
import os
|
|
from langchain_core.language_models.chat_models import BaseChatModel
|
|
from browser_use.agent.views import ActionModel, ActionResult
|
|
|
|
from src.utils.mcp_client import create_tool_param_model, setup_mcp_client_and_tools
|
|
|
|
from browser_use.utils import time_execution_sync
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
Context = TypeVar('Context')
|
|
|
|
|
|
class CustomController(Controller):
|
|
def __init__(self, exclude_actions: list[str] = [],
|
|
output_model: Optional[Type[BaseModel]] = None,
|
|
ask_assistant_callback: Optional[Union[Callable[[str, BrowserContext], Dict[str, Any]], Callable[
|
|
[str, BrowserContext], Awaitable[Dict[str, Any]]]]] = None,
|
|
):
|
|
super().__init__(exclude_actions=exclude_actions, output_model=output_model)
|
|
self._register_custom_actions()
|
|
self.ask_assistant_callback = ask_assistant_callback
|
|
self.mcp_client = None
|
|
self.mcp_server_config = None
|
|
|
|
def _register_custom_actions(self):
|
|
"""Register all custom browser actions"""
|
|
|
|
@self.registry.action(
|
|
"When executing tasks, prioritize autonomous completion. However, if you encounter a definitive blocker "
|
|
"that prevents you from proceeding independently – such as needing credentials you don't possess, "
|
|
"requiring subjective human judgment, needing a physical action performed, encountering complex CAPTCHAs, "
|
|
"or facing limitations in your capabilities – you must request human assistance."
|
|
)
|
|
async def ask_for_assistant(query: str, browser: BrowserContext):
|
|
if self.ask_assistant_callback:
|
|
if inspect.iscoroutinefunction(self.ask_assistant_callback):
|
|
user_response = await self.ask_assistant_callback(query, browser)
|
|
else:
|
|
user_response = self.ask_assistant_callback(query, browser)
|
|
msg = f"AI ask: {query}. User response: {user_response['response']}"
|
|
logger.info(msg)
|
|
return ActionResult(extracted_content=msg, include_in_memory=True)
|
|
else:
|
|
return ActionResult(extracted_content="Human cannot help you. Please try another way.",
|
|
include_in_memory=True)
|
|
|
|
@self.registry.action(
|
|
'Upload file to interactive element with file path ',
|
|
)
|
|
async def upload_file(index: int, path: str, browser: BrowserContext, available_file_paths: list[str]):
|
|
if path not in available_file_paths:
|
|
return ActionResult(error=f'File path {path} is not available')
|
|
|
|
if not os.path.exists(path):
|
|
return ActionResult(error=f'File {path} does not exist')
|
|
|
|
dom_el = await browser.get_dom_element_by_index(index)
|
|
|
|
file_upload_dom_el = dom_el.get_file_upload_element()
|
|
|
|
if file_upload_dom_el is None:
|
|
msg = f'No file upload element found at index {index}'
|
|
logger.info(msg)
|
|
return ActionResult(error=msg)
|
|
|
|
file_upload_el = await browser.get_locate_element(file_upload_dom_el)
|
|
|
|
if file_upload_el is None:
|
|
msg = f'No file upload element found at index {index}'
|
|
logger.info(msg)
|
|
return ActionResult(error=msg)
|
|
|
|
try:
|
|
await file_upload_el.set_input_files(path)
|
|
msg = f'Successfully uploaded file to index {index}'
|
|
logger.info(msg)
|
|
return ActionResult(extracted_content=msg, include_in_memory=True)
|
|
except Exception as e:
|
|
msg = f'Failed to upload file to index {index}: {str(e)}'
|
|
logger.info(msg)
|
|
return ActionResult(error=msg)
|
|
|
|
@time_execution_sync('--act')
|
|
async def act(
|
|
self,
|
|
action: ActionModel,
|
|
browser_context: Optional[BrowserContext] = None,
|
|
|
|
page_extraction_llm: Optional[BaseChatModel] = None,
|
|
sensitive_data: Optional[Dict[str, str]] = None,
|
|
available_file_paths: Optional[list[str]] = None,
|
|
|
|
context: Context | None = None,
|
|
) -> ActionResult:
|
|
"""Execute an action"""
|
|
|
|
try:
|
|
for action_name, params in action.model_dump(exclude_unset=True).items():
|
|
if params is not None:
|
|
if action_name.startswith("mcp"):
|
|
|
|
logger.debug(f"Invoke MCP tool: {action_name}")
|
|
mcp_tool = self.registry.registry.actions.get(action_name).function
|
|
result = await mcp_tool.ainvoke(params)
|
|
else:
|
|
result = await self.registry.execute_action(
|
|
action_name,
|
|
params,
|
|
browser=browser_context,
|
|
page_extraction_llm=page_extraction_llm,
|
|
sensitive_data=sensitive_data,
|
|
available_file_paths=available_file_paths,
|
|
context=context,
|
|
)
|
|
|
|
if isinstance(result, str):
|
|
return ActionResult(extracted_content=result)
|
|
elif isinstance(result, ActionResult):
|
|
return result
|
|
elif result is None:
|
|
return ActionResult()
|
|
else:
|
|
raise ValueError(f'Invalid action result type: {type(result)} of {result}')
|
|
return ActionResult()
|
|
except Exception as e:
|
|
raise e
|
|
|
|
async def setup_mcp_client(self, mcp_server_config: Optional[Dict[str, Any]] = None):
|
|
self.mcp_server_config = mcp_server_config
|
|
if self.mcp_server_config:
|
|
self.mcp_client = await setup_mcp_client_and_tools(self.mcp_server_config)
|
|
self.register_mcp_tools()
|
|
|
|
def register_mcp_tools(self):
|
|
"""
|
|
Register the MCP tools used by this controller.
|
|
"""
|
|
if self.mcp_client:
|
|
for server_name in self.mcp_client.server_name_to_tools:
|
|
for tool in self.mcp_client.server_name_to_tools[server_name]:
|
|
tool_name = f"mcp.{server_name}.{tool.name}"
|
|
self.registry.registry.actions[tool_name] = RegisteredAction(
|
|
name=tool_name,
|
|
description=tool.description,
|
|
function=tool,
|
|
param_model=create_tool_param_model(tool),
|
|
)
|
|
logger.info(f"Add mcp tool: {tool_name}")
|
|
logger.debug(
|
|
f"Registered {len(self.mcp_client.server_name_to_tools[server_name])} mcp tools for {server_name}")
|
|
else:
|
|
logger.warning(f"MCP client not started.")
|
|
|
|
async def close_mcp_client(self):
|
|
if self.mcp_client:
|
|
await self.mcp_client.__aexit__(None, None, None)
|
|
|