import html
import json
import os
import queue
import shutil
import sys
import threading
import uuid
from pathlib import Path
from time import sleep

os.environ['GRADIO_TEMP_DIR'] = os.getcwd()
video_root_path = os.path.join(os.getcwd(), 'video_root')
os.makedirs(video_root_path, exist_ok=True)

from omagent_core.clients.devices.app.callback import AppCallback
from omagent_core.clients.devices.app.input import AppInput
from omagent_core.clients.devices.app.schemas import ContentStatus, MessageType
from omagent_core.engine.automator.task_handler import TaskHandler
from omagent_core.engine.http.models.workflow_status import terminal_status
from omagent_core.engine.workflow.conductor_workflow import ConductorWorkflow
from omagent_core.services.connectors.redis import RedisConnector
from omagent_core.utils.build import build_from_file
from omagent_core.utils.container import container
from omagent_core.utils.logger import logging
from omagent_core.utils.registry import registry

registry.import_module()

container.register_connector(name="redis_stream_client", connector=RedisConnector)
# container.register_stm(stm='RedisSTM')
container.register_callback(callback=AppCallback)
container.register_input(input=AppInput)

import gradio as gr


class WebpageClient:
    def __init__(
            self,
            interactor: ConductorWorkflow = None,
            processor: ConductorWorkflow = None,
            config_path: str = "./config",
            workers: list = [],
    ) -> None:
        self._interactor = interactor
        self._processor = processor
        self._config_path = config_path
        self._workers = workers
        self._workflow_instance_id = None
        self._worker_config = build_from_file(self._config_path)
        self._task_to_domain = {}
        self._incomplete_message = ""
        self._custom_css = """
            #OmAgent {
                height: 100vh !important;
                max-height: calc(100vh - 190px) !important;
                overflow-y: auto;
            }

            .running-message {
                margin: 0;
                padding: 2px 4px;
                white-space: pre-wrap;
                word-wrap: break-word;
                font-family: inherit;
            }

            /* Remove the background and border of the message box */
            .message-wrap {
                background: none !important;
                border: none !important;
                padding: 0 !important;
                margin: 0 !important;
            }

            /* Remove the bubble style of the running message */
            .message:has(.running-message) {
                background: none !important;
                border: none !important;
                padding: 0 !important;
                box-shadow: none !important;
            }
        """
        self.workflow_instance_id = str(uuid.uuid4())
        self.processor_instance_id = str(uuid.uuid4())
        worker_config = build_from_file(self._config_path)
        self.initialization(workers, worker_config)
    
    def initialization(self, workers, worker_config):
        self.workers = {}
        for worker in workers:
            worker.workflow_instance_id = self.workflow_instance_id
            self.workers[type(worker).__name__] = worker
        
        for config in worker_config:
            worker_cls = registry.get_worker(config['name'])
            worker = worker_cls(**config)
            worker.workflow_instance_id = self.workflow_instance_id
            self.workers[config['name']] = worker
        
    def gradio_app(self):
        
        with gr.Blocks() as demo:

            def load_local_video() -> dict:
                result = {}
                for root, _, files in os.walk(video_root_path):
                    for file in filter(lambda x: x.split('.')[-1].lower() in (
                            'mp4', 'avi', 'mov', 'wmv', 'flv', 'mkv', 'webm', 'm4v'), files):
                        file_obs_path = os.path.join(root, file)
                        result[Path(file_obs_path).name] = file_obs_path
                return result
            
            video_dict = load_local_video()
            current_video = None
            state = gr.State(value={
                'video_dict': video_dict,
                'current_video': current_video
            })
            with gr.Row():
                with gr.Column():
                    with gr.Column():
                        def display_video_map(video_title):
                            # change display video
                            video_path = state.value.get('video_dict', {}).get(video_title)
                            
                            exception_queue = queue.Queue()
                            workflow_input = {'video_path': video_path}
                            processor_result = None
                            
                            def run_workflow(workflow_input):
                                nonlocal processor_result
                                try:
                                    processor_result = self._processor.start_workflow_with_input(
                                        workflow_input=workflow_input, workers=self.workers
                                    )
                                except Exception as e:
                                    exception_queue.put(e)  # add exception to queue
                                    logging.error(f"Error starting workflow: {e}")
                                    raise e
                            
                            # workflow_thread = threading.Thread(target=run_workflow, args=(workflow_input,))
                            # workflow_thread.start()
                            run_workflow(workflow_input)
                            
                            processor_workflow_instance_id = self.processor_instance_id
                            while True:
                                status = self._processor.get_workflow(
                                    workflow_id=processor_workflow_instance_id).status
                                if status in terminal_status:
                                    break
                                sleep(1)
                            
                            state.value['video_dict'] = load_local_video()
                            state.value.update(current_video=video_path)
                            state.value.update(processor_result=processor_result)
                            state.value.update(processor_workflow_instance_id=processor_workflow_instance_id)
                        
                            return video_path, state
                            
                        select_video = gr.Dropdown(
                            state.value['video_dict'].keys(),
                            value=None
                        )
                        display_video = gr.Video(
                            state.value['current_video'],
                        )
                        select_video.change(
                            fn=display_video_map,
                            inputs=[select_video],
                            outputs=[display_video, state]
                        )
                
                with gr.Column():
                    chatbot = gr.Chatbot(
                        type="messages",
                    )
                    
                    chat_input = gr.Textbox(
                        interactive=True,
                        placeholder="Enter message...",
                        show_label=False,
                    )
                    
                    chat_msg = chat_input.submit(
                        self.add_message,
                        [chatbot, chat_input, state],
                        [chatbot, chat_input]
                    )
                    bot_msg = chat_msg.then(
                        self.bot, (chatbot, state), chatbot, api_name="bot_response"
                    )
                    bot_msg.then(
                        lambda: gr.Textbox(interactive=True), None, [chat_input]
                    )
            
            demo.launch(
                max_file_size='1gb'
            )
        
    def start_interactor(self):

        try:
            self.gradio_app()
        except KeyboardInterrupt:
            logging.info("\nDetected Ctrl+C, stopping workflow...")
            if self._workflow_instance_id is not None:
                self._interactor._executor.terminate(
                    workflow_id=self._workflow_instance_id
                )
            raise
    
    def stop_interactor(self):
        # self._task_handler_interactor.stop_processes()
        print("stop_interactor")
        sys.exit(0)
        
    def start_processor(self):
        self._task_handler_processor = TaskHandler(
            worker_config=self._worker_config, workers=self._workers, task_to_domain=self._task_to_domain
        )
        self._task_handler_processor.start_processes()
        
        try:
            with gr.Blocks(title="OmAgent", css=self._custom_css) as chat_interface:
                chatbot = gr.Chatbot(
                    elem_id="OmAgent",
                    bubble_full_width=False,
                    type="messages",
                    height="100%",
                )
                
                chat_input = gr.MultimodalTextbox(
                    interactive=True,
                    file_count="multiple",
                    placeholder="Enter message or upload file...",
                    show_label=False,
                )
                
                chat_msg = chat_input.submit(
                    self.add_processor_message,
                    [chatbot, chat_input],
                    [chatbot, chat_input],
                )
                bot_msg = chat_msg.then(
                    self.processor_bot, chatbot, chatbot, api_name="bot_response"
                )
                bot_msg.then(
                    lambda: gr.MultimodalTextbox(interactive=True), None, [chat_input]
                )
            chat_interface.launch(server_port=7861)
        except KeyboardInterrupt:
            logging.info("\nDetected Ctrl+C, stopping workflow...")
            if self._workflow_instance_id is not None:
                self._processor._executor.terminate(
                    workflow_id=self._workflow_instance_id
                )
            raise
    
    def stop_processor(self):
        self._task_handler_processor.stop_processes()
    
    def add_message(self, history, message, state):
        if isinstance(state, gr.State):
            if state.value.get('current_video') is None:
                history.append({"role": "user", "content": message})
                history.append({"role": "assistant", "content": 'Please select a video'})
                return history, gr.Textbox(value=None, interactive=False)
        else:
            if not state:
                history.append({"role": "user", "content": message})
                history.append({"role": "assistant", "content": 'Please reselect the video'})
                return history, gr.Textbox(value=None, interactive=False)
            if state.get('current_video') is None:
                history.append({"role": "user", "content": message})
                history.append({"role": "assistant", "content": 'Please select a video'})
                return history, gr.Textbox(value=None, interactive=False)
        if self._workflow_instance_id is None:
            workflow_input = {
                'question': message,
                "video_md5": state.value.get('processor_result', {}).get("video_md5"),
                "video_path": state.value.get('processor_result', {}).get("video_path"),
                "instance_id": state.value.get('processor_result', {}).get("instance_id"),
                "processor_workflow_instance_id": state.value.get("processor_workflow_instance_id")
            }
            exception_queue = queue.Queue()
            
            def run_workflow(workflow_input):
                try:
                    self._interactor.start_workflow_with_input(
                        workflow_input=workflow_input, workers=self.workers
                    )
                except Exception as e:
                    exception_queue.put(e)  # add exception to queue
                    logging.error(f"Error starting workflow: {e}")
                    raise e
            
            workflow_thread = threading.Thread(target=run_workflow, args=(workflow_input,),daemon=True)
            workflow_thread.start()

            self._workflow_instance_id = self.workflow_instance_id
        contents = []
        history.append({"role": "user", "content": message})
        contents.append({"data": message, "type": "text"})
        result = {
            "agent_id": self._workflow_instance_id,
            "messages": [{"role": "user", "content": contents}],
            "kwargs": {},
        }
        container.get_connector("redis_stream_client")._client.xadd(
            f"{self._workflow_instance_id}_input",
            {"payload": json.dumps(result, ensure_ascii=False)},
        )
        return history, gr.Textbox(value=None, interactive=False)
    
    def add_processor_message(self, history, message):
        if self._workflow_instance_id is None:
            self._workflow_instance_id = self._processor.start_workflow_with_input(
                workflow_input={}, task_to_domain=self._task_to_domain
            )
        image_items = []
        for idx, x in enumerate(message["files"]):
            history.append({"role": "user", "content": {"path": x}})
            image_items.append(
                {"type": "image_url", "resource_id": str(idx), "data": str(x)}
            )
        result = {"content": image_items}
        container.get_connector("redis_stream_client")._client.xadd(
            f"image_process", {"payload": json.dumps(result, ensure_ascii=False)}
        )
        return history, gr.MultimodalTextbox(value=None, interactive=False)
    
    def bot(self, history, state):
        if isinstance(state, gr.State):
            if state.value.get('current_video') is None:
                yield history
                return
        else:
            if state.get('current_video') is None:
                yield history
                return
            
        stream_name = f"{self._workflow_instance_id}_output"
        consumer_name = f"{self._workflow_instance_id}_agent"  # consumer name
        group_name = "omappagent"  # replace with your consumer group name
        running_stream_name = f"{self._workflow_instance_id}_running"
        self._check_redis_stream_exist(stream_name, group_name)
        self._check_redis_stream_exist(running_stream_name, group_name)
        while True:
            # read running stream
            running_messages = self._get_redis_stream_message(
                group_name, consumer_name, running_stream_name
            )
            for stream, message_list in running_messages:
                for message_id, message in message_list:
                    payload_data = self._get_message_payload(message)
                    if payload_data is None:
                        continue
                    progress = html.escape(payload_data.get("progress", ""))
                    message = html.escape(payload_data.get("message", ""))
                    formatted_message = (
                        f'<pre class="running-message">{progress}: {message}</pre>'
                    )
                    history.append({"role": "assistant", "content": formatted_message})
                    yield history
                    
                    container.get_connector("redis_stream_client")._client.xack(
                        running_stream_name, group_name, message_id
                    )
            # read output stream
            messages = self._get_redis_stream_message(
                group_name, consumer_name, stream_name
            )
            finish_flag = False
            
            for stream, message_list in messages:
                for message_id, message in message_list:
                    incomplete_flag = False
                    payload_data = self._get_message_payload(message)
                    if payload_data is None:
                        continue
                    if payload_data["content_status"] == ContentStatus.INCOMPLETE.value:
                        incomplete_flag = True
                    message_item = payload_data["message"]
                    if message_item["type"] == MessageType.IMAGE_URL.value:
                        history.append(
                            {
                                "role": "assistant",
                                "content": {"path": message_item["content"]},
                            }
                        )
                    else:
                        if incomplete_flag:
                            self._incomplete_message = (
                                    self._incomplete_message + message_item["content"]
                            )
                            if history and history[-1]["role"] == "assistant":
                                history[-1]["content"] = self._incomplete_message
                            else:
                                history.append(
                                    {
                                        "role": "assistant",
                                        "content": self._incomplete_message,
                                    }
                                )
                        else:
                            if self._incomplete_message != "":
                                self._incomplete_message = (
                                        self._incomplete_message + message_item["content"]
                                )
                                if history and history[-1]["role"] == "assistant":
                                    history[-1]["content"] = self._incomplete_message
                                else:
                                    history.append(
                                        {
                                            "role": "assistant",
                                            "content": self._incomplete_message,
                                        }
                                    )
                                self._incomplete_message = ""
                            else:
                                history.append(
                                    {
                                        "role": "assistant",
                                        "content": message_item["content"],
                                    }
                                )
                    
                    yield history
                    
                    container.get_connector("redis_stream_client")._client.xack(
                        stream_name, group_name, message_id
                    )
                    
                    # check finish flag
                    if (
                            "interaction_type" in payload_data
                            and payload_data["interaction_type"] == 1
                    ):
                        finish_flag = True
                    if (
                            "content_status" in payload_data
                            and payload_data["content_status"]
                            == ContentStatus.END_ANSWER.value
                    ):
                        self._workflow_instance_id = None
                        finish_flag = True
            
            if finish_flag:
                break
            sleep(0.01)
    
    def processor_bot(self, history: list):
        history.append({"role": "assistant", "content": f"processing..."})
        yield history
        while True:
            status = self._processor.get_workflow(
                workflow_id=self._workflow_instance_id
            ).status
            if status in terminal_status:
                history.append({"role": "assistant", "content": f"completed"})
                yield history
                self._workflow_instance_id = None
                break
            sleep(0.01)
    
    def _get_redis_stream_message(
            self, group_name: str, consumer_name: str, stream_name: str
    ):
        messages = container.get_connector("redis_stream_client")._client.xreadgroup(
            group_name, consumer_name, {stream_name: ">"}, count=1
        )
        messages = [
            (
                stream,
                [
                    (
                        message_id,
                        {
                            k.decode("utf-8"): v.decode("utf-8")
                            for k, v in message.items()
                        },
                    )
                    for message_id, message in message_list
                ],
            )
            for stream, message_list in messages
        ]
        return messages
    
    def _check_redis_stream_exist(self, stream_name: str, group_name: str):
        try:
            container.get_connector("redis_stream_client")._client.xgroup_create(
                stream_name, group_name, id="0", mkstream=True
            )
        except Exception as e:
            logging.debug(f"Consumer group may already exist: {e}")
    
    def _get_message_payload(self, message: dict):
        logging.info(f"Received running message: {message}")
        payload = message.get("payload")
        # check payload data
        if not payload:
            logging.error("Payload is empty")
            return None
        try:
            payload_data = json.loads(payload)
        except json.JSONDecodeError as e:
            logging.error(f"Payload is not a valid JSON: {e}")
            return None
        return payload_data