import torch
import src.constants.config as configurations
from sentence_transformers import SentenceTransformer
from sentence_transformers import CrossEncoder
from src.constants.credentials import cohere_trial_key, mixedbread_key
import streamlit as st
from src.reader import Reader
from src.utils_search import UtilsSearch
from copy import deepcopy
import numpy as np
import cohere
from mixedbread_ai.client import MixedbreadAI
from src.pytorch_modules.datasets.schema_string_dataset import SchemaStringDataset


configurations = configurations.service_mxbai_msc_direct_config
api_key = cohere_trial_key
co = cohere.Client(api_key)
semantic_column_names = configurations["semantic_column_names"]
model = MixedbreadAI(api_key=mixedbread_key)
cross_encoder_name = configurations["cross_encoder_name"]

@st.cache_data
def init():
    config = configurations
    search_utils = UtilsSearch(config)
    reader = Reader(config=config["reader_config"])
    df = reader.read()
    index = search_utils.dataframe_to_index(df)
    return df, index, search_utils

def get_possible_values_for_column(column_name, search_utils, df):
    if column_name not in st.session_state:
        setattr(st.session_state, column_name, search_utils.top_10_common_values(df, column_name))
    return getattr(st.session_state, column_name)


# Initialize or retrieve from session state
if 'init_results' not in st.session_state:
    st.session_state.init_results = init()

# Now you can access your initialized objects directly from the session state
df, index, search_utils = st.session_state.init_results

# Streamlit app layout
st.title('Search Demo')

# Input fields
query = st.text_input('Enter your search query here')
use_cohere = st.checkbox('Use Cohere', value=True)  # Default to checked

programmatic_search_config = deepcopy(configurations['programmatic_search_config'])

dynamic_programmatic_search_config = {
    "scalar_columns": [],
    "discrete_columns": []
}


for column in programmatic_search_config['scalar_columns']:
    # Create number input for scalar values
    col_name = column["column_name"]
    min_val = float(column["min_value"])
    max_val = float(column["max_value"])
    user_min = st.number_input(f'Minimum {col_name.capitalize()}', min_value=min_val, max_value=max_val, value=min_val)
    user_max = st.number_input(f'Maximum {col_name.capitalize()}', min_value=min_val, max_value=max_val, value=max_val)
    dynamic_programmatic_search_config['scalar_columns'].append({"column_name": col_name, "min_value": user_min, "max_value": user_max})

for column in programmatic_search_config['discrete_columns']:
    # Create multiselect for discrete values
    col_name = column["column_name"]
    default_values = column["default_values"]
    # Assuming you have a function to fetch possible values for the discrete columns based on the column name
    possible_values = get_possible_values_for_column(col_name, search_utils, df)  # Implement this function based on your application
    selected_values = st.multiselect(f'Select {col_name.capitalize()}', options=possible_values, default=default_values)
    dynamic_programmatic_search_config['discrete_columns'].append({"column_name": col_name, "default_values": selected_values})


programmatic_search_config['scalar_columns'] = dynamic_programmatic_search_config['scalar_columns']
programmatic_search_config['discrete_columns'] = dynamic_programmatic_search_config['discrete_columns']


# Search button
if st.button('Search'):
    if query:  # Checking if a query was entered
        df_retrieved = search_utils.retrieve(query, df, model, index, top_k=1000, api=True)
        df_filtered = search_utils.filter_dataframe(df_retrieved, programmatic_search_config)
        df_filtered = df_filtered.sort_values(by='similarities', ascending=True)
        df_filtered = df_filtered[:100].reset_index(drop=True)

        if len(df_filtered) == 0:
            st.write('No results found')
        else:
            if use_cohere == False:
                records = df_filtered.to_dict(orient='records')
                dataset_str = SchemaStringDataset(records, configurations)
                documents = [batch["inputs"][:256] for batch in dataset_str]
                res = model.reranking(
                    model=cross_encoder_name,
                    query=query,
                    input=documents,
                    top_k=10,
                    return_input=False
                )
                ids = [item.index for item in res.data]
                results_df = df_filtered.loc[ids]

            else:
                df_filtered.fillna(value="", inplace=True)
                docs = df_filtered.to_dict('records')
                column_names = semantic_column_names
                docs = [{name: str(doc[name]) for name in column_names} for doc in docs]
                rank_fields = list(docs[0].keys())
                results = co.rerank(query=query, documents=docs, top_n=10, model='rerank-english-v3.0',
                                    rank_fields=rank_fields)
                top_ids = [hit.index for hit in results.results]
                # Create the DataFrame with the rerank results
                results_df = df_filtered.iloc[top_ids].copy()
                results_df['rank'] = (np.arange(len(results_df)) + 1)

            results_df = search_utils.drop_columns(results_df, programmatic_search_config)
            st.write(results_df)
    else:
        st.write("Please enter a query to search.")