import hashlib
from datetime import datetime
from typing import List, Optional

import pandas as pd
from clickhouse_connect import get_client
from langchain.schema.embeddings import Embeddings
from langchain.vectorstores.myscale import MyScaleWithoutJSON, MyScaleSettings
from streamlit.runtime.uploaded_file_manager import UploadedFile

from backend.chat_bot.tools import parse_files, extract_embedding
from backend.construct.build_retriever_tool import create_retriever_tool
from logger import logger


class ChatBotKnowledgeTable:
    def __init__(self, host, port, username, password,
                 embedding: Embeddings, parser_api_key: str, db="chat",
                 kb_table="private_kb", tool_table="private_tool") -> None:
        super().__init__()
        personal_files_schema_ = f"""
            CREATE TABLE IF NOT EXISTS {db}.{kb_table}(
                entity_id String,
                file_name String,
                text String,
                user_id String,
                created_by DateTime,
                vector Array(Float32),
                CONSTRAINT cons_vec_len CHECK length(vector) = 768,
                VECTOR INDEX vidx vector TYPE MSTG('metric_type=Cosine')
            ) ENGINE = ReplacingMergeTree ORDER BY entity_id
        """

        # `tool_name` represent private knowledge database name.
        private_knowledge_base_schema_ = f"""
            CREATE TABLE IF NOT EXISTS {db}.{tool_table}(
                tool_id String,
                tool_name String,
                file_names Array(String),
                user_id String,
                created_by DateTime,
                tool_description String
            ) ENGINE = ReplacingMergeTree ORDER BY tool_id
        """
        self.personal_files_table = kb_table
        self.private_knowledge_base_table = tool_table
        config = MyScaleSettings(
            host=host,
            port=port,
            username=username,
            password=password,
            database=db,
            table=kb_table,
        )
        self.client = get_client(
            host=config.host,
            port=config.port,
            username=config.username,
            password=config.password,
        )
        self.client.command("SET allow_experimental_object_type=1")
        self.client.command(personal_files_schema_)
        self.client.command(private_knowledge_base_schema_)
        self.parser_api_key = parser_api_key
        self.vector_store = MyScaleWithoutJSON(
            embedding=embedding,
            config=config,
            must_have_cols=["file_name", "text", "created_by"],
        )

    # List all files with given `user_id`
    def list_files(self, user_id: str):
        query = f"""
        SELECT DISTINCT file_name, COUNT(entity_id) AS num_paragraph, 
            arrayMax(arrayMap(x->length(x), groupArray(text))) AS max_chars
        FROM {self.vector_store.config.database}.{self.personal_files_table}
        WHERE user_id = '{user_id}' GROUP BY file_name
        """
        return [r for r in self.vector_store.client.query(query).named_results()]

    # Parse and embedding files
    def add_by_file(self, user_id, files: List[UploadedFile]):
        data = parse_files(self.parser_api_key, user_id, files)
        data = extract_embedding(self.vector_store.embeddings, data)
        self.vector_store.client.insert_df(
            table=self.personal_files_table,
            df=pd.DataFrame(data),
            database=self.vector_store.config.database,
        )

    # Remove all files and private_knowledge_bases with given `user_id`
    def clear(self, user_id: str):
        self.vector_store.client.command(
            f"DELETE FROM {self.vector_store.config.database}.{self.personal_files_table} "
            f"WHERE user_id='{user_id}'"
        )
        query = f"""DELETE FROM {self.vector_store.config.database}.{self.private_knowledge_base_table} 
                    WHERE user_id  = '{user_id}'"""
        self.vector_store.client.command(query)

    def create_private_knowledge_base(
            self, user_id: str, tool_name: str, tool_description: str, files: Optional[List[str]] = None
    ):
        self.vector_store.client.insert_df(
            self.private_knowledge_base_table,
            pd.DataFrame(
                [
                    {
                        "tool_id": hashlib.sha256(
                            (user_id + tool_name).encode("utf-8")
                        ).hexdigest(),
                        "tool_name": tool_name,  # tool_name represent user's private knowledge base.
                        "file_names": files,
                        "user_id": user_id,
                        "created_by": datetime.now(),
                        "tool_description": tool_description,
                    }
                ]
            ),
            database=self.vector_store.config.database,
        )

    # Show all private knowledge bases with given `user_id`
    def list_private_knowledge_bases(self, user_id: str, private_knowledge_base=None):
        extended_where = f"AND tool_name = '{private_knowledge_base}'" if private_knowledge_base else ""
        query = f"""
        SELECT tool_name, tool_description, length(file_names) 
        FROM {self.vector_store.config.database}.{self.private_knowledge_base_table}
        WHERE user_id = '{user_id}' {extended_where}
        """
        return [r for r in self.vector_store.client.query(query).named_results()]

    def remove_private_knowledge_bases(self, user_id: str, private_knowledge_bases: List[str]):
        unique_list = list(set(private_knowledge_bases))
        unique_list = ",".join([f"'{t}'" for t in unique_list])
        query = f"""DELETE FROM {self.vector_store.config.database}.{self.private_knowledge_base_table}
                    WHERE user_id  = '{user_id}' AND tool_name IN [{unique_list}]"""
        self.vector_store.client.command(query)

    def as_retrieval_tools(self, user_id, tool_name=None):
        logger.info(f"")
        private_knowledge_bases = self.list_private_knowledge_bases(user_id=user_id, private_knowledge_base=tool_name)
        retrievers = {}
        for private_kb in private_knowledge_bases:
            file_names_sql = f"""
            SELECT arrayJoin(file_names) FROM (
                SELECT file_names 
                FROM chat.private_tool
                WHERE user_id = '{user_id}' AND tool_name = '{private_kb["tool_name"]}'
            )
            """
            logger.info(f"user_id is {user_id}, file_names_sql is {file_names_sql}")
            res = self.client.query(file_names_sql)
            file_names = []
            for line in res.result_rows:
                file_names.append(line[0])
            file_names = ', '.join(f"'{item}'" for item in file_names)
            logger.info(f"user_id is {user_id}, file_names is {file_names}")
            retrievers[private_kb["tool_name"]] = create_retriever_tool(
                self.vector_store.as_retriever(
                    search_kwargs={"where_str": f"user_id='{user_id}' AND file_name IN ({file_names})"},
                ),
                tool_name=private_kb["tool_name"],
                description=private_kb["tool_description"],
            )
        return retrievers