from src.pytorch_modules.datasets.schema_string_dataset import SchemaStringDataset
import os
import pandas as pd
import numpy as np
import json
import faiss
import torch


class UtilsSearch:
    def __init__(self, config):
        self.config = config

    @staticmethod
    def dataframe_to_index(df):
        embeddings = np.stack(df['embeddings'].to_numpy())
        norm_embeddings = np.ascontiguousarray(embeddings / np.linalg.norm(embeddings, axis=1)[:, None])
        # Create a FAISS index (Step 2, unchanged but using normalized embeddings)
        dimension = norm_embeddings.shape[1]
        index = faiss.IndexFlatL2(dimension)
        index.add(norm_embeddings)
        return index  # Ad


    def retrieve(self, query, df, model, index, top_k=100, api=False):
        query += "Represent this sentence for searching relevant passages: "
        """
        Search the DataFrame for the given query and return a sorted DataFrame based on similarity.

        :param query: The search query string.
        :param df: The input DataFrame containing embeddings.
        :param model: The model to encode the query and compute embeddings.
        :param index: The search index for querying.
        :param top_k: The number of top results to return.
        :return: A new DataFrame sorted by similarity to the query, with a 'similarities' column.
        """
        # Check if CUDA is available and set the device accordingly
        if not api:
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
            model.to(device)

            # Compute the query embedding
            query_vector = model.encode(query, convert_to_tensor=True, device=device).cpu().numpy()
            query_vector /= np.linalg.norm(query_vector)
        else:
            res = model.embeddings(
                input=[query],
                model=self.config["sentence_transformer_name"],
                prompt=None,
            )
            query_vector = np.array([entry.embedding for entry in res.data][0]).astype(np.float32)


        # Normalize the query vector


        # Perform the search
        distances, indices = index.search(np.array([query_vector]), top_k)

        # Retrieve the rows from the DataFrame corresponding to the indices
        retrieved_df = df.iloc[indices[0]]

        # Attach the distances as a new column named 'similarities'
        # Ensure the distances array matches the size of the retrieved DataFrame, especially if using slicing or other operations that might change its shape
        retrieved_df = retrieved_df.assign(similarities=distances[0])

        if 'similarities' in retrieved_df.columns:
            retrieved_df = retrieved_df.sort_values(by='similarities', ascending=True)

        # Optionally, you might want to reset the index if the order matters or if you need to serialize the DataFrame without index issues
        retrieved_df = retrieved_df.reset_index(drop=True)


        return retrieved_df

    def rerank(self, query, df_top_100, cross_encoder, index):
        # Convert the top 5 records to a list of dictionaries for processing
        # print(df_top_100)
        config = self.config
        df_copy = df_top_100.copy().reset_index(drop=True)
        records = df_copy.to_dict(orient='records')[:100]

        # Assuming SchemaStringDataset can handle GPU data
        dataset_str = SchemaStringDataset(records, config)

        # Extract documents from dataset
        documents = [batch["inputs"][:256]  for batch in dataset_str]

        # Rank documents based on the query
        # Ensure data processed by cross_encoder is moved to the correct device
        ids = [item["corpus_id"] for item in cross_encoder.rank(query, documents, top_k=10)]

        # Use the ids to filter and reorder the original DataFrame
        df_sorted_by_relevance = df_copy.loc[ids]
        return df_sorted_by_relevance

    def search(self, query, df, model, cross_encoder, index):
        sorted_df = self.retrieve(query, df, model, index)
        return self.rerank(query, sorted_df, cross_encoder, index)

    @staticmethod
    def top_10_common_values(df, column_name):
        """
        This function takes a pandas dataframe and a column name,
        and returns the top 10 most common non-null values of that column as a list.
        """
        # Drop null values from the specified column and count occurrences of each value
        # Convert the index of the resulting Series (which contains the values) to a list
        value_counts_list = df[column_name].dropna().value_counts().head(10).index.tolist()

        return value_counts_list

    @staticmethod
    def filter_dataframe(df, config, top_k_programmatic=100):
        """
        Filters a DataFrame based on scalar and discrete column configurations, with type handling and null filtering.

        Parameters:
        - df: pandas.DataFrame to filter.
        - config: Dictionary containing 'scalar_columns' and 'discrete_columns' configurations.

        Returns:
        - Filtered pandas.DataFrame.
        """
        scalar_columns = config.get('scalar_columns', [])
        discrete_columns = config.get('discrete_columns', [])

        # Combine all column names to check for nulls
        all_columns = [col["column_name"] for col in scalar_columns] + [col["column_name"] for col in discrete_columns]

        # Drop rows where any of the specified columns have null values
        df = df.dropna(subset=all_columns)

        # Filtering based on scalar columns
        for col in scalar_columns:
            column_name = col["column_name"]
            # Ensure min_value and max_value are of numeric type
            min_value = float(col["min_value"])
            max_value = float(col["max_value"])
            # Convert the DataFrame column to numeric type to avoid comparison issues
            df[column_name] = pd.to_numeric(df[column_name], errors='coerce')
            df = df[df[column_name].between(min_value, max_value)]

        # Filtering based on discrete columns
        for col in discrete_columns:
            column_name = col["column_name"]
            default_values = col["default_values"]
            if len(default_values) > 0:
                df = df[df[column_name].isin(default_values)]

        if 'similarities' in df.columns:
            df = df.sort_values(by='similarities', ascending=False)

            # Return the top 100 items with the highest similarity
        return df

    @staticmethod
    def drop_columns(df, config):
        columns_to_drop = config.get('columns_to_drop', [])
        df_dropped = df.drop(columns_to_drop, axis=1)
        return df_dropped