# -*- coding: utf-8 -*-
"""
To run:
- activate the virtual environment
- streamlit run path\to\streamlit_app.py
"""
import logging
import os
import re
import sys
import time
import warnings
import shutil

from langchain.chat_models import ChatOpenAI
from langchain.embeddings.openai import OpenAIEmbeddings
import openai
import pandas as pd
import streamlit as st
from st_aggrid import GridOptionsBuilder, AgGrid, GridUpdateMode, ColumnsAutoSizeMode
from streamlit_chat import message

from streamlit_langchain_chat.constants import *
from streamlit_langchain_chat.customized_langchain.llms import OpenAI, AzureOpenAI, AzureOpenAIChat
from streamlit_langchain_chat.dataset import Dataset

# Configure logger
logging.basicConfig(format="\n%(asctime)s\n%(message)s", level=logging.INFO, force=True)
logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout))

warnings.filterwarnings('ignore')

if 'generated' not in st.session_state:
    st.session_state['generated'] = []
if 'past' not in st.session_state:
    st.session_state['past'] = []
if 'costs' not in st.session_state:
    st.session_state['costs'] = []
if 'contexts' not in st.session_state:
    st.session_state['contexts'] = []
if 'chunks' not in st.session_state:
    st.session_state['chunks'] = []
if 'user_input' not in st.session_state:
    st.session_state['user_input'] = ""
if 'dataset' not in st.session_state:
    st.session_state['dataset'] = None


def check_api_keys() -> bool:
    source_id = app.params['source_id']
    index_id = app.params['index_id']

    open_api_key = os.getenv('OPENAI_API_KEY', '')
    openapi_api_key_ready = type(open_api_key) is str and len(open_api_key) > 0

    pinecone_api_key = os.getenv('PINECONE_API_KEY', '')
    pinecone_api_key_ready = type(pinecone_api_key) is str and len(pinecone_api_key) > 0 if index_id == 2 else True

    is_ready = True if openapi_api_key_ready and pinecone_api_key_ready else False
    return is_ready


def check_combination_point() -> bool:
    type_id = app.params['type_id']
    open_api_key = os.getenv('OPENAI_API_KEY', '')
    openapi_api_key_ready = type(open_api_key) is str and len(open_api_key) > 0
    api_base = app.params['api_base']

    if type_id == 1:
        deployment_id = app.params['deployment_id']
        return True if openapi_api_key_ready and api_base and deployment_id else False
    elif type_id == 2:
        return True if openapi_api_key_ready and api_base else False
    else:
        return False


def check_index() -> bool:
    dataset = st.session_state['dataset']

    index_built = dataset.index_docstore if hasattr(dataset, "index_docstore") else False
    without_source = app.params['source_id'] == 4
    is_ready = True if index_built or without_source else False
    return is_ready


def check_index_point() -> bool:
    index_id = app.params['index_id']

    pinecone_api_key = os.getenv('PINECONE_API_KEY', '')
    pinecone_api_key_ready = type(pinecone_api_key) is str and len(pinecone_api_key) > 0 if index_id == 2 else True
    pinecone_environment = os.getenv('PINECONE_ENVIRONMENT', False) if index_id == 2 else True

    is_ready = True if index_id and pinecone_api_key_ready and pinecone_environment else False
    return is_ready


def check_params_point() -> bool:
    max_sources = app.params['max_sources']
    temperature = app.params['temperature']

    is_ready = True if max_sources and isinstance(temperature, float) else False
    return is_ready


def check_source_point() -> bool:
    return True


def clear_chat_history():
    if st.session_state['past'] or st.session_state['generated'] or st.session_state['contexts'] or st.session_state['chunks'] or st.session_state['costs']:
        st.session_state['past'] = []
        st.session_state['generated'] = []
        st.session_state['contexts'] = []
        st.session_state['chunks'] = []
        st.session_state['costs'] = []


