import inspect
from collections import OrderedDict
from typing import Callable, Dict, List, Union

from lagent.actions.base_action import BaseAction
from lagent.actions.builtin_actions import FinishAction, InvalidAction, NoAction
from lagent.hooks import Hook, RemovableHandle
from lagent.schema import ActionReturn, ActionValidCode, AgentMessage, FunctionCall
from lagent.utils import create_object


class ActionExecutor:
    """The action executor class.

    Args:
        actions (Union[BaseAction, List[BaseAction]]): The action or actions.
        invalid_action (BaseAction, optional): The invalid action. Defaults to
            InvalidAction().
        no_action (BaseAction, optional): The no action.
            Defaults to NoAction().
        finish_action (BaseAction, optional): The finish action. Defaults to
            FinishAction().
        finish_in_action (bool, optional): Whether the finish action is in the
            action list. Defaults to False.
    """

    def __init__(
        self,
        actions: Union[BaseAction, List[BaseAction], Dict, List[Dict]],
        invalid_action: BaseAction = dict(type=InvalidAction),
        no_action: BaseAction = dict(type=NoAction),
        finish_action: BaseAction = dict(type=FinishAction),
        finish_in_action: bool = False,
        hooks: List[Dict] = None,
    ):

        if not isinstance(actions, list):
            actions = [actions]
        finish_action = create_object(finish_action)
        if finish_in_action:
            actions.append(finish_action)
        for i, action in enumerate(actions):
            actions[i] = create_object(action)
        self.actions = {action.name: action for action in actions}

        self.invalid_action = create_object(invalid_action)
        self.no_action = create_object(no_action)
        self.finish_action = finish_action
        self._hooks: Dict[int, Hook] = OrderedDict()
        if hooks:
            for hook in hooks:
                hook = create_object(hook)
                self.register_hook(hook)

    def description(self) -> List[Dict]:
        actions = []
        for action_name, action in self.actions.items():
            if action.is_toolkit:
                for api in action.description['api_list']:
                    api_desc = api.copy()
                    api_desc['name'] = f"{action_name}.{api_desc['name']}"
                    actions.append(api_desc)
            else:
                action_desc = action.description.copy()
                actions.append(action_desc)
        return actions

    def __contains__(self, name: str):
        return name in self.actions

    def keys(self):
        return list(self.actions.keys())

    def __setitem__(self, name: str, action: Union[BaseAction, Dict]):
        action = create_object(action)
        self.actions[action.name] = action

    def __delitem__(self, name: str):
        del self.actions[name]

    def forward(self, name, parameters, **kwargs) -> ActionReturn:
        action_name, api_name = (
            name.split('.') if '.' in name else (name, 'run'))
        action_return: ActionReturn = ActionReturn()
        if action_name not in self:
            if name == self.no_action.name:
                action_return = self.no_action(parameters)
            elif name == self.finish_action.name:
                action_return = self.finish_action(parameters)
            else:
                action_return = self.invalid_action(parameters)
        else:
            action_return = self.actions[action_name](parameters, api_name)
            action_return.valid = ActionValidCode.OPEN
        return action_return

    def __call__(self,
                 message: AgentMessage,
                 session_id=0,
                 **kwargs) -> AgentMessage:
        # message.receiver = self.name
        for hook in self._hooks.values():
            result = hook.before_action(self, message, session_id)
            if result:
                message = result

        assert isinstance(message.content, FunctionCall) or (
            isinstance(message.content, dict) and 'name' in message.content
            and 'parameters' in message.content)
        if isinstance(message.content, dict):
            name = message.content.get('name')
            parameters = message.content.get('parameters')
        else:
            name = message.content.name
            parameters = message.content.parameters

        response_message = self.forward(
            name=name, parameters=parameters, **kwargs)
        if not isinstance(response_message, AgentMessage):
            response_message = AgentMessage(
                sender=self.__class__.__name__,
                content=response_message,
            )

        for hook in self._hooks.values():
            result = hook.after_action(self, response_message, session_id)
            if result:
                response_message = result
        return response_message

    def register_hook(self, hook: Callable):
        handle = RemovableHandle(self._hooks)
        self._hooks[handle.id] = hook
        return handle


class AsyncActionExecutor(ActionExecutor):

    async def forward(self, name, parameters, **kwargs) -> ActionReturn:
        action_name, api_name = (
            name.split('.') if '.' in name else (name, 'run'))
        action_return: ActionReturn = ActionReturn()
        if action_name not in self:
            if name == self.no_action.name:
                action_return = self.no_action(parameters)
            elif name == self.finish_action.name:
                action_return = self.finish_action(parameters)
            else:
                action_return = self.invalid_action(parameters)
        else:
            action = self.actions[action_name]
            if inspect.iscoroutinefunction(action.__call__):
                action_return = await action(parameters, api_name)
            else:
                action_return = action(parameters, api_name)
            action_return.valid = ActionValidCode.OPEN
        return action_return

    async def __call__(self,
                       message: AgentMessage,
                       session_id=0,
                       **kwargs) -> AgentMessage:
        # message.receiver = self.name
        for hook in self._hooks.values():
            if inspect.iscoroutinefunction(hook.before_action):
                result = await hook.before_action(self, message, session_id)
            else:
                result = hook.before_action(self, message, session_id)
            if result:
                message = result

        assert isinstance(message.content, FunctionCall) or (
            isinstance(message.content, dict) and 'name' in message.content
            and 'parameters' in message.content)
        if isinstance(message.content, dict):
            name = message.content.get('name')
            parameters = message.content.get('parameters')
        else:
            name = message.content.name
            parameters = message.content.parameters

        response_message = await self.forward(
            name=name, parameters=parameters, **kwargs)
        if not isinstance(response_message, AgentMessage):
            response_message = AgentMessage(
                sender=self.__class__.__name__,
                content=response_message,
            )

        for hook in self._hooks.values():
            if inspect.iscoroutinefunction(hook.after_action):
                result = await hook.after_action(self, response_message,
                                                 session_id)
            else:
                result = hook.after_action(self, response_message, session_id)
            if result:
                response_message = result
        return response_message