import gradio as gr
from huggingface_hub import login, InferenceClient
import os
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings
import umap
import pandas as pd

HF_TOKEN = os.getenv("HUGGINGFACE_TOKEN")

login(token=HF_TOKEN)
client = InferenceClient(token=HF_TOKEN)


embeddings = HuggingFaceEmbeddings(model_name="OrdalieTech/Solon-embeddings-large-0.1")

db_code = FAISS.load_local("faiss_code_education",
        embeddings,
        allow_dangerous_deserialization=True)

reducer = umap.UMAP()
index = db_code.index
ntotal = min(index.ntotal, 4998)
embeds = index.reconstruct_n(0, ntotal)
umap_embeds = reducer.fit_transform(embeds)

articles_df = pd.DataFrame({
    "x" : umap_embeds[:,0],
    "y" : umap_embeds[:,1],
    "type" : [ "Source" ] * len(umap_embeds),
})

system_prompt = """Tu es un assistant juridique spécialisé dans le Code de l'éducation français. 
Ta mission est d'aider les utilisateurs à comprendre la législation en répondant à leurs questions.

Voici comment tu dois procéder :

1. **Analyse de la question:** Lis attentivement la question de l'utilisateur.
2. **Identification des articles pertinents:** Examine les 10 articles de loi fournis et sélectionne ceux qui sont les plus pertinents pour répondre à la question.
3. **Formulation de la réponse:** Rédige une réponse claire et concise en français, en utilisant les informations des articles sélectionnés. Cite explicitement les articles que tu utilises (par exemple, "Selon l'article L311-3...").
4. **Structure de la réponse:** Si ta réponse s'appuie sur plusieurs articles, structure-la de manière logique, en séparant les informations provenant de chaque article.
5. **Cas ambigus:** 
* Si la question est trop vague, demande des précisions à l'utilisateur.
* Si plusieurs articles pourraient s'appliquer, présente les différentes interprétations possibles."""


def query_rag(query, model, system_prompt):
    docs = db_code.similarity_search(query, 10)

    article_dict = {}
    context_list = []
    for doc in docs:
        article = doc.metadata
        context_list.append(' > '.join(article['chemin'])+'\n'+article['texte']+'\n---\n')
        article_dict[article['article']] = article

    user = 'Question de l\'utilisateur : ' + query + '\nContexte législatif :\n' + '\n'.join(context_list)

    messages = [ { "role" : "system", "content" : system_prompt } ]
    messages.append( { "role" : "user", "content" : user } )

    if "factice" in model:
        return user, article_dict

    chat_completion = client.chat_completion(
        messages=messages,
        model=model,
        max_tokens=1024)

    return chat_completion.choices[0].message.content, article_dict

def create_context_response(response, article_dict):
    context = '\n'
    for i, article in enumerate(article_dict):
        art = article_dict[article]
        context += '* **' + ' > '.join(art['chemin']) + '** : '+ art['texte'].replace('\n', '\n    ')+'\n'
    return context

def chat_interface(query, model, system_prompt):
    response, article_dict = query_rag(query, model, system_prompt)
    context = create_context_response(response, article_dict)
    return response, context

def update_plot(query):
    query_embed = embeddings.embed_documents([query])[0]
    query_umap_embed = reducer.transform([query_embed])
    
    data = {
        "x": umap_embeds[:, 0].tolist() + [query_umap_embed[0, 0]],
        "y": umap_embeds[:, 1].tolist() + [query_umap_embed[0, 1]],
        "type": ["Source"] * len(umap_embeds) + ["Requête"]
    }
    return pd.DataFrame(data)

with gr.Blocks(title="Assistant Juridique pour le Code de l'éducation (Beta)") as demo:
    gr.Markdown(
        """
        ## Posez vos questions sur le Code de l'éducation
        
        **Créé par Marc de Falco**

        **Avertissement :** Les informations fournies sont données à titre indicatif et ne constituent pas un avis juridique. Les serveurs étant publics, veuillez ne pas inclure de données sensibles.
        """
    )

    query_box = gr.Textbox(label="Votre question")

    model = gr.Dropdown(
        label="Modèle de langage",
        choices=[
            "meta-llama/Meta-Llama-3-70B-Instruct",
            "meta-llama/Meta-Llama-3-8B-Instruct",
            "HuggingFaceH4/zephyr-7b-beta",
            "NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO",
            "mistralai/Mixtral-8x22B-v0.1",
            "factice: question+contexte"
            ],
        value="meta-llama/Meta-Llama-3-70B-Instruct")

    submit_button = gr.Button("Envoyer")

    with gr.Tab(label="Réponse"):
        response_box = gr.Markdown()
    with gr.Tab(label="Sources"):
        sources_box = gr.Markdown()
    with gr.Tab(label="Visualisation"):
        scatter_plot = gr.ScatterPlot(articles_df,
                x = "x", y = "y",
                color="type",
                label="Visualisation des embeddings",
                width=500,
                height=500)
    with gr.Tab(label="Paramètres"):
        system_box = gr.Textbox(label="Invite systeme", value=system_prompt,
                                lines=20)

    submit_button.click(chat_interface, 
                inputs=[query_box, model, system_box], 
                outputs=[response_box, sources_box])
    submit_button.click(update_plot, inputs=[query_box], outputs=[scatter_plot])

demo.launch()