from typing import Dict, List

import torch
import pandas as pd
import streamlit as st
from findkit import retrieval_pipeline

import config
from search_utils import (
    RetrievalPipelineWrapper,
    get_doc_cols,
    get_repos_with_descriptions,
    get_retrieval_df,
    merge_cols,
)


class RetrievalApp:

    def is_cuda_available(self):
        try:
            torch._C._cuda_init()
        except:
            return False
        return True

    def get_device_options(self):
        if self.is_cuda_available():
            return ["cuda", "cpu"]
        else:
            return ["cpu"]

    @st.cache(allow_output_mutation=True)
    def get_retrieval_df(self):
        return get_retrieval_df(self.data_path, config.text_list_cols)

    def __init__(self, data_path="lambdaofgod/pwc_repositories_with_dependencies"):
        self.data_path = data_path
        self.device = st.sidebar.selectbox("device", self.get_device_options())
        print("loading data")

        self.retrieval_df = self.get_retrieval_df().copy()

        model_name = st.sidebar.selectbox("model", config.model_names)
        self.query_encoder_name = "lambdaofgod/query-" + model_name
        self.document_encoder_name = "lambdaofgod/document-" + model_name

        doc_cols = get_doc_cols(model_name)

        st.sidebar.text("using models")
        st.sidebar.text("https://huggingface.co/" + self.query_encoder_name)
        st.sidebar.text("HTTP://huggingface.co/" + self.document_encoder_name)

        self.additional_shown_cols = st.sidebar.multiselect(
            label="used text features", options=config.text_cols, default=doc_cols
        )

    @staticmethod
    def show_retrieval_results(
        retrieval_pipe: RetrievalPipelineWrapper,
        query: str,
        k: int,
        all_queries: List[str],
        description_length: int,
        repos_by_query: Dict[str, pd.DataFrame],
        additional_shown_cols: List[str],
    ):
        print("started retrieval")
        if query in all_queries:
            with st.expander(
                "query is in gold standard set queries. Toggle viewing gold standard results?"
            ):
                st.write("gold standard results")
                task_repos = repos_by_query.get_group(query)
                st.table(get_repos_with_descriptions(retrieval_pipe.X_df, task_repos))
        with st.spinner(text="fetching results"):
            st.write(
                retrieval_pipe.search(
                    query, k, description_length, additional_shown_cols
                ).to_html(escape=False, index=False),
                unsafe_allow_html=True,
            )
        print("finished retrieval")

    def run_app(self, retrieval_pipeline):

        retrieved_results = st.sidebar.number_input("number of results", value=10)
        description_length = st.sidebar.number_input(
            "number of used description words", value=10
        )

        tasks_deduped = (
            self.retrieval_df["tasks"].explode().value_counts().reset_index()
        )  # drop_duplicates().sort_values().reset_index(drop=True)
        tasks_deduped.columns = ["task", "documents per task"]
        with st.sidebar.expander("View test set queries"):
            st.table(tasks_deduped.explode("task"))
        repos_by_query = self.retrieval_df.explode("tasks").groupby("tasks")
        query = st.text_input("input query", value="metric learning")
        RetrievalApp.show_retrieval_results(
            retrieval_pipeline,
            query,
            retrieved_results,
            tasks_deduped["task"].to_list(),
            description_length,
            repos_by_query,
            self.additional_shown_cols,
        )

    @st.cache(allow_output_mutation=True)
    def get_retrieval_pipeline(self, displayed_retrieval_df):
        return RetrievalPipelineWrapper.setup_from_encoder_names(
            self.query_encoder_name,
            self.document_encoder_name,
            displayed_retrieval_df["document"],
            displayed_retrieval_df,
            device=self.device,
        )

    def main(self):
        print("setting up retrieval_pipe")
        displayed_retrieval_df = merge_cols(
            self.retrieval_df.copy(), self.additional_shown_cols
        )
        retrieval_pipeline = self.get_retrieval_pipeline(displayed_retrieval_df)
        self.run_app(retrieval_pipeline)