import inspect
import logging
import re
from abc import ABCMeta
from copy import deepcopy
from functools import wraps
from typing import Callable, Optional, Type, get_args, get_origin

try:
    from typing import Annotated
except ImportError:
    from typing_extensions import Annotated

from griffe import Docstring

try:
    from griffe import DocstringSectionKind
except ImportError:
    from griffe.enumerations import DocstringSectionKind

from ..schema import ActionReturn, ActionStatusCode
from .parser import BaseParser, JsonParser, ParseError

logging.getLogger('griffe').setLevel(logging.ERROR)


def tool_api(func: Optional[Callable] = None,
             *,
             explode_return: bool = False,
             returns_named_value: bool = False,
             **kwargs):
    """Turn functions into tools. It will parse typehints as well as docstrings
    to build the tool description and attach it to functions via an attribute
    ``api_description``.

    Examples:

        .. code-block:: python

            # typehints has higher priority than docstrings
            from typing import Annotated

            @tool_api
            def add(a: Annotated[int, 'augend'], b: Annotated[int, 'addend'] = 1):
                '''Add operation

                Args:
                    x (int): a
                    y (int): b
                '''
                return a + b

            print(add.api_description)

    Args:
        func (Optional[Callable]): function to decorate. Defaults to ``None``.
        explode_return (bool): whether to flatten the dictionary or tuple return
            as the ``return_data`` field. When enabled, it is recommended to
            annotate the member in docstrings. Defaults to ``False``.

            .. code-block:: python

                @tool_api(explode_return=True)
                def foo(a, b):
                    '''A simple function

                    Args:
                        a (int): a
                        b (int): b

                    Returns:
                        dict: information of inputs
                            * x: value of a
                            * y: value of b
                    '''
                    return {'x': a, 'y': b}

                print(foo.api_description)

        returns_named_value (bool): whether to parse ``thing: Description`` in
            returns sections as a name and description, rather than a type and
            description. When true, type must be wrapped in parentheses:
            ``(int): Description``. When false, parentheses are optional but
            the items cannot be named: ``int: Description``. Defaults to ``False``.

    Returns:
        Callable: wrapped function or partial decorator

    Important:
        ``return_data`` field will be added to ``api_description`` only
        when ``explode_return`` or ``returns_named_value`` is enabled.
    """

    def _detect_type(string):
        field_type = 'STRING'
        if 'list' in string:
            field_type = 'Array'
        elif 'str' not in string:
            if 'float' in string:
                field_type = 'FLOAT'
            elif 'int' in string:
                field_type = 'NUMBER'
            elif 'bool' in string:
                field_type = 'BOOLEAN'
        return field_type

    def _explode(desc):
        kvs = []
        desc = '\nArgs:\n' + '\n'.join([
            '    ' + item.lstrip(' -+*#.')
            for item in desc.split('\n')[1:] if item.strip()
        ])
        docs = Docstring(desc).parse('google')
        if not docs:
            return kvs
        if docs[0].kind is DocstringSectionKind.parameters:
            for d in docs[0].value:
                d = d.as_dict()
                if not d['annotation']:
                    d.pop('annotation')
                else:
                    d['type'] = _detect_type(d.pop('annotation').lower())
                kvs.append(d)
        return kvs

    def _parse_tool(function):
        # remove rst syntax
        docs = Docstring(
            re.sub(':(.+?):`(.+?)`', '\\2', function.__doc__ or '')).parse(
                'google', returns_named_value=returns_named_value, **kwargs)
        desc = dict(
            name=function.__name__,
            description=docs[0].value
            if docs[0].kind is DocstringSectionKind.text else '',
            parameters=[],
            required=[],
        )
        args_doc, returns_doc = {}, []
        for doc in docs:
            if doc.kind is DocstringSectionKind.parameters:
                for d in doc.value:
                    d = d.as_dict()
                    d['type'] = _detect_type(d.pop('annotation').lower())
                    args_doc[d['name']] = d
            if doc.kind is DocstringSectionKind.returns:
                for d in doc.value:
                    d = d.as_dict()
                    if not d['name']:
                        d.pop('name')
                    if not d['annotation']:
                        d.pop('annotation')
                    else:
                        d['type'] = _detect_type(d.pop('annotation').lower())
                    returns_doc.append(d)

        sig = inspect.signature(function)
        for name, param in sig.parameters.items():
            if name == 'self':
                continue
            parameter = dict(
                name=param.name,
                type='STRING',
                description=args_doc.get(param.name,
                                         {}).get('description', ''))
            annotation = param.annotation
            if annotation is inspect.Signature.empty:
                parameter['type'] = args_doc.get(param.name,
                                                 {}).get('type', 'STRING')
            else:
                if get_origin(annotation) is Annotated:
                    annotation, info = get_args(annotation)
                    if info:
                        parameter['description'] = info
                while get_origin(annotation):
                    annotation = get_args(annotation)
                parameter['type'] = _detect_type(str(annotation))
            desc['parameters'].append(parameter)
            if param.default is inspect.Signature.empty:
                desc['required'].append(param.name)

        return_data = []
        if explode_return:
            return_data = _explode(returns_doc[0]['description'])
        elif returns_named_value:
            return_data = returns_doc
        if return_data:
            desc['return_data'] = return_data
        return desc

    if callable(func):

        if inspect.iscoroutinefunction(func):

            @wraps(func)
            async def wrapper(self, *args, **kwargs):
                return await func(self, *args, **kwargs)

        else:

            @wraps(func)
            def wrapper(self, *args, **kwargs):
                return func(self, *args, **kwargs)

        wrapper.api_description = _parse_tool(func)
        return wrapper

    def decorate(func):

        if inspect.iscoroutinefunction(func):

            @wraps(func)
            async def wrapper(self, *args, **kwargs):
                return await func(self, *args, **kwargs)

        else:

            @wraps(func)
            def wrapper(self, *args, **kwargs):
                return func(self, *args, **kwargs)

        wrapper.api_description = _parse_tool(func)
        return wrapper

    return decorate


