import json  # to work with JSON
import threading  # to allow streaming response
import time  # to pave the deliver of the message

import faiss  # to create a search index
import gradio  # for the interface
import numpy  # to work with vectors
import pandas  # to work with pandas
import sentence_transformers  # to load an embedding model
import spaces  # for GPU
import transformers  # to load an LLM

# Constants
GREETING = (
    "Howdy! "
    "I'm an AI agent that uses [retrieval-augmented generation](https://en.wikipedia.org/wiki/Retrieval-augmented_generation) pipeline to answer questions about research by the [Design Research Collective](https://cmudrc.github.io/). "
    "And the best part is that I always try to cite my sources! "
    "I still make some mistakes though. " 
    "What can I tell you about today?"
)
EXAMPLE_QUERIES = [
    "Tell me about new research at the intersection of additive manufacturing and machine learning.",
    "What is a physics-informed neural network and what can it be used for?",
    "What can agent-based models do about climate change?",
    "What's the difference between a markov chain and a hidden markov model?",
    "What are the latest advancements in reinforcement learning?",
    "What is known about different modes for human-AI teaming?",
]
EMBEDDING_MODEL_NAME = "allenai-specter"
LLM_MODEL_NAME = "Qwen/Qwen2.5-7B-Instruct"
PUBLICATIONS_TO_RETRIEVE = 5
PARQUET_URL = "hf://datasets/ccm/publications/data/train-00000-of-00001.parquet"

# Load the dataset and convert to pandas
data = pandas.read_parquet(PARQUET_URL)

# Filter out any publications without an abstract
abstract_is_null = [
    '"abstract": null' in json.dumps(bibdict) for bibdict in data["bib_dict"].values
]
data = data[~pandas.Series(abstract_is_null)]
data.reset_index(inplace=True)

# Load the model for later use in embeddings
model = sentence_transformers.SentenceTransformer(EMBEDDING_MODEL_NAME)

# Create an LLM pipeline that we can send queries to
tokenizer = transformers.AutoTokenizer.from_pretrained(LLM_MODEL_NAME, trust_remote_code=True)
streamer = transformers.TextIteratorStreamer(
    tokenizer, skip_prompt=True, skip_special_tokens=True
)
chatmodel = transformers.AutoModelForCausalLM.from_pretrained(
    LLM_MODEL_NAME, device_map="auto", torch_dtype="auto", trust_remote_code=True
)

# Create a FAISS index for fast similarity search
metric = faiss.METRIC_INNER_PRODUCT
vectors = numpy.stack(data["embedding"].tolist(), axis=0)
index = faiss.IndexFlatL2(len(data["embedding"][0]))
index.metric_type = metric
faiss.normalize_L2(vectors)
index.train(vectors)
index.add(vectors)


def preprocess(query: str, k: int) -> tuple[str, str]:
    """
    Searches the dataset for the top k most relevant papers to the query and returns a prompt and references
    Args:
        query (str): The user's query
        k (int): The number of results to return
    Returns:
        tuple[str, str]: A tuple containing the prompt and references
    """
    encoded_query = numpy.expand_dims(model.encode(query), axis=0)
    faiss.normalize_L2(encoded_query)
    D, I = index.search(encoded_query, k)
    top_five = data.loc[I[0]]

    prompt = (
        "You are an AI assistant who delights in helping people learn about research from the Design Research Collective, which is a research lab at Carnegie Mellon University led by Professor Chris McComb. "
        "Your main task is to provide a concise ANSWER to the USER_QUERY that includes as many of the RESEARCH_ABSTRACTS as possible. "
        "The RESEARCH_ABSTRACTS are provided in the `.bibtex` format. Your ANSWER should contain citations to the RESEARCH_ABSTRACTS using (AUTHOR, YEAR) format. "
        "DO NOT list references at the end of the answer.\n\n"
        "RESEARCH_ABSTRACTS:\n```bibtex\n{{ABSTRACTS_GO_HERE}}\n```\n\n"
        "USER_GUERY:\n{{QUERY_GOES_HERE}}\n\n"
        "ANSWER:\n"
    )

    references = []
    research_abstracts = ""

    for i in range(k):
        year = str(int(top_five["bib_dict"].values[i]["pub_year"]))
        abstract = top_five["bib_dict"].values[i]["abstract"]
        url = "https://scholar.google.com/citations?view_op=view_citation&citation_for_view=" + top_five["author_pub_id"].values[i]
        title = top_five["bib_dict"].values[i]["title"]
        last_names = [
                    author.split(" ")[-1]
                    for author in top_five["bib_dict"]
                    .values[i]["author"]
                    .split(" and ")
                ]
        authors = ", ".join(
                last_names
            )

        first_authors_last_name = last_names[0]

        research_abstracts += top_five["bibtex"].values[i] + "\n"
        references.append(f"<a href=\"{url}\">{first_authors_last_name} {year}</a>")
        
    prompt = prompt.replace("{{ABSTRACTS_GO_HERE}}", research_abstracts)
    prompt = prompt.replace("{{QUERY_GOES_HERE}}", query)

    print(prompt)
    
    return prompt, "; ".join(references)


@spaces.GPU
def reply(message: str, history: list[str]) -> str:
    """
    This function is responsible for crafting a response
    Args:
        message (str): The user's message
        history (list[str]): The conversation history
    Returns:
        str: The AI's response
    """

    # Apply preprocessing
    message, bypass = preprocess(message, PUBLICATIONS_TO_RETRIEVE)

    # This is some handling that is applied to the history variable to put it in a good format
    history_transformer_format = [
        {"role": role, "content": message_pair[idx]}
        for message_pair in history
        for idx, role in enumerate(["user", "assistant"])
        if message_pair[idx] is not None
    ] + [{"role": "user", "content": message}]

    # Stream a response from pipe
    text = tokenizer.apply_chat_template(
        history_transformer_format, tokenize=False, add_generation_prompt=True
    )
    model_inputs = tokenizer([text], return_tensors="pt").to("cuda:0")

    generate_kwargs = dict(model_inputs, streamer=streamer, max_new_tokens=512)
    t = threading.Thread(target=chatmodel.generate, kwargs=generate_kwargs)
    t.start()

    partial_message = ""
    for new_token in streamer:
        if new_token != "<":
            partial_message += new_token
            time.sleep(0.01)
            yield partial_message

    yield partial_message + "\n\n" + bypass

    

# Create and run the gradio interface
gradio.ChatInterface(
    reply,
    examples=EXAMPLE_QUERIES,
    chatbot=gradio.Chatbot(
        show_label=False,
        show_share_button=False,
        show_copy_button=False,
        value=[[None, GREETING]],
        avatar_images=[
            "https://cdn.dribbble.com/users/316121/screenshots/2333676/11-04_scotty-plaid_dribbble.png",
            "https://media.thetab.com/blogs.dir/90/files/2021/06/screenshot-2021-06-10-at-110730-1024x537.png",
        ],
        height="60vh",
        bubble_full_width=False,
    ),
    retry_btn=None,
    undo_btn=None,
    clear_btn=None,
    theme=gradio.themes.Default(
        font=[gradio.themes.GoogleFont("Zilla Slab")]
    )
).launch(debug=True)