RAG_LOI_v2 / app.py
Trabis's picture
Update app.py
548a3a4 verified
from llama_index.core import Settings, VectorStoreIndex, StorageContext, load_index_from_storage
from llama_index.core.embeddings import BaseEmbedding
from llama_index.llms.mistralai import MistralAI
from llama_index.core.node_parser import SentenceSplitter
from llama_index.core import SimpleDirectoryReader
from llama_index.core import PromptTemplate
# from pydantic import PrivateAttr
# import requests
from typing import List, Optional, Union
# from llama_index.core.embeddings.utils import BaseEmbedding
from llama_index.embeddings.huggingface import HuggingFaceInferenceAPIEmbedding
# from llama_index.embeddings.huggingface import HuggingFaceEmbedding
import streamlit as st
from functools import lru_cache
import pickle
import os
mistral_api_key = os.getenv("mistral_api_key")
class QASystem:
def __init__(self,
mistral_api_key: str = mistral_api_key,
data_dir: str = "./data",
storage_dir: str = "./index_llama_136_multilingual-e5-large",
model_temperature: float = 0.002):
self.data_dir = data_dir
self.storage_dir = storage_dir
# Initialize embedding model with API
# api_key =
self.embedding_model = HuggingFaceInferenceAPIEmbedding(
model_name="intfloat/multilingual-e5-large",
)
# self.embedding_model = HuggingFaceEmbedding(model_name="intfloat/multilingual-e5-large",trust_remote_code=True)
self.llm = MistralAI(
model="mistral-large-latest",
api_key=mistral_api_key,
temperature=model_temperature,
max_tokens=1024
)
self._configure_settings()
# self.create_index()
self.index = self.load_index() # Define index here
def _configure_settings(self):
Settings.llm = self.llm
Settings.embed_model = self.embedding_model
def create_index(self):
print("creating index")
documents = SimpleDirectoryReader(self.data_dir).load_data()
node_parser = SentenceSplitter(chunk_size=206, chunk_overlap=0)
nodes = node_parser.get_nodes_from_documents(documents, show_progress=True)
sentence_index = VectorStoreIndex(nodes, show_progress=True)
sentence_index.storage_context.persist(self.storage_dir)
# # Save the index to a pickle file
# with open(f"{self.storage_dir}/index.pkl", "wb") as f:
# pickle.dump(sentence_index, f)
return sentence_index
def load_index(self):
# with open(f'{self.storage_dir}/index.pkl', 'rb') as f:
# sentence_index = pickle.load(f)
storage_context = StorageContext.from_defaults(persist_dir=self.storage_dir)
return load_index_from_storage(storage_context, embed_model=self.embedding_model)
def create_query_engine(self):
template = """
استخدم المعلومات التالية للإجابة على السؤال في النهاية. إذا لم تكن تعرف الإجابة، فقل فقط أنك لا تعرف، لا تحاول اختلاق إجابة.
{context_str}
السؤال: {query_str}
الإجابة بالعربية:
"""
prompt = PromptTemplate(template=template)
return self.index.as_query_engine(
similarity_top_k=10,
streaming=True,
text_qa_template=prompt,
response_mode="tree_summarize", #tree_summarize, simple_summarize, compact
embed_model=self.embedding_model
)
def query(self, question: str):
query_engine = self.create_query_engine()
response = query_engine.query(question)
return response#.print_response_stream()
# Utilisation de singleton pour éviter les réinitialisations multiples
# @st.cache_resource
@lru_cache(maxsize=1000)
def get_qa_system():
return QASystem()
def main():
st.markdown("""
<style>
@import url('https://fonts.googleapis.com/css2?family=Noto+Kufi+Arabic:wght@100;200;300;400;500;600;700;800;900&display=swap');
/* Application globale de la police */
* {
font-family: 'Noto Kufi Arabic', sans-serif !important;
}
body {
text-align: right;
font-family: 'Noto Kufi Arabic', sans-serif !important;
}
/* Style pour tous les textes */
p, div, span, button, input, label, h1, h2, h3, h4, h5, h6 {
font-family: 'Noto Kufi Arabic', sans-serif !important;
}
/* Titre principal avec taille réduite */
h1 {
font-size: 1.2em !important;
margin-bottom: 0.5em !important;
text-align: center;
padding: 0.3px;
}
.css-1h9b9rq.e1tzin5v0 {
direction: rtl;
text-align: right;
font-family: 'Noto Kufi Arabic', sans-serif !important;
}
/* Style pour l'expandeur */
.streamlit-expanderContent, div[data-testid="stExpander"] {
direction: rtl !important;
text-align: right !important;
font-family: 'Noto Kufi Arabic', sans-serif !important;
}
/* Style pour les boutons de l'expandeur */
button[kind="secondary"] {
direction: rtl !important;
text-align: right !important;
width: 100% !important;
font-family: 'Noto Kufi Arabic', sans-serif !important;
font-weight: 30 !important;
}
/* Style pour tous les éléments de texte */
p, div {
direction: rtl !important;
text-align: right !important;
font-family: 'Noto Kufi Arabic', sans-serif !important;
}
/* Style pour les bullet points */
ul, li {
direction: rtl !important;
text-align: right !important;
margin-right: 20px !important;
margin-left: 0 !important;
font-family: 'Noto Kufi Arabic', sans-serif !important;
}
.stTextInput, .stButton {
margin-left: auto;
margin-right: 0;
font-family: 'Noto Kufi Arabic', sans-serif !important;
}
.stTextInput {
width: 100% !important;
direction: rtl;
text-align: right;
font-family: 'Noto Kufi Arabic', sans-serif !important;
}
/* Force RTL sur tous les conteneurs */
.element-container, .stMarkdown {
direction: rtl !important;
text-align: right !important;
font-family: 'Noto Kufi Arabic', sans-serif !important;
}
/* Style spécifique pour l'expandeur des sources */
.css-1fcdlhc, .css-1629p8f {
direction: rtl !important;
text-align: right !important;
font-family: 'Noto Kufi Arabic', sans-serif !important;
}
/* Style pour le titre */
.stTitle {
font-family: 'Noto Kufi Arabic', sans-serif !important;
font-weight: 700 !important;
}
/* Style pour les boutons */
.stButton>button {
font-family: 'Noto Kufi Arabic', sans-serif !important;
font-weight: 500 !important;
}
/* Style pour les champs de texte */
.stTextInput>div>div>input {
font-family: 'Noto Kufi Arabic', sans-serif !important;
}
</style>
""", unsafe_allow_html=True)
st.title("هذا تطبيق للاجابة عن الاسئلة المتعلقة بالقانون المغربي ")
st.title("حاليا يضم 136 قانونا")
qa_system = get_qa_system()
question = st.text_input("اطرح سؤالك :",placeholder=None)
if st.button("بحث"):
if question:
response_container = st.empty()
def stream_response(token):
if 'current_response' not in st.session_state:
st.session_state.current_response = ""
st.session_state.current_response += token
response_container.markdown(st.session_state.current_response, unsafe_allow_html=True)
try:
query_engine = qa_system.create_query_engine()
st.session_state.current_response = ""
response = query_engine.query(question)
full_response = ""
for token in response.response_gen:
full_response += token
stream_response(token)
if hasattr(response, 'source_nodes'):
st.markdown("""
<div style="direction: rtl !important; text-align: right !important; font-family: 'Noto Kufi Arabic', sans-serif !important;">
<div class="streamlit-expanderHeader">
المصادر
</div>
</div>
""", unsafe_allow_html=True)
with st.expander(""):
for node in response.source_nodes:
st.markdown(f"""
<div style="direction: rtl !important; text-align: right !important; font-family: 'Noto Kufi Arabic', sans-serif !important;">
<p style="text-align: right !important;">مصادر الجواب : {node.metadata.get('file_name', 'Unknown')}</p>
<p style="text-align: right !important;">Extrait: {node.text[:]}</p>
</div>
""", unsafe_allow_html=True)
except Exception as e:
st.error(f"Une erreur s'est produite : {e}")
if __name__ == "__main__":
main()