from llama_index.core import Settings, VectorStoreIndex, StorageContext, load_index_from_storage
from llama_index.core.embeddings import BaseEmbedding
from llama_index.llms.mistralai import MistralAI
from llama_index.core.node_parser import SentenceSplitter
from llama_index.core import SimpleDirectoryReader
from llama_index.core import PromptTemplate
# from pydantic import PrivateAttr
# import requests
from typing import List, Optional, Union
# from llama_index.core.embeddings.utils import BaseEmbedding  
from llama_index.embeddings.huggingface import HuggingFaceInferenceAPIEmbedding
# from llama_index.embeddings.huggingface import HuggingFaceEmbedding
import streamlit as st
from functools import lru_cache
import pickle
import os


mistral_api_key = os.getenv("mistral_api_key")

class QASystem:
    def __init__(self, 
                 mistral_api_key: str = mistral_api_key,
                 data_dir: str = "./data",
                 storage_dir: str = "./index_llama_136_multilingual-e5-large",
                 model_temperature: float = 0.002):
        self.data_dir = data_dir
        self.storage_dir = storage_dir
        
        # Initialize embedding model with API
        # api_key = 
        self.embedding_model = HuggingFaceInferenceAPIEmbedding(
                                    model_name="intfloat/multilingual-e5-large",
                                    )
        # self.embedding_model = HuggingFaceEmbedding(model_name="intfloat/multilingual-e5-large",trust_remote_code=True)
        
        self.llm = MistralAI(
            model="mistral-large-latest", 
            api_key=mistral_api_key, 
            temperature=model_temperature,
            max_tokens=1024
        )
        self._configure_settings()

        # self.create_index()

        self.index = self.load_index()  # Define index here 
        
    def _configure_settings(self):
        Settings.llm = self.llm
        Settings.embed_model = self.embedding_model

    def create_index(self):
        print("creating index")
        documents = SimpleDirectoryReader(self.data_dir).load_data()
        node_parser = SentenceSplitter(chunk_size=206, chunk_overlap=0)
        nodes = node_parser.get_nodes_from_documents(documents, show_progress=True)
        
        sentence_index = VectorStoreIndex(nodes, show_progress=True)
        sentence_index.storage_context.persist(self.storage_dir)

        # # Save the index to a pickle file  
        # with open(f"{self.storage_dir}/index.pkl", "wb") as f:  
        #     pickle.dump(sentence_index, f) 
        
        return sentence_index
    
    def load_index(self):
        # with open(f'{self.storage_dir}/index.pkl', 'rb') as f:  
        #     sentence_index = pickle.load(f)

        storage_context = StorageContext.from_defaults(persist_dir=self.storage_dir)
        return load_index_from_storage(storage_context, embed_model=self.embedding_model)
    
    def create_query_engine(self):
        template = """
        استخدم المعلومات التالية للإجابة على السؤال في النهاية. إذا لم تكن تعرف الإجابة، فقل فقط أنك لا تعرف، لا تحاول اختلاق إجابة.

        {context_str}
        السؤال: {query_str}
        الإجابة بالعربية:
        """
        prompt = PromptTemplate(template=template)
        
        return self.index.as_query_engine(
            similarity_top_k=10,
            streaming=True,
            text_qa_template=prompt,
            response_mode="tree_summarize", #tree_summarize, simple_summarize, compact
            embed_model=self.embedding_model
        )
    
    def query(self, question: str):
        query_engine = self.create_query_engine()
        response = query_engine.query(question)
        return  response#.print_response_stream()


# Utilisation de singleton pour éviter les réinitialisations multiples  
# @st.cache_resource  
@lru_cache(maxsize=1000)
def get_qa_system():  
    return QASystem() 

