import json
from enum import IntEnum

# import re
from typing import Any, Callable, List, Optional

from lagent.prompts.parsers import StrParser
from lagent.utils import create_object, load_class_from_string


def default_plugin_validate(plugin: str):
    plugin = plugin.strip()
    if not (plugin.startswith('{') and plugin.endswith("}")):
        raise json.decoder.JSONDecodeError
    return json.loads(plugin)


class ToolStatusCode(IntEnum):
    NO_TOOL = 0
    VALID_TOOL = 1
    PARSING_ERROR = -1


class ToolParser(StrParser):

    def __init__(self,
                 tool_type: str,
                 template: str = '',
                 begin: str = '<tool>\n',
                 end: str = '</tool>\n',
                 validate: Callable[[str], Any] = None,
                 **kwargs):
        super().__init__(template, begin=begin, end=end, **kwargs)
        self.template = template
        self.tool_type = tool_type
        # self.pattern = re.compile(
        #     '(.*?){}(.*)({})?'.format(re.escape(begin), re.escape(end)),
        #     re.DOTALL)
        self.validate = load_class_from_string(validate) if isinstance(
            validate, str) else validate

    def parse_response(self, data: str) -> dict:
        if self.format_field['begin'] not in data:
            return dict(
                tool_type=None,
                thought=data,
                action=None,
                status=ToolStatusCode.NO_TOOL)
        thought, action, *_ = data.split(self.format_field["begin"])
        action = action.split(self.format_field['end'])[0]
        status = ToolStatusCode.VALID_TOOL
        if self.validate:
            try:
                action = self.validate(action)
            except Exception:
                status = ToolStatusCode.PARSING_ERROR
        return dict(
            tool_type=self.tool_type,
            thought=thought,
            action=action,
            status=status)

    def format_response(self, parsed: dict) -> str:
        if parsed['action'] is None:
            return parsed['thought']
        assert parsed['tool_type'] == self.tool_type
        if isinstance(parsed['action'], dict):
            action = json.dumps(parsed['action'], ensure_ascii=False)
        else:
            action = str(parsed['action'])
        return parsed['thought'] + self.format_field[
            'begin'] + action + self.format_field['end']


class InterpreterParser(ToolParser):

    def __init__(self,
                 tool_type: str = 'interpreter',
                 template: str = '',
                 begin: str = '<|action_start|><|interpreter|>\n',
                 end: str = '<|action_end|>\n',
                 validate: Callable[[str], Any] = None,
                 **kwargs):
        super().__init__(tool_type, template, begin, end, validate, **kwargs)


class PluginParser(ToolParser):

    def __init__(self,
                 tool_type: str = 'plugin',
                 template: str = '',
                 begin: str = '<|action_start|><|plugin|>\n',
                 end: str = '<|action_end|>\n',
                 validate: Callable[[str], Any] = default_plugin_validate,
                 **kwargs):
        super().__init__(tool_type, template, begin, end, validate, **kwargs)


class MixedToolParser(StrParser):

    def __init__(self,
                 tool_type: Optional[str] = None,
                 template='',
                 parsers: List[ToolParser] = None,
                 **format_field):
        self.parsers = {}
        self.tool_type = tool_type
        for parser in parsers or []:
            parser = create_object(parser)
            self.parsers[parser.tool_type] = parser
        super().__init__(template, **format_field)

    def format_instruction(self) -> List[dict]:
        inst = []
        content = super().format_instruction()
        if content.strip():
            msg = dict(role='system', content=content)
            if self.tool_type:
                msg['name'] = self.tool_type
            inst.append(msg)
        for name, parser in self.parsers.items():
            content = parser.format_instruction()
            if content.strip():
                inst.append(dict(role='system', content=content, name=name))
        return inst

    def parse_response(self, data: str) -> dict:
        res = dict(
            tool_type=None,
            thought=data,
            action=None,
            status=ToolStatusCode.NO_TOOL)
        for name, parser in self.parsers.items():
            res = parser.parse_response(data)
            if res['tool_type'] == name:
                break
        return res

    def format_response(self, parsed: dict) -> str:
        if parsed['action'] is None:
            return parsed['thought']
        assert parsed['tool_type'] in self.parsers
        return self.parsers[parsed['tool_type']].format_response(parsed)