def clear_index():
    if dataset := st.session_state['dataset']:
        # delete directory (with files)
        index_path = dataset.index_path
        if index_path.exists():
            shutil.rmtree(str(index_path))

        # update variable
        st.session_state['dataset'] = None

    elif (TEMP_DIR / "default").exists():
        shutil.rmtree(str(TEMP_DIR / "default"))


def check_sources() -> bool:
    uploaded_files_rows = app.params['uploaded_files_rows']
    urls_df = app.params['urls_df']
    source_id = app.params['source_id']

    some_files = True if uploaded_files_rows and uploaded_files_rows[-1].get('filepath') != "" else False
    some_urls = bool([True for url, citation in urls_df.to_numpy() if url])

    only_local_files = some_files and not some_urls
    only_urls = not some_files and some_urls
    is_ready = only_local_files or only_urls or (source_id == 4)
    return is_ready


def collect_dataset_and_built_index():
    start = time.time()
    uploaded_files_rows = app.params['uploaded_files_rows']
    urls_df = app.params['urls_df']
    type_id = app.params['type_id']
    temperature = app.params['temperature']
    index_id = app.params['index_id']
    api_base = app.params['api_base']
    deployment_id = app.params['deployment_id']

    some_files = True if uploaded_files_rows and uploaded_files_rows[-1].get('filepath') != "" else False
    some_urls = bool([True for url, citation in urls_df.to_numpy() if url])

    openai.api_type = "azure" if type_id == 1 else "open_ai"
    openai.api_base = api_base
    openai.api_version = "2023-03-15-preview" if type_id == 1 else None

    if deployment_id != "text-davinci-003":
        dataset = Dataset(
            llm=ChatOpenAI(
                temperature=temperature,
                max_tokens=512,
                deployment_id=deployment_id,
            )
        )
    else:
        dataset = Dataset(
            llm=OpenAI(
                temperature=temperature,
                max_tokens=512,
                deployment_id=COMBINATIONS_OPTIONS.get(combination_id).get('deployment_name'),
            )
        )

    # get url documents
    if some_urls:
        urls_df = urls_df.reset_index()
        for url_index, url_row in urls_df.iterrows():
            url = url_row.get('urls', '')
            citation = url_row.get('citation string', '')
            if url:
                try:
                    dataset.add(
                        url,
                        citation,
                        citation,
                        disable_check=True  # True to accept Japanese letters
                    )
                except Exception as e:
                    print(e)
                    pass

    # dataset is pandas dataframe
    if some_files:
        for uploaded_files_row in uploaded_files_rows:
            key = uploaded_files_row.get('citation string') if ',' not in uploaded_files_row.get('citation string') else None
            dataset.add(
                uploaded_files_row.get('filepath'),
                uploaded_files_row.get('citation string'),
                key=key,
                disable_check=True  # True to accept Japanese letters
            )

    openai_embeddings = OpenAIEmbeddings(
        document_model_name="text-embedding-ada-002",
        query_model_name="text-embedding-ada-002",
    )
    if index_id == 1:
        dataset._build_faiss_index(openai_embeddings)
    else:
        dataset._build_pinecone_index(openai_embeddings)
    st.session_state['dataset'] = dataset

    if OPERATING_MODE == "debug":
        print(f"time to collect dataset: {time.time() - start:.2f} [s]")


def configure_streamlit_and_page():
    # Configure Streamlit page and state
    st.set_page_config(**ST_CONFIG)

    # Force responsive layout for columns also on mobile
    st.write(
        """<style>
        [data-testid="column"] {
            width: calc(50% - 1rem);
            flex: 1 1 calc(50% - 1rem);
            min-width: calc(50% - 1rem);
        }
        </style>""",
        unsafe_allow_html=True,
    )


