from sentence_transformers import CrossEncoder

import json
import math
import numpy as np
from middlewares.search_client import SearchClient
import os
from dotenv import load_dotenv


load_dotenv()


GOOGLE_SEARCH_ENGINE_ID = os.getenv("GOOGLE_SEARCH_ENGINE_ID")
GOOGLE_SEARCH_API_KEY = os.getenv("GOOGLE_SEARCH_API_KEY")
BING_SEARCH_API_KEY = os.getenv("BING_SEARCH_API_KEY")

reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")

googleSearchClient = SearchClient(
    "google", api_key=GOOGLE_SEARCH_API_KEY, engine_id=GOOGLE_SEARCH_ENGINE_ID
)
bingSearchClient = SearchClient("bing", api_key=BING_SEARCH_API_KEY, engine_id=None)




def rerank(query, top_k, search_results, chunk_size=512):
    chunks = []
    for result in search_results:
        text = result["text"]
        words = text.split()
        num_chunks = math.ceil(len(words) / chunk_size)
        for i in range(num_chunks):
            start = i * chunk_size
            end = (i + 1) * chunk_size
            chunk = " ".join(words[start:end])
            chunks.append((result["link"], chunk))

    # Create sentence combinations with the query
    sentence_combinations = [[query, chunk[1]] for chunk in chunks]

    # Compute similarity scores for these combinations
    similarity_scores = reranker.predict(sentence_combinations)

    # Sort scores indexes in decreasing order
    sim_scores_argsort = reversed(np.argsort(similarity_scores))

    # Rearrange search_results based on the reranked scores
    reranked_results = []
    for idx in sim_scores_argsort:
        link = chunks[idx][0]
        chunk = chunks[idx][1]
        reranked_results.append({"link": link, "text": chunk})

    # Return the top K ranks
    return reranked_results[:top_k]


def gen_augmented_prompt_via_websearch(
    prompt,
    search_vendor,
    n_crawl,
    top_k,
    pre_context="",
    post_context="",
    pre_prompt="",
    post_prompt="",
    pass_prev=False,
    prev_output="",
    chunk_size=512,
):
   
    try:
        search_results = []
        reranked_results = []
        if search_vendor == "Google":
            search_results = googleSearchClient.search(prompt, n_crawl)
        elif search_vendor == "Bing":
            print('[Bing search enabled]')
            search_results = bingSearchClient.search(prompt, n_crawl)
            print(search_results)
            print('[Bing search completed]') 
        if len(search_results) > 0:
            reranked_results = rerank(prompt, top_k, search_results, chunk_size)
    except Exception as e:
        print(e)

    links = []
    context = ""
    for res in reranked_results:
        context += res["text"] + "\n\n"
        link = res["link"]
        links.append(link)

    # remove duplicate links
    links = list(set(links))

    prev_output = prev_output if pass_prev else ""

    augmented_prompt = f"""
    
    {pre_context}

    {context}

    
    {post_context}
    
    {pre_prompt} 
    
    {prompt} 
    
    {post_prompt}

    {prev_output}

    """

    print(augmented_prompt)
    return augmented_prompt, links