class ToolMeta(ABCMeta):
    """Metaclass of tools."""

    def __new__(mcs, name, base, attrs):
        is_toolkit, tool_desc = True, dict(
            name=name,
            description=Docstring(attrs.get('__doc__',
                                            '')).parse('google')[0].value)
        for key, value in attrs.items():
            if callable(value) and hasattr(value, 'api_description'):
                api_desc = getattr(value, 'api_description')
                if key == 'run':
                    tool_desc['parameters'] = api_desc['parameters']
                    tool_desc['required'] = api_desc['required']
                    if api_desc['description']:
                        tool_desc['description'] = api_desc['description']
                    if api_desc.get('return_data'):
                        tool_desc['return_data'] = api_desc['return_data']
                    is_toolkit = False
                else:
                    tool_desc.setdefault('api_list', []).append(api_desc)
        if not is_toolkit and 'api_list' in tool_desc:
            raise KeyError('`run` and other tool APIs can not be implemented '
                           'at the same time')
        if is_toolkit and 'api_list' not in tool_desc:
            is_toolkit = False
            if callable(attrs.get('run')):
                run_api = tool_api(attrs['run'])
                api_desc = run_api.api_description
                tool_desc['parameters'] = api_desc['parameters']
                tool_desc['required'] = api_desc['required']
                if api_desc['description']:
                    tool_desc['description'] = api_desc['description']
                if api_desc.get('return_data'):
                    tool_desc['return_data'] = api_desc['return_data']
                attrs['run'] = run_api
            else:
                tool_desc['parameters'], tool_desc['required'] = [], []
        attrs['_is_toolkit'] = is_toolkit
        attrs['__tool_description__'] = tool_desc
        return super().__new__(mcs, name, base, attrs)