def get_answer():
    query = st.session_state['user_input']
    dataset = st.session_state['dataset']
    type_id = app.params['type_id']
    index_id = app.params['index_id']
    max_sources = app.params['max_sources']

    if query and dataset and type_id and index_id:
        chat_history = [(past, generated)
                        for (past, generated) in zip(st.session_state['past'], st.session_state['generated'])]
        marginal_relevance = False if not index_id == 1 else True
        start = time.time()
        openai_embeddings = OpenAIEmbeddings(
            document_model_name="text-embedding-ada-002",
            query_model_name="text-embedding-ada-002",
        )
        result = dataset.query(
            query,
            openai_embeddings,
            chat_history,
            marginal_relevance=marginal_relevance,  # if pinecone is used it must be False
        )
        if OPERATING_MODE == "debug":
            print(f"time to get answer: {time.time() - start:.2f} [s]")
            print("-" * 10)
        # response = {'generated_text': result.formatted_answer}
        # response = {'generated_text': f"test_{len(st.session_state['generated'])} by {query}"}  # @debug
        return result
    else:
        return None


def load_main_page():
    """
    Load the body of web.
    """
    # Streamlit	    HTML	Markdown
    # st.title	    <h1>	#
    # st.header	    <h2>	##
    # st.subheader	<h3>	###
    st.markdown(f"## Augmented-Retrieval Q&A ChatGPT ({APP_VERSION})")
    validate_status()
    st.markdown(f"#### **Status**: {app.params['status']}")

    # hidden div with anchor
    st.markdown("<div id='linkto_top'></div>", unsafe_allow_html=True)
    col1, col2, col3 = st.columns(3)
    col1.button(label="clear index", type="primary", on_click=clear_index)
    col2.button(label="clear conversation", type="primary", on_click=clear_chat_history)
    col3.markdown("<a href='#linkto_bottom'>Link to bottom</a>", unsafe_allow_html=True)

    if st.session_state["generated"]:
        for i in range(len(st.session_state["generated"])):
            message(st.session_state['past'][i], is_user=True, key=str(i) + '_user')
            message(st.session_state['generated'][i], key=str(i))
            with st.expander("See context"):
                st.write(st.session_state['contexts'][i])
            with st.expander("See chunks"):
                st.write(st.session_state['chunks'][i])
            with st.expander("See costs"):
                st.write(st.session_state['costs'][i])
    dataset = st.session_state['dataset']
    index_built = dataset.index_docstore if hasattr(dataset, "index_docstore") else False
    without_source = app.params['source_id'] == 4
    enable_chat_button = index_built or without_source
    st.text_input("You:",
                  key='user_input',
                  on_change=on_enter,
                  disabled=not enable_chat_button
                  )

    st.markdown("<a href='#linkto_top'>Link to top</a>", unsafe_allow_html=True)
    # hidden div with anchor
    st.markdown("<div id='linkto_bottom'></div>", unsafe_allow_html=True)


