import streamlit as st
import json
import base64
import urllib

PATH = []


def url_from_state_dict(state):
    return {
        key: [json.dumps(value)]
        for key, value in sorted(state.items())
    }

def state_dict_from_url(params):
    state = dict()
    # params = st.experimental_get_query_params()
    # st.write(params)
    for key, value in sorted(params.items()):
        state[key] = json.loads(value[-1])
    return state


st.markdown('''
<style>
    [data-testid="stVerticalBlock"] [data-testid="stVerticalBlock"] {
        background-color: #d4d4d4;
        box-shadow: 0px 0px 0px 6px #d4d4d4, 0px 0px 0px 7px #000000e3;
        border-radius: 2px;
    }
</style>
''', unsafe_allow_html=True)

left, right = st.columns(2)

left.write('### Session State:')
left.write(st.session_state)

right.write('### URL State:')
right.write(state_dict_from_url(st.experimental_get_query_params()))

if not st.session_state.to_dict():#  state_dict_from_url(st.experimental_get_query_params()) != st.session_state.to_dict():
    st.write('### Auto Setting Session State:')
    print('Yass!')
    print('setting to...', state_dict_from_url(st.experimental_get_query_params()))

    for key, value in state_dict_from_url(st.experimental_get_query_params()).items():
        st.session_state[key] = value
    st.experimental_rerun()
# pass

autoset = st.sidebar.radio('Auto Set URL', ['on', 'off'], key='radio')


class NumberInput:
    def __init__(
        self,
        label: str,
        value: float,
        key: str,
        min_value: float = None,
        max_value: float = None,
    ):
        PATH.append(key)
        self.label = label
        self.path = '.'.join(PATH)
        self.default = value
        self.min_value = min_value
        self.max_value = max_value

        self.define_state('value')

        PATH.pop()

    def define_state(self, key):
        if st.session_state.get(self.path + key) is not None:
            return
        st.session_state[self.path + key] = self.default

    def read_state(self, key):
        return st.session_state[self.path + key]

    def read_value(self):
        return self.read_state('value')

    @property
    def value(self):
        return self.read_value()

    def render(self):
        return st.number_input(
            label=self.label,
            min_value=self.min_value,
            max_value=self.max_value,
            # value=self.default,
            key=self.path + 'value',
            step=0.5,
        )


def argand(a):
    import matplotlib.pyplot as plt
    import numpy as np
    fig, ax = plt.subplots()

    for x in range(len(a)):
        ax.plot([0,a[x].real],[0,a[x].imag],'ro-',label='python')
    limit=np.max(np.ceil(np.absolute(a)))
    ax.axis(xmin=-limit, xmax=limit, ymin=-limit, ymax=limit)

    return fig


class ComplexNumberInput:
    def __init__(
        self,
        label: str,
        value: complex,
        key: str,
        swapplaces: bool = False,
    ) -> None:
        PATH.append(key)
        self.label = label
        self.path = '.'.join(PATH)
        self.default = value
        # start
        self.real = NumberInput('Real', self.default.real, 'real')
        self.imag = NumberInput('Imaginary', self.default.imag, 'imag')
        self.swapplaces = swapplaces
        # end
        PATH.pop()

    def read_value(self):
        return complex(self.real.value, self.imag.value)

    @property
    def value(self):
        return self.read_value()

    def render(self):
        with st.container():
            st.header(self.label)
            if self.swapplaces:
                self.imag.render()
                self.real.render()
            else:
                self.real.render()
                self.imag.render()
            st.pyplot(argand([self.value]))

st.markdown('''
<a target="_self" href="http://localhost:8501/?complex1.imagvalue=2.0&complex1.realvalue=3.5&complex2.imagvalue=2.5&complex2.realvalue=6.0&numbervalue=0.05&radio=%22on%22">LINK</a>
''', unsafe_allow_html=True)

n = NumberInput('Number', 123, 'number')
n.render()
st.write(n.value)


left, right = st.columns(2)
with left:
    x = ComplexNumberInput('Complex', 1 + 2j, 'complex1')
    x.render()
    st.write('`x` = ', x.value)

with right:
    y = ComplexNumberInput(st.text_input('header'), 1 + 2j, 'complex2', swapplaces=st.checkbox('swap places'))
    y.render()
    st.write('`y` = ', y.value)

st.write('`x` + `y` = ', x.value + y.value)

if autoset == 'on':

    # j = json.dumps(
    #     st.session_state.to_dict(),
    #     sort_keys=True,
    # )
    # escaped = urllib.parse.quote(j)
    # state_dict = {
    #     key: json.dumps(value)
    #     for key, value in st.session_state.to_dict().items()
    # }
    # st.experimental_set_query_params(**state_dict)
    st.experimental_set_query_params(
        **url_from_state_dict(st.session_state.to_dict())
    )
    # st.experimental_rerun()