File size: 10,161 Bytes
a34fa54
 
 
 
 
 
 
 
 
 
 
 
 
548a3a4
a34fa54
4fdd604
a34fa54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6d8f0ce
a34fa54
 
 
 
 
 
 
 
 
 
548a3a4
 
a34fa54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
from llama_index.core import Settings, VectorStoreIndex, StorageContext, load_index_from_storage
from llama_index.core.embeddings import BaseEmbedding
from llama_index.llms.mistralai import MistralAI
from llama_index.core.node_parser import SentenceSplitter
from llama_index.core import SimpleDirectoryReader
from llama_index.core import PromptTemplate
# from pydantic import PrivateAttr
# import requests
from typing import List, Optional, Union
# from llama_index.core.embeddings.utils import BaseEmbedding  
from llama_index.embeddings.huggingface import HuggingFaceInferenceAPIEmbedding
# from llama_index.embeddings.huggingface import HuggingFaceEmbedding
import streamlit as st
from functools import lru_cache
import pickle
import os


mistral_api_key = os.getenv("mistral_api_key")

class QASystem:
    def __init__(self, 
                 mistral_api_key: str = mistral_api_key,
                 data_dir: str = "./data",
                 storage_dir: str = "./index_llama_136_multilingual-e5-large",
                 model_temperature: float = 0.002):
        self.data_dir = data_dir
        self.storage_dir = storage_dir
        
        # Initialize embedding model with API
        # api_key = 
        self.embedding_model = HuggingFaceInferenceAPIEmbedding(
                                    model_name="intfloat/multilingual-e5-large",
                                    )
        # self.embedding_model = HuggingFaceEmbedding(model_name="intfloat/multilingual-e5-large",trust_remote_code=True)
        
        self.llm = MistralAI(
            model="mistral-large-latest", 
            api_key=mistral_api_key, 
            temperature=model_temperature,
            max_tokens=1024
        )
        self._configure_settings()

        # self.create_index()

        self.index = self.load_index()  # Define index here 
        
    def _configure_settings(self):
        Settings.llm = self.llm
        Settings.embed_model = self.embedding_model

    def create_index(self):
        print("creating index")
        documents = SimpleDirectoryReader(self.data_dir).load_data()
        node_parser = SentenceSplitter(chunk_size=206, chunk_overlap=0)
        nodes = node_parser.get_nodes_from_documents(documents, show_progress=True)
        
        sentence_index = VectorStoreIndex(nodes, show_progress=True)
        sentence_index.storage_context.persist(self.storage_dir)

        # # Save the index to a pickle file  
        # with open(f"{self.storage_dir}/index.pkl", "wb") as f:  
        #     pickle.dump(sentence_index, f) 
        
        return sentence_index
    
    def load_index(self):
        # with open(f'{self.storage_dir}/index.pkl', 'rb') as f:  
        #     sentence_index = pickle.load(f)

        storage_context = StorageContext.from_defaults(persist_dir=self.storage_dir)
        return load_index_from_storage(storage_context, embed_model=self.embedding_model)
    
    def create_query_engine(self):
        template = """
        استخدم المعلومات التالية للإجابة على السؤال في النهاية. إذا لم تكن تعرف الإجابة، فقل فقط أنك لا تعرف، لا تحاول اختلاق إجابة.

        {context_str}
        السؤال: {query_str}
        الإجابة بالعربية:
        """
        prompt = PromptTemplate(template=template)
        
        return self.index.as_query_engine(
            similarity_top_k=10,
            streaming=True,
            text_qa_template=prompt,
            response_mode="tree_summarize", #tree_summarize, simple_summarize, compact
            embed_model=self.embedding_model
        )
    
    def query(self, question: str):
        query_engine = self.create_query_engine()
        response = query_engine.query(question)
        return  response#.print_response_stream()


# Utilisation de singleton pour éviter les réinitialisations multiples  
# @st.cache_resource  
@lru_cache(maxsize=1000)
def get_qa_system():  
    return QASystem() 