def load_sidebar_page():
    st.sidebar.markdown("## Instructions")

    # ############ #
    # SOURCES TYPE #
    # ############ #
    st.sidebar.markdown("1. Select a source:")
    source_selected = st.sidebar.selectbox(
        "Choose the location of your info to give context to chatgpt",
        [key for key, value in SOURCES_IDS.items()])
    app.params['source_id'] = SOURCES_IDS.get(source_selected, None)

    # ##### #
    # MODEL #
    # ##### #
    st.sidebar.markdown("2. Select a model (LLM):")
    combination_selected = st.sidebar.selectbox(
        "Choose type: MSF Azure OpenAI and model / OpenAI",
        [key for key, value in TYPE_IDS.items()])
    app.params['type_id'] = TYPE_IDS.get(combination_selected, None)

    if app.params['type_id'] == 1:  # with AzureOpenAI endpoint
        # https://docs.streamlit.io/library/api-reference/widgets/st.text_input
        os.environ['OPENAI_API_KEY'] = st.sidebar.text_input(
            label="Enter Azure OpenAI API Key",
            type="password"
        ).strip()
        app.params['api_base'] = st.sidebar.text_input(
            label="Enter Azure API base",
            placeholder="https://<api_base_endpoint>.openai.azure.com/",
        ).strip()
        app.params['deployment_id'] = st.sidebar.text_input(
            label="Enter Azure deployment_id",
        ).strip()
    elif app.params['type_id'] == 2:  # with OpenAI endpoint
        os.environ['OPENAI_API_KEY'] = st.sidebar.text_input(
            label="Enter OpenAI API Key",
            placeholder="sk-...",
            type="password"
        ).strip()
        app.params['api_base'] = "https://api.openai.com/v1"
        app.params['deployment_id'] = None

    # ####### #
    # INDEXES #
    # ####### #
    st.sidebar.markdown("3. Select a index store:")
    index_selected = st.sidebar.selectbox(
        "Type of Index",
        [key for key, value in INDEX_IDS.items()])
    app.params['index_id'] = INDEX_IDS.get(index_selected, None)
    if app.params['index_id'] == 2:  # with pinecone
        os.environ['PINECONE_API_KEY'] = st.sidebar.text_input(
            label="Enter pinecone API Key",
            type="password"
        ).strip()

        os.environ['PINECONE_ENVIRONMENT'] = st.sidebar.text_input(
            label="Enter pinecone environment",
            placeholder="eu-west1-gcp",
        ).strip()

    # ############## #
    # CONFIGURATIONS #
    # ############## #
    st.sidebar.markdown("4. Choose configuration:")
    # https://docs.streamlit.io/library/api-reference/widgets/st.number_input
    max_sources = st.sidebar.number_input(
        label="Top-k: Number of chunks/sections (1-5)",
        step=1,
        format="%d",
        value=5
    )
    app.params['max_sources'] = max_sources
    temperature = st.sidebar.number_input(
        label="Temperature (0.0 – 1.0)",
        step=0.1,
        format="%f",
        value=0.0,
        min_value=0.0,
        max_value=1.0
    )
    app.params['temperature'] = round(temperature, 1)

    # ############## #
    # UPLOAD SOURCES #
    # ############## #
    app.params['uploaded_files_rows'] = []
    if app.params['source_id'] == 1:
        # https://docs.streamlit.io/library/api-reference/widgets/st.file_uploader
        # https://towardsdatascience.com/make-dataframes-interactive-in-streamlit-c3d0c4f84ccb
        st.sidebar.markdown("""5. Upload your local documents and modify citation strings (optional)""")
        uploaded_files = st.sidebar.file_uploader(
            "Choose files",
            accept_multiple_files=True,
            type=['pdf', 'PDF',
                  'txt', 'TXT',
                  'html',
                  'docx', 'DOCX',
                  'pptx', 'PPTX',
                  ],
        )
        uploaded_files_dataset = request_pathname(uploaded_files)
        uploaded_files_df = pd.DataFrame(
            uploaded_files_dataset,
            columns=['filepath', 'citation string'])
        uploaded_files_grid_options_builder = GridOptionsBuilder.from_dataframe(uploaded_files_df)
        uploaded_files_grid_options_builder.configure_selection(
            selection_mode='multiple',
            pre_selected_rows=list(range(uploaded_files_df.shape[0])) if uploaded_files_df.iloc[-1, 0] != "" else [],
            use_checkbox=True,
        )
        uploaded_files_grid_options_builder.configure_column("citation string", editable=True)
        uploaded_files_grid_options_builder.configure_auto_height()
        uploaded_files_grid_options = uploaded_files_grid_options_builder.build()
        with st.sidebar:
            uploaded_files_ag_grid = AgGrid(
                uploaded_files_df,
                gridOptions=uploaded_files_grid_options,
                update_mode=GridUpdateMode.SELECTION_CHANGED | GridUpdateMode.VALUE_CHANGED,
            )
        app.params['uploaded_files_rows'] = uploaded_files_ag_grid["selected_rows"]

    app.params['urls_df'] = pd.DataFrame()
    if app.params['source_id'] == 3:
        st.sidebar.markdown("""5. Write some urls and modify citation strings if you want (to look prettier)""")
        # option 1: with streamlit version 1.20.0+
        # app.params['urls_df'] = st.sidebar.experimental_data_editor(
        #     pd.DataFrame([["", ""]], columns=['urls', 'citation string']),
        #     use_container_width=True,
        #     num_rows="dynamic",
        # )

        # option 2: with streamlit version 1.19.0
        urls_dataset = [["", ""],
                        ["", ""],
                        ["", ""],
                        ["", ""],
                        ["", ""]]
        urls_df = pd.DataFrame(
            urls_dataset,
            columns=['urls', 'citation string'])

        urls_grid_options_builder = GridOptionsBuilder.from_dataframe(urls_df)
        urls_grid_options_builder.configure_columns(['urls', 'citation string'], editable=True)
        urls_grid_options_builder.configure_auto_height()
        urls_grid_options = urls_grid_options_builder.build()
        with st.sidebar:
            urls_ag_grid = AgGrid(
                urls_df,
                gridOptions=urls_grid_options,
                update_mode=GridUpdateMode.SELECTION_CHANGED | GridUpdateMode.VALUE_CHANGED,
            )
        df = urls_ag_grid.data
        df = df[df.urls != ""]
        app.params['urls_df'] = df

    if app.params['source_id'] in (1, 2, 3):
        st.sidebar.markdown("""6. Build an index where you can ask""")
        api_keys_ready = check_api_keys()
        source_ready = check_sources()
        enable_index_button = api_keys_ready and source_ready
        if st.sidebar.button("Build index", disabled=not enable_index_button):
            collect_dataset_and_built_index()