def main():  
    st.markdown("""  
    <style>  
        @import url('https://fonts.googleapis.com/css2?family=Noto+Kufi+Arabic:wght@100;200;300;400;500;600;700;800;900&display=swap');

        /* Application globale de la police */
        * {
            font-family: 'Noto Kufi Arabic', sans-serif !important;
        }
        
        body {  
            text-align: right;
            font-family: 'Noto Kufi Arabic', sans-serif !important;
        }  
        
        /* Style pour tous les textes */
        p, div, span, button, input, label, h1, h2, h3, h4, h5, h6 {
            font-family: 'Noto Kufi Arabic', sans-serif !important;
        }

        /* Titre principal avec taille réduite */
        h1 {
            font-size: 1.2em !important;
            margin-bottom: 0.5em !important;
            text-align: center;
            padding: 0.3px;
        }

        .css-1h9b9rq.e1tzin5v0 {
            direction: rtl; 
            text-align: right;
            font-family: 'Noto Kufi Arabic', sans-serif !important;
        }  
        
        /* Style pour l'expandeur */
        .streamlit-expanderContent, div[data-testid="stExpander"] {
            direction: rtl !important;
            text-align: right !important;
            font-family: 'Noto Kufi Arabic', sans-serif !important;
        }
        
        /* Style pour les boutons de l'expandeur */
        button[kind="secondary"] {
            direction: rtl !important;
            text-align: right !important;
            width: 100% !important;
            font-family: 'Noto Kufi Arabic', sans-serif !important;
            font-weight: 30 !important;
        }
        
        /* Style pour tous les éléments de texte */
        p, div {
            direction: rtl !important;
            text-align: right !important;
            font-family: 'Noto Kufi Arabic', sans-serif !important;
        }
        
        /* Style pour les bullet points */
        ul, li {
            direction: rtl !important;
            text-align: right !important;
            margin-right: 20px !important;
            margin-left: 0 !important;
            font-family: 'Noto Kufi Arabic', sans-serif !important;
        }
        
        .stTextInput, .stButton {  
            margin-left: auto;  
            margin-right: 0;  
            font-family: 'Noto Kufi Arabic', sans-serif !important;
        }  
        
        .stTextInput {  
            width: 100% !important;   
            direction: rtl; 
            text-align: right;
            font-family: 'Noto Kufi Arabic', sans-serif !important;
        }
        
        /* Force RTL sur tous les conteneurs */
        .element-container, .stMarkdown {
            direction: rtl !important;
            text-align: right !important;
            font-family: 'Noto Kufi Arabic', sans-serif !important;
        }
        
        /* Style spécifique pour l'expandeur des sources */
        .css-1fcdlhc, .css-1629p8f {
            direction: rtl !important;
            text-align: right !important;
            font-family: 'Noto Kufi Arabic', sans-serif !important;
        }

        /* Style pour le titre */
        .stTitle {
            font-family: 'Noto Kufi Arabic', sans-serif !important;
            font-weight: 700 !important;
        }

        /* Style pour les boutons */
        .stButton>button {
            font-family: 'Noto Kufi Arabic', sans-serif !important;
            font-weight: 500 !important;
        }

        /* Style pour les champs de texte */
        .stTextInput>div>div>input {
            font-family: 'Noto Kufi Arabic', sans-serif !important;
        }
    </style>  
    """, unsafe_allow_html=True)  
    
    st.title("هذا تطبيق للاجابة عن الاسئلة المتعلقة بالقانون المغربي ")  
    st.title("حاليا يضم 136 قانونا")
    
    qa_system = get_qa_system()  
    
    question = st.text_input("اطرح سؤالك :",placeholder=None)  
    
    if st.button("بحث"):  
        if question:  
            response_container = st.empty()  
            def stream_response(token):  
                if 'current_response' not in st.session_state:  
                    st.session_state.current_response = ""  
                st.session_state.current_response += token  
                response_container.markdown(st.session_state.current_response, unsafe_allow_html=True)  
            
            try:  
                query_engine = qa_system.create_query_engine()  
                st.session_state.current_response = ""  
                response = query_engine.query(question)  
                
                full_response = ""  
                for token in response.response_gen:  
                    full_response += token  
                    stream_response(token)  
                
                if hasattr(response, 'source_nodes'):  
                    st.markdown("""
                        <div style="direction: rtl !important; text-align: right !important; font-family: 'Noto Kufi Arabic', sans-serif !important;">
                            <div class="streamlit-expanderHeader">
                                المصادر
                            </div>
                        </div>
                    """, unsafe_allow_html=True)
                    with st.expander(""):  
                        for node in response.source_nodes:  
                            st.markdown(f"""
                                <div style="direction: rtl !important; text-align: right !important; font-family: 'Noto Kufi Arabic', sans-serif !important;">
                                    <p style="text-align: right !important;">مصادر الجواب : {node.metadata.get('file_name', 'Unknown')}</p>
                                    <p style="text-align: right !important;">Extrait: {node.text[:]}</p>
                                </div>
                            """, unsafe_allow_html=True)
            except Exception as e:  
                st.error(f"Une erreur s'est produite : {e}")  

if __name__ == "__main__":  
    main()