def main():  
    st.markdown("""  
    <style>  
        @import url('https://fonts.googleapis.com/css2?family=Noto+Kufi+Arabic:wght@100;200;300;400;500;600;700;800;900&display=swap');

        /* Application globale de la police */
        * {
            font-family: 'Noto Kufi Arabic', sans-serif !important;
        }
        
        body {  
            text-align: right;
            font-family: 'Noto Kufi Arabic', sans-serif !important;
        }  
        
        /* Style pour tous les textes */
        p, div, span, button, input, label, h1, h2, h3, h4, h5, h6 {
            font-family: 'Noto Kufi Arabic', sans-serif !important;
        }

        /* Titre principal avec taille réduite */
        h1 {
            font-size: 1.2em !important;
            margin-bottom: 0.5em !important;
            text-align: center;
            padding: 0.3px;
        }

        .css-1h9b9rq.e1tzin5v0 {
            direction: rtl; 
            text-align: right;
            font-family: 'Noto Kufi Arabic', sans-serif !important;
        }  
        
        /* Style pour l'expandeur */
        .streamlit-expanderContent, div[data-testid="stExpander"] {
            direction: rtl !important;
            text-align: right !important;
            font-family: 'Noto Kufi Arabic', sans-serif !important;
        }
        
        /* Style pour les boutons de l'expandeur */
        button[kind="secondary"] {
            direction: rtl !important;
            text-align: right !important;
            width: 100% !important;
            font-family: 'Noto Kufi Arabic', sans-serif !important;
            font-weight: 30 !important;
        }
        
        /* Style pour tous les éléments de texte */
        p, div {
            direction: rtl !important;
            text-align: right !important;
            font-family: 'Noto Kufi Arabic', sans-serif !important;
        }
        
        /* Style pour les bullet points */
        ul, li {
            direction: rtl !important;
            text-align: right !important;
            margin-right: 20px !important;
            margin-left: 0 !important;
            font-family: 'Noto Kufi Arabic', sans-serif !important;
        }
        
        .stTextInput, .stButton {  
            margin-left: auto;  
            margin-right: 0;  
            font-family: 'Noto Kufi Arabic', sans-serif !important;
        }  
        
        .stTextInput {  
            width: 100% !important;   
            direction: rtl; 
            text-align: right;
            font-family: 'Noto Kufi Arabic', sans-serif !important;
        }
        
        /* Force RTL sur tous les conteneurs */
        .element-container, .stMarkdown {
            direction: rtl !important;
            text-align: right !important;
            font-family: 'Noto Kufi Arabic', sans-serif !important;
        }
        
        /* Style spécifique pour l'expandeur des sources */
        .css-1fcdlhc, .css-1629p8f {
            direction: rtl !important;
            text-align: right !important;
            font-family: 'Noto Kufi Arabic', sans-serif !important;
        }

        /* Style pour le titre */
        .stTitle {
            font-family: 'Noto Kufi Arabic', sans-serif !important;
            font-weight: 700 !important;
        }

        /* Style pour les boutons */
        .stButton>button {
            font-family: 'Noto Kufi Arabic', sans-serif !important;
            font-weight: 500 !important;
        }

        /* Style pour les champs de texte */
        .stTextInput>div>div>input {
            font-family: 'Noto Kufi Arabic', sans-serif !important;
        }
    </style>  
    """, unsafe_allow_html=True)  
    
    st.title("هذا تطبيق للاجابة عن الاسئلة المتعلقة بالقانون المغربي ")  
    st.title("حاليا يضم 136 قانونا")
    
    qa_system = get_qa_system()  
    
    question = st.text_input("اطرح سؤالك :",placeholder=None)  
    
    if st.button("بحث"):  
        if question:  
            response_container = st.empty()  
            def stream_response(token):  
                if 'current_response' not in st.session_state:  
                    st.session_state.current_response = ""  
                st.session_state.current_response += token  
                response_container.markdown(st.session_state.current_response, unsafe_allow_html=True)  
            
            try:  
                query_engine = qa_system.create_query_engine()  
                st.session_state.current_response = ""  
                response = query_engine.query(question)  
                
                full_response = ""  
                for token in response.response_gen:  
                    full_response += token  
                    stream_response(token)  
                
                if hasattr(response, 'source_nodes'):  
                    st.markdown("""
                        <div style="direction: rtl !important; text-align: right !important; font-family: 'Noto Kufi Arabic', sans-serif !important;">
                            <div class="streamlit-expanderHeader">
                                المصادر
                            </div>
                        </div>
                    """, unsafe_allow_html=True)
                    with st.expander(""):  
                        for node in response.source_nodes:  
                            st.markdown(f"""
                                <div style="direction: rtl !important; text-align: right !important; font-family: 'Noto Kufi Arabic', sans-serif !important;">
                                    <p style="text-align: right !important;">مصادر الجواب : {node.metadata.get('file_name', 'Unknown')}</p>
                                    <p style="text-align: right !important;">Extrait: {node.text[:]}</p>
                                </div>
                            """, unsafe_allow_html=True)
            except Exception as e:  
                st.error(f"Une erreur s'est produite : {e}")  

if __name__ == "__main__":  
    main()