# query를 자동으로 읽고 쓰는 container를 정의

from __future__ import annotations

import re
from typing import Callable, TypeVar

import streamlit as st

__all__ = ["QueryWrapper", "get_base_url"]

T = TypeVar("T")


import hashlib
import urllib.parse


def SHA1(msg: str) -> str:
    return hashlib.sha1(msg.encode()).hexdigest()[:8]


def get_base_url():
    session = st.runtime.get_instance()._session_mgr.list_active_sessions()[0]
    return urllib.parse.urlunparse(
        [session.client.request.protocol, session.client.request.host, "", "", "", ""]
    )


class QueryWrapper:
    queries: dict[str, _QueryWrapper] = {}  # 기록용

    def __init__(self, query: str, label: str | None = None, use_hash: bool = True):
        self.__wrapper = QueryWrapper.queries[query] = _QueryWrapper(
            query, label, use_hash
        )

    def __call__(self, *args, **kwargs):
        return self.__wrapper(*args, **kwargs)

    @classmethod
    def get_sharable_link(cls):
        # for k, v in cls.queries.items():
        #     print(f"{k}: {v}")
        return re.sub(
            "&+", "&", "&".join([str(v) for k, v in cls.queries.items()])
        ).strip("&")


class _QueryWrapper:
    ILLEGAL_CHARS = "&/=?"

    def __init__(self, query: str, label: str | None = None, use_hash: bool = True):
        self.query = query
        self.label = label or query
        self.use_hash = use_hash
        self.hash_table = {}
        self.key = None

    def __call__(
        self,
        base_container: Callable,
        legal_list: list[T],
        default: T | list[T] | None = None,
        *,
        key: str | None = None,
        **kwargs,
    ) -> T | list[T] | None:
        val_from_query = st.query_params.get_all(self.query.lower())
        # print(val_from_query)
        legal = len(val_from_query) > 0
        self.key = key or self.label

        self.hash_table = {SHA1(str(v)): v for v in legal_list}

        # filter out illegal values
        if legal and legal_list:
            val_from_query = [v for v in val_from_query if v in self.hash_table]
        # print(self.label, val_from_query, legal)
        if legal:
            selected = [self.hash_table[v] for v in val_from_query]
        elif default:
            selected = default
        elif self.label in st.session_state:
            selected = st.session_state[self.label]
            if legal_list:
                if isinstance(selected, list):
                    selected = [v for v in selected if v in legal_list]
                elif selected not in legal_list:
                    selected = []
        else:
            selected = []
        if selected is None:
            pass
        elif len(selected) == 1 and base_container in [st.selectbox, st.radio]:
            selected = selected[0]
        # print(self.label, selected)
        if base_container == st.checkbox:
            selected = base_container(
                self.label,
                legal_list,
                index=legal_list.index(selected) if selected in legal_list else None,
                key=self.key,
                **kwargs,
            )
        elif base_container == st.multiselect:
            selected = base_container(
                self.label, legal_list, default=selected, key=self.key, **kwargs
            )
        elif base_container == st.radio:
            selected = base_container(
                self.label,
                legal_list,
                index=legal_list.index(selected) if selected in legal_list else None,
                key=self.key,
                **kwargs,
            )
        elif base_container == st.selectbox:
            selected = base_container(
                self.label,
                legal_list,
                index=legal_list.index(selected) if selected in legal_list else None,
                key=self.key,
                **kwargs,
            )
        else:
            selected = base_container(self.label, legal_list, key=self.key, **kwargs)
        return st.session_state[self.key]

    def __str__(self):
        selected = st.session_state.get(self.key, None)
        if isinstance(selected, str):
            return f"{self.query.lower()}={SHA1(selected)}"
        elif isinstance(selected, list):
            return "&".join([f"{self.query.lower()}={SHA1(str(v))}" for v in selected])
        else:
            return ""