def main():
    configure_streamlit_and_page()
    load_sidebar_page()
    load_main_page()


def on_enter():
    output = get_answer()
    if output:
        st.session_state.past.append(st.session_state['user_input'])
        st.session_state.generated.append(output.answer)
        st.session_state.contexts.append(output.context)
        st.session_state.chunks.append(output.chunks)
        st.session_state.costs.append(output.cost_str)
        st.session_state['user_input'] = ""


def request_pathname(files):
    if not files:
        return [["", ""]]

    # check if temporal directory exist, if not create it
    if not Path.exists(TEMP_DIR):
        TEMP_DIR.mkdir(
            parents=True,
            exist_ok=True,
        )

    file_paths = []
    for file in files:
        # # absolut path
        # file_path = str(TEMP_DIR / file.name)
        # relative path
        file_path = str((TEMP_DIR / file.name).relative_to(ROOT_DIR))
        file_paths.append(file_path)
        with open(file_path, "wb") as f:
            f.write(file.getbuffer())
    return [[filepath, filename.name] for filepath, filename in zip(file_paths, files)]


def validate_status():
    source_point_ready = check_source_point()
    combination_point_ready = check_combination_point()
    index_point_ready = check_index_point()
    params_point_ready = check_params_point()
    sources_ready = check_sources()
    index_ready = check_index()

    if source_point_ready and combination_point_ready and index_point_ready and params_point_ready and sources_ready and index_ready:
        app.params['status'] = "✨Ready✨"
    elif not source_point_ready:
        app.params['status'] = "⚠️Review step 1 on the sidebar."
    elif not combination_point_ready:
        app.params['status'] = "⚠️Review step 2 on the sidebar. API Keys or endpoint, ..."
    elif not index_point_ready:
        app.params['status'] = "⚠️Review step 3 on the sidebar. Index API Key or environment."
    elif not params_point_ready:
        app.params['status'] = "⚠️Review step 4 on the sidebar"
    elif not sources_ready:
        app.params['status'] = "⚠️Review step 5 on the sidebar. Waiting for some source..."
    elif not index_ready:
        app.params['status'] = "⚠️Review step 6 on the sidebar. Waiting for press button to create index ..."
    else:
        app.params['status'] = "⚠️Something is not ready..."


class StreamlitLangchainChatApp():
    def __init__(self) -> None:
        """Use __init__ to define instance variables. It cannot have any arguments."""
        self.params = dict()

    def run(self, **state) -> None:
        """Define here all logic required by your application."""
        main()


if __name__ == "__main__":
    app = StreamlitLangchainChatApp()
    app.run()