class BaseAction(metaclass=ToolMeta):
    """Base class for all actions.

    Args:
        description (:class:`Optional[dict]`): The description of the action.
            Defaults to ``None``.
        parser (:class:`Type[BaseParser]`): The parser class to process the
            action's inputs and outputs. Defaults to :class:`JsonParser`.

    Examples:

        * simple tool

        .. code-block:: python

            class Bold(BaseAction):
                '''Make text bold'''

                def run(self, text: str):
                    '''
                    Args:
                        text (str): input text

                    Returns:
                        str: bold text
                    '''
                    return '**' + text + '**'

            action = Bold()

        * toolkit with multiple APIs

        .. code-block:: python

            class Calculator(BaseAction):
                '''Calculator'''

                @tool_api
                def add(self, a, b):
                    '''Add operation

                    Args:
                        a (int): augend
                        b (int): addend

                    Returns:
                        int: sum
                    '''
                    return a + b

                @tool_api
                def sub(self, a, b):
                    '''Subtraction operation

                    Args:
                        a (int): minuend
                        b (int): subtrahend

                    Returns:
                        int: difference
                    '''
                    return a - b

            action = Calculator()
    """

    def __init__(
        self,
        description: Optional[dict] = None,
        parser: Type[BaseParser] = JsonParser,
    ):
        self._description = deepcopy(description or self.__tool_description__)
        self._name = self._description['name']
        self._parser = parser(self)

    def __call__(self, inputs: str, name='run') -> ActionReturn:
        fallback_args = {'inputs': inputs, 'name': name}
        if not hasattr(self, name):
            return ActionReturn(
                fallback_args,
                type=self.name,
                errmsg=f'invalid API: {name}',
                state=ActionStatusCode.API_ERROR)
        try:
            inputs = self._parser.parse_inputs(inputs, name)
        except ParseError as exc:
            return ActionReturn(
                fallback_args,
                type=self.name,
                errmsg=exc.err_msg,
                state=ActionStatusCode.ARGS_ERROR)
        try:
            outputs = getattr(self, name)(**inputs)
        except Exception as exc:
            return ActionReturn(
                inputs,
                type=self.name,
                errmsg=str(exc),
                state=ActionStatusCode.API_ERROR)
        if isinstance(outputs, ActionReturn):
            action_return = outputs
            if not action_return.args:
                action_return.args = inputs
            if not action_return.type:
                action_return.type = self.name
        else:
            result = self._parser.parse_outputs(outputs)
            action_return = ActionReturn(inputs, type=self.name, result=result)
        return action_return

    @property
    def name(self):
        return self._name

    @property
    def is_toolkit(self):
        return self._is_toolkit

    @property
    def description(self) -> dict:
        """Description of the tool."""
        return self._description

    def __repr__(self):
        return f'{self.description}'

    __str__ = __repr__


class AsyncActionMixin:

    async def __call__(self, inputs: str, name='run') -> ActionReturn:
        fallback_args = {'inputs': inputs, 'name': name}
        if not hasattr(self, name):
            return ActionReturn(
                fallback_args,
                type=self.name,
                errmsg=f'invalid API: {name}',
                state=ActionStatusCode.API_ERROR)
        try:
            inputs = self._parser.parse_inputs(inputs, name)
        except ParseError as exc:
            return ActionReturn(
                fallback_args,
                type=self.name,
                errmsg=exc.err_msg,
                state=ActionStatusCode.ARGS_ERROR)
        try:
            outputs = await getattr(self, name)(**inputs)
        except Exception as exc:
            return ActionReturn(
                inputs,
                type=self.name,
                errmsg=str(exc),
                state=ActionStatusCode.API_ERROR)
        if isinstance(outputs, ActionReturn):
            action_return = outputs
            if not action_return.args:
                action_return.args = inputs
            if not action_return.type:
                action_return.type = self.name
        else:
            result = self._parser.parse_outputs(outputs)
            action_return = ActionReturn(inputs, type=self.name, result=result)
        return action_return