import copy
from typing import List, Union

from lagent.agents import Agent, AgentForInternLM, AsyncAgent, AsyncAgentForInternLM
from lagent.schema import AgentMessage, AgentStatusCode, ModelStatusCode


class StreamingAgentMixin:
    """Make agent calling output a streaming response."""

    def __call__(self, *message: Union[AgentMessage, List[AgentMessage]], session_id=0, **kwargs):
        for hook in self._hooks.values():
            message = copy.deepcopy(message)
            result = hook.before_agent(self, message, session_id)
            if result:
                message = result
        self.update_memory(message, session_id=session_id)
        response_message = AgentMessage(sender=self.name, content="")
        for response_message in self.forward(*message, session_id=session_id, **kwargs):
            if not isinstance(response_message, AgentMessage):
                model_state, response = response_message
                response_message = AgentMessage(
                    sender=self.name,
                    content=response,
                    stream_state=model_state,
                )
            yield response_message.model_copy()
        self.update_memory(response_message, session_id=session_id)
        for hook in self._hooks.values():
            response_message = response_message.model_copy(deep=True)
            result = hook.after_agent(self, response_message, session_id)
            if result:
                response_message = result
        yield response_message


class AsyncStreamingAgentMixin:
    """Make asynchronous agent calling output a streaming response."""

    async def __call__(
        self, *message: Union[AgentMessage, List[AgentMessage]], session_id=0, **kwargs
    ):
        for hook in self._hooks.values():
            message = copy.deepcopy(message)
            result = hook.before_agent(self, message, session_id)
            if result:
                message = result
        self.update_memory(message, session_id=session_id)
        response_message = AgentMessage(sender=self.name, content="")
        async for response_message in self.forward(*message, session_id=session_id, **kwargs):
            if not isinstance(response_message, AgentMessage):
                model_state, response = response_message
                response_message = AgentMessage(
                    sender=self.name,
                    content=response,
                    stream_state=model_state,
                )
            yield response_message.model_copy()
        self.update_memory(response_message, session_id=session_id)
        for hook in self._hooks.values():
            response_message = response_message.model_copy(deep=True)
            result = hook.after_agent(self, response_message, session_id)
            if result:
                response_message = result
        yield response_message


class StreamingAgent(StreamingAgentMixin, Agent):
    """Base streaming agent class"""

    def forward(self, *message: AgentMessage, session_id=0, **kwargs):
        formatted_messages = self.aggregator.aggregate(
            self.memory.get(session_id),
            self.name,
            self.output_format,
            self.template,
        )
        for model_state, response, _ in self.llm.stream_chat(
            formatted_messages, session_id=session_id, **kwargs
        ):
            yield AgentMessage(
                sender=self.name,
                content=response,
                formatted=self.output_format.parse_response(response),
                stream_state=model_state,
            ) if self.output_format else (model_state, response)


class AsyncStreamingAgent(AsyncStreamingAgentMixin, AsyncAgent):
    """Base asynchronous streaming agent class"""

    async def forward(self, *message: AgentMessage, session_id=0, **kwargs):
        formatted_messages = self.aggregator.aggregate(
            self.memory.get(session_id),
            self.name,
            self.output_format,
            self.template,
        )
        async for model_state, response, _ in self.llm.stream_chat(
            formatted_messages, session_id=session_id, **kwargs
        ):
            yield AgentMessage(
                sender=self.name,
                content=response,
                formatted=self.output_format.parse_response(response),
                stream_state=model_state,
            ) if self.output_format else (model_state, response)


class StreamingAgentForInternLM(StreamingAgentMixin, AgentForInternLM):
    """Streaming implementation of `lagent.agents.AgentForInternLM`"""

    _INTERNAL_AGENT_CLS = StreamingAgent

    def forward(self, message: AgentMessage, session_id=0, **kwargs):
        if isinstance(message, str):
            message = AgentMessage(sender="user", content=message)
        for _ in range(self.max_turn):
            last_agent_state = AgentStatusCode.SESSION_READY
            for message in self.agent(message, session_id=session_id, **kwargs):
                if isinstance(message.formatted, dict) and message.formatted.get("tool_type"):
                    if message.stream_state == ModelStatusCode.END:
                        message.stream_state = last_agent_state + int(
                            last_agent_state
                            in [
                                AgentStatusCode.CODING,
                                AgentStatusCode.PLUGIN_START,
                            ]
                        )
                    else:
                        message.stream_state = (
                            AgentStatusCode.PLUGIN_START
                            if message.formatted["tool_type"] == "plugin"
                            else AgentStatusCode.CODING
                        )
                else:
                    message.stream_state = AgentStatusCode.STREAM_ING
                yield message
                last_agent_state = message.stream_state
            if self.finish_condition(message):
                message.stream_state = AgentStatusCode.END
                yield message
                return
            if message.formatted["tool_type"]:
                tool_type = message.formatted["tool_type"]
                executor = getattr(self, f"{tool_type}_executor", None)
                if not executor:
                    raise RuntimeError(f"No available {tool_type} executor")
                tool_return = executor(message, session_id=session_id)
                tool_return.stream_state = message.stream_state + 1
                message = tool_return
                yield message
            else:
                message.stream_state = AgentStatusCode.STREAM_ING
                yield message


class AsyncStreamingAgentForInternLM(AsyncStreamingAgentMixin, AsyncAgentForInternLM):
    """Streaming implementation of `lagent.agents.AsyncAgentForInternLM`"""

    _INTERNAL_AGENT_CLS = AsyncStreamingAgent

    async def forward(self, message: AgentMessage, session_id=0, **kwargs):
        if isinstance(message, str):
            message = AgentMessage(sender="user", content=message)
        for _ in range(self.max_turn):
            last_agent_state = AgentStatusCode.SESSION_READY
            async for message in self.agent(message, session_id=session_id, **kwargs):
                if isinstance(message.formatted, dict) and message.formatted.get("tool_type"):
                    if message.stream_state == ModelStatusCode.END:
                        message.stream_state = last_agent_state + int(
                            last_agent_state
                            in [
                                AgentStatusCode.CODING,
                                AgentStatusCode.PLUGIN_START,
                            ]
                        )
                    else:
                        message.stream_state = (
                            AgentStatusCode.PLUGIN_START
                            if message.formatted["tool_type"] == "plugin"
                            else AgentStatusCode.CODING
                        )
                else:
                    message.stream_state = AgentStatusCode.STREAM_ING
                yield message
                last_agent_state = message.stream_state
            if self.finish_condition(message):
                message.stream_state = AgentStatusCode.END
                yield message
                return
            if message.formatted["tool_type"]:
                tool_type = message.formatted["tool_type"]
                executor = getattr(self, f"{tool_type}_executor", None)
                if not executor:
                    raise RuntimeError(f"No available {tool_type} executor")
                tool_return = await executor(message, session_id=session_id)
                tool_return.stream_state = message.stream_state + 1
                message = tool_return
                yield message
            else:
                message.stream_state = AgentStatusCode.STREAM_ING